diff --git a/src/main/java/com/naaturel/ANN/Main.java b/src/main/java/com/naaturel/ANN/Main.java index 8f52b09..c6a5562 100644 --- a/src/main/java/com/naaturel/ANN/Main.java +++ b/src/main/java/com/naaturel/ANN/Main.java @@ -38,7 +38,7 @@ public class Main { Layer layer = new Layer(neurons); Network network = new Network(List.of(layer)); - Trainer trainer = new AdalineTraining(); + Trainer trainer = new GradientDescentTraining(); trainer.train(network, dataset); } diff --git a/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java b/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java index fc8aa1a..5e8865d 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java +++ b/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java @@ -81,25 +81,7 @@ public class TrainingPipeline { System.out.printf("[Training finished in %.3fs]", (end-start)/1000.0); } - AtomicInteger neuronIndex = new AtomicInteger(0); - ctx.model.forEachNeuron(n -> { - List weights = new ArrayList<>(); - n.forEachSynapse(syn -> weights.add(syn.getWeight())); - - float b = weights.get(0); - float w1 = weights.get(1); - float w2 = weights.get(2); - - this.visualizer.addEquation("boundary_" + neuronIndex.getAndIncrement(), w1, w2, b, -3, 3); - }); - int i = 0; - for(DataSetEntry entry : ctx.dataset){ - List inputs = entry.getData(); - this.visualizer.addPoint("p"+i, inputs.get(0).getValue(), inputs.get(1).getValue()); - this.visualizer.addPoint("p"+i, inputs.get(0).getValue()+0.01F, inputs.get(1).getValue()+0.01F); - i++; - } - this.visualizer.build(); + if(this.visualization) this.visualize(ctx); } private void executeSteps(TrainingContext ctx){ @@ -123,4 +105,26 @@ public class TrainingPipeline { ctx.epoch += 1; } + private void visualize(TrainingContext ctx){ + AtomicInteger neuronIndex = new AtomicInteger(0); + ctx.model.forEachNeuron(n -> { + List weights = new ArrayList<>(); + n.forEachSynapse(syn -> weights.add(syn.getWeight())); + + float b = weights.get(0); + float w1 = weights.get(1); + float w2 = weights.get(2); + + this.visualizer.addEquation("boundary_" + neuronIndex.getAndIncrement(), w1, w2, b, -3, 3); + }); + int i = 0; + for(DataSetEntry entry : ctx.dataset){ + List inputs = entry.getData(); + this.visualizer.addPoint("p"+i, inputs.get(0).getValue(), inputs.get(1).getValue()); + this.visualizer.addPoint("p"+i, inputs.get(0).getValue()+0.01F, inputs.get(1).getValue()+0.01F); + i++; + } + this.visualizer.build(); + } + } diff --git a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java index 2d98ccb..d40ab28 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java @@ -12,6 +12,7 @@ import com.naaturel.ANN.implementation.gradientDescent.SquareLossStrategy; import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStrategy; import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy; import com.naaturel.ANN.implementation.training.steps.*; +import com.naaturel.ANN.infrastructure.graph.GraphVisualizer; import java.util.ArrayList; import java.util.List; @@ -27,7 +28,7 @@ public class GradientDescentTraining implements Trainer { GradientDescentTrainingContext context = new GradientDescentTrainingContext(); context.dataset = dataset; context.model = model; - context.learningRate = 0.0011F; + context.learningRate = 0.0005F; context.correctorTerms = new ArrayList<>(); List steps = List.of( @@ -38,7 +39,7 @@ public class GradientDescentTraining implements Trainer { ); new TrainingPipeline(steps) - .stopCondition(ctx -> ctx.globalLoss <= 0.08F || ctx.epoch > 500) + .stopCondition(ctx -> ctx.globalLoss <= 0.00F || ctx.epoch > 50000) .beforeEpoch(ctx -> { GradientDescentTrainingContext gdCtx = (GradientDescentTrainingContext) ctx; gdCtx.globalLoss = 0.0F; @@ -49,8 +50,9 @@ public class GradientDescentTraining implements Trainer { context.globalLoss /= context.dataset.size(); new GradientDescentCorrectionStrategy(context).apply(); }) - .withVerbose(true) + //.withVerbose(true) .withTimeMeasurement(true) + .withVisualization(false, new GraphVisualizer()) .run(context); }