diff --git a/src/main/java/com/naaturel/ANN/Main.java b/src/main/java/com/naaturel/ANN/Main.java index 58de78f..1871013 100644 --- a/src/main/java/com/naaturel/ANN/Main.java +++ b/src/main/java/com/naaturel/ANN/Main.java @@ -23,7 +23,7 @@ public class Main { DataSet dataset = new DatasetExtractor() .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_14.csv", nbrClass); - int[] neuronPerLayer = new int[]{3, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 37, dataset.getNbrLabels()}; + int[] neuronPerLayer = new int[]{27, dataset.getNbrLabels()}; int nbrInput = dataset.getNbrInputs(); List layers = new ArrayList<>(); @@ -51,7 +51,7 @@ public class Main { FullyConnectedNetwork network = new FullyConnectedNetwork(layers.toArray(new Layer[0])); Trainer trainer = new GradientBackpropagationTraining(); - trainer.train(0.01F, 5000, network, dataset); + trainer.train(0.0001F, 15000, network, dataset); GraphVisualizer visualizer = new GraphVisualizer(); @@ -60,9 +60,9 @@ public class Main { visualizer.addPoint("Label " + label.getFirst(), entry.getData().get(0).getValue(), entry.getData().get(1).getValue()); } - float min = -3F; - float max = 2F; - float step = 0.01F; + float min = -5F; + float max = 5F; + float step = 0.025F; for (float x = min; x < max; x+=step){ for (float y = min; y < max; y+=step){ float prediction = network.predict(List.of(new Input(x), new Input(y))).getFirst(); @@ -71,8 +71,7 @@ public class Main { } } - - visualizer.buildScatterGraph(); + visualizer.buildScatterGraph((int)min-1, (int)max+1); } } diff --git a/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java index 2383b95..3c443b1 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java @@ -35,7 +35,7 @@ public class GradientBackpropagationTraining implements Trainer { ctx.globalLoss = 0.0F; }) .afterEpoch(ctx -> ctx.globalLoss /= dataset.size()) - .withVerbose(false, epoch/10) + .withVerbose(true, epoch/10) .withTimeMeasurement(true) .run(context); } diff --git a/src/main/java/com/naaturel/ANN/infrastructure/graph/GraphVisualizer.java b/src/main/java/com/naaturel/ANN/infrastructure/graph/GraphVisualizer.java index cec7d14..9314764 100644 --- a/src/main/java/com/naaturel/ANN/infrastructure/graph/GraphVisualizer.java +++ b/src/main/java/com/naaturel/ANN/infrastructure/graph/GraphVisualizer.java @@ -42,13 +42,13 @@ public class GraphVisualizer { } - public void buildScatterGraph(){ + public void buildScatterGraph(int lower, int upper){ JFreeChart chart = ChartFactory.createScatterPlot( "Predictions", "X", "Y", dataset ); XYPlot plot = chart.getXYPlot(); - plot.getDomainAxis().setRange(-2, 2); - plot.getRangeAxis().setRange(-2, 2); + plot.getDomainAxis().setRange(lower, upper); + plot.getRangeAxis().setRange(lower, upper); JFrame frame = new JFrame("Predictions"); frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);