Just a regular commit
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package com.naaturel.ANN;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||
@@ -11,6 +12,7 @@ import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||
import com.naaturel.ANN.implementation.activation.Heaviside;
|
||||
import com.naaturel.ANN.implementation.correction.SimpleCorrectionStrategy;
|
||||
import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
|
||||
import com.naaturel.ANN.implementation.training.SimpleTraining;
|
||||
import com.naaturel.ANN.implementation.training.steps.*;
|
||||
|
||||
import java.util.*;
|
||||
@@ -67,25 +69,8 @@ public class Main {
|
||||
Layer layer = new Layer(List.of(neuron));
|
||||
Network network = new Network(List.of(layer));
|
||||
|
||||
TrainingContext context = new TrainingContext();
|
||||
context.dataset = orDataSet;
|
||||
context.model = network;
|
||||
context.learningRate = 0.3F;
|
||||
|
||||
List<TrainingStep> steps = List.of(
|
||||
new PredictionStep(),
|
||||
new DeltaStep(),
|
||||
new SimpleLossStep(),
|
||||
new SimpleErrorDetectionStep(),
|
||||
new WeightCorrectionStep(new SimpleCorrectionStrategy())
|
||||
);
|
||||
|
||||
TrainingPipeline pipeline = new TrainingPipeline(steps);
|
||||
pipeline
|
||||
.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 100)
|
||||
.beforeEpoch(ctx -> ctx.globalLoss = 0)
|
||||
.withVerbose(true)
|
||||
.run(context);
|
||||
Trainer trainer = new SimpleTraining();
|
||||
trainer.train(network, orDataSet);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user