diff --git a/src/main/java/com/naaturel/ANN/Main.java b/src/main/java/com/naaturel/ANN/Main.java index 1871013..348e93a 100644 --- a/src/main/java/com/naaturel/ANN/Main.java +++ b/src/main/java/com/naaturel/ANN/Main.java @@ -2,6 +2,7 @@ package com.naaturel.ANN; import com.naaturel.ANN.domain.model.neuron.Neuron; import com.naaturel.ANN.domain.abstraction.Trainer; +import com.naaturel.ANN.implementation.gradientDescent.Linear; import com.naaturel.ANN.implementation.multiLayers.Sigmoid; import com.naaturel.ANN.implementation.multiLayers.TanH; import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining; @@ -18,12 +19,12 @@ public class Main { public static void main(String[] args){ - int nbrClass = 3; + int nbrClass = 1; DataSet dataset = new DatasetExtractor() - .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_14.csv", nbrClass); + .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_9.csv", nbrClass); - int[] neuronPerLayer = new int[]{27, dataset.getNbrLabels()}; + int[] neuronPerLayer = new int[]{10, 5, 5, dataset.getNbrLabels()}; int nbrInput = dataset.getNbrInputs(); List layers = new ArrayList<>(); @@ -51,7 +52,7 @@ public class Main { FullyConnectedNetwork network = new FullyConnectedNetwork(layers.toArray(new Layer[0])); Trainer trainer = new GradientBackpropagationTraining(); - trainer.train(0.0001F, 15000, network, dataset); + trainer.train(0.0005F, 15000, network, dataset); GraphVisualizer visualizer = new GraphVisualizer(); @@ -60,13 +61,13 @@ public class Main { visualizer.addPoint("Label " + label.getFirst(), entry.getData().get(0).getValue(), entry.getData().get(1).getValue()); } - float min = -5F; - float max = 5F; - float step = 0.025F; + float min = 0F; + float max = 15F; + float step = 0.03F; for (float x = min; x < max; x+=step){ for (float y = min; y < max; y+=step){ float prediction = network.predict(List.of(new Input(x), new Input(y))).getFirst(); - float predSeries = prediction > 0.5F ? 1 : 0; + float predSeries = prediction > 0.5F ? 1 : -1; visualizer.addPoint(Float.toString(predSeries), x, y); } } diff --git a/src/main/java/com/naaturel/ANN/implementation/multiLayers/BackpropagationCorrectionStep.java b/src/main/java/com/naaturel/ANN/implementation/multiLayers/BackpropagationCorrectionStep.java index 25ec57d..ed999d9 100644 --- a/src/main/java/com/naaturel/ANN/implementation/multiLayers/BackpropagationCorrectionStep.java +++ b/src/main/java/com/naaturel/ANN/implementation/multiLayers/BackpropagationCorrectionStep.java @@ -4,7 +4,7 @@ import com.naaturel.ANN.domain.abstraction.AlgorithmStep; public class BackpropagationCorrectionStep implements AlgorithmStep { - private GradientBackpropagationContext context; + private GradientBackpropagationContext context; public BackpropagationCorrectionStep(GradientBackpropagationContext context){ this.context = context; diff --git a/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java index 3c443b1..a6608c1 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java @@ -30,12 +30,12 @@ public class GradientBackpropagationTraining implements Trainer { ); new TrainingPipeline(steps) - .stopCondition(ctx -> ctx.globalLoss <= 0.00F || ctx.epoch > epoch) + .stopCondition(ctx -> ctx.globalLoss <= 0.001F || ctx.epoch > epoch) .beforeEpoch(ctx -> { ctx.globalLoss = 0.0F; }) .afterEpoch(ctx -> ctx.globalLoss /= dataset.size()) - .withVerbose(true, epoch/10) + .withVerbose(true,epoch/10) .withTimeMeasurement(true) .run(context); } diff --git a/src/main/resources/assets/table_2_9.csv b/src/main/resources/assets/table_2_9.csv index 3bfd684..ed8c03f 100644 --- a/src/main/resources/assets/table_2_9.csv +++ b/src/main/resources/assets/table_2_9.csv @@ -19,3 +19,10 @@ 4,7,-1 4,9,1 4,10,1 +2,6,-1 +7,7,-1 +5,9,1 +9,10,1 +7,1,-1 +5,0,1 +9,5,1 \ No newline at end of file