From ada01d350b27a4d56d9d9f05f40c880cc5accbf3 Mon Sep 17 00:00:00 2001 From: Laurent Date: Mon, 30 Mar 2026 18:28:21 +0200 Subject: [PATCH] Change signature of train method --- src/main/java/com/naaturel/ANN/Main.java | 37 +++++++++++++++---- .../ANN/domain/abstraction/Trainer.java | 2 +- .../ANN/domain/model/neuron/Neuron.java | 7 ++-- .../model/training/TrainingPipeline.java | 2 +- .../training/AdalineTraining.java | 4 +- .../GradientBackpropagationTraining.java | 11 +++--- .../training/GradientDescentTraining.java | 4 +- .../training/SimpleTraining.java | 4 +- .../ANN/infrastructure/dataset/DataSet.java | 31 +++++++++++++++- .../infrastructure/graph/GraphVisualizer.java | 20 +++++++++- src/main/resources/assets/xor.csv | 4 ++ 11 files changed, 100 insertions(+), 26 deletions(-) create mode 100644 src/main/resources/assets/xor.csv diff --git a/src/main/java/com/naaturel/ANN/Main.java b/src/main/java/com/naaturel/ANN/Main.java index a285ce4..b345dd1 100644 --- a/src/main/java/com/naaturel/ANN/Main.java +++ b/src/main/java/com/naaturel/ANN/Main.java @@ -6,8 +6,10 @@ import com.naaturel.ANN.implementation.multiLayers.Sigmoid; import com.naaturel.ANN.implementation.multiLayers.TanH; import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining; import com.naaturel.ANN.infrastructure.dataset.DataSet; +import com.naaturel.ANN.infrastructure.dataset.DataSetEntry; import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor; import com.naaturel.ANN.domain.model.neuron.*; +import com.naaturel.ANN.infrastructure.graph.GraphVisualizer; import java.util.*; @@ -15,13 +17,13 @@ public class Main { public static void main(String[] args){ - int nbrInput = 25; - int nbrClass = 4; - - int[] neuronPerLayer = new int[]{10, nbrClass}; + int nbrClass = 1; DataSet dataset = new DatasetExtractor() - .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_3_5.csv", nbrClass); + .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_12.csv", nbrClass); + + int[] neuronPerLayer = new int[]{10, dataset.getNbrLabels()}; + int nbrInput = dataset.getNbrInputs(); List layers = new ArrayList<>(); for (int i = 0; i < neuronPerLayer.length; i++){ @@ -38,7 +40,7 @@ public class Main { Bias bias = new Bias(new Weight()); - Neuron n = new Neuron(syns, bias, new TanH()); + Neuron n = new Neuron(syns, bias, new Sigmoid(2)); neurons.add(n); } Layer layer = new Layer(neurons); @@ -48,7 +50,28 @@ public class Main { FullyConnectedNetwork network = new FullyConnectedNetwork(layers); Trainer trainer = new GradientBackpropagationTraining(); - trainer.train(network, dataset); + trainer.train(0.5F, network, dataset); + /*GraphVisualizer visualizer = new GraphVisualizer(); + + for (DataSetEntry entry : dataset) { + List label = dataset.getLabelsAsFloat(entry); + visualizer.addPoint("Label " + label.getFirst(), entry.getData().get(0).getValue(), entry.getData().get(1).getValue()); + } + + float min = -2F; + float max = 2F; + float step = 0.01F; + for (float x = min; x < max; x+=step){ + for (float y = min; y < max; y+=step){ + float prediction = network.predict(List.of(new Input(x), new Input(y))).getFirst(); + float predSeries = prediction > 0.5F ? 1 : 0; + visualizer.addPoint(Float.toString(predSeries), x, y); + } + } + + + visualizer.buildScatterGraph();*/ } + } diff --git a/src/main/java/com/naaturel/ANN/domain/abstraction/Trainer.java b/src/main/java/com/naaturel/ANN/domain/abstraction/Trainer.java index 867341d..80b321e 100644 --- a/src/main/java/com/naaturel/ANN/domain/abstraction/Trainer.java +++ b/src/main/java/com/naaturel/ANN/domain/abstraction/Trainer.java @@ -3,5 +3,5 @@ package com.naaturel.ANN.domain.abstraction; import com.naaturel.ANN.infrastructure.dataset.DataSet; public interface Trainer { - void train(Model model, DataSet dataset); + void train(float learningRate, Model model, DataSet dataset); } diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java index ea6cb97..cd4272f 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java +++ b/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java @@ -53,12 +53,11 @@ public class Neuron implements Model { } public float calculateWeightedSum() { - float res = 0; - res += this.bias.getWeight() * this.bias.getInput(); + this.weightedSum = 0F; + this.weightedSum += this.bias.getWeight() * this.bias.getInput(); for(Synapse syn : this.synapses){ - res += syn.getWeight() * syn.getInput(); + this.weightedSum += syn.getWeight() * syn.getInput(); } - this.weightedSum = res; return this.weightedSum; } diff --git a/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java b/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java index de70dd6..afbecff 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java +++ b/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java @@ -124,7 +124,7 @@ public class TrainingPipeline { this.visualizer.addPoint("p"+i, inputs.get(0).getValue()+0.01F, inputs.get(1).getValue()+0.01F); i++; } - this.visualizer.build(); + this.visualizer.buildLineGraph(); } } diff --git a/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java index b641df8..b45c364 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java @@ -23,11 +23,11 @@ public class AdalineTraining implements Trainer { } @Override - public void train(Model model, DataSet dataset) { + public void train(float learningRate, Model model, DataSet dataset) { AdalineTrainingContext context = new AdalineTrainingContext(); context.dataset = dataset; context.model = model; - context.learningRate = 0.003F; + context.learningRate = learningRate; List steps = List.of( new SimplePredictionStep(context), diff --git a/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java index 0c3df4f..54a9ff5 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java @@ -13,14 +13,13 @@ import com.naaturel.ANN.infrastructure.dataset.DataSet; import java.util.List; - public class GradientBackpropagationTraining implements Trainer { @Override - public void train(Model model, DataSet dataset) { + public void train(float learningRate, Model model, DataSet dataset) { GradientBackpropagationContext context = new GradientBackpropagationContext(); context.dataset = dataset; context.model = model; - context.learningRate = 0.1F; + context.learningRate = learningRate; List steps = List.of( new SimplePredictionStep(context), @@ -30,8 +29,10 @@ public class GradientBackpropagationTraining implements Trainer { ); new TrainingPipeline(steps) - .stopCondition(ctx -> ctx.epoch == 250) - .withVerbose(true) + .beforeEpoch(ctx -> ctx.globalLoss = 0.0F) + .afterEpoch(ctx -> ctx.globalLoss = ctx.localLoss/dataset.size()) + .stopCondition(ctx -> ctx.epoch > 1000000) + .withVerbose(false) .withTimeMeasurement(true) .run(context); diff --git a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java index 1fc558b..3a849bc 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java @@ -23,11 +23,11 @@ public class GradientDescentTraining implements Trainer { } @Override - public void train(Model model, DataSet dataset) { + public void train(float learningRate, Model model, DataSet dataset) { GradientDescentTrainingContext context = new GradientDescentTrainingContext(); context.dataset = dataset; context.model = model; - context.learningRate = 0.0008F; + context.learningRate = learningRate; context.correctorTerms = new ArrayList<>(); List steps = List.of( diff --git a/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java index 6e73eef..b2cb32a 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java @@ -16,11 +16,11 @@ public class SimpleTraining implements Trainer { } @Override - public void train(Model model, DataSet dataset) { + public void train(float learningRate, Model model, DataSet dataset) { SimpleTrainingContext context = new SimpleTrainingContext(); context.dataset = dataset; context.model = model; - context.learningRate = 0.3F; + context.learningRate = learningRate; List steps = List.of( new SimplePredictionStep(context), diff --git a/src/main/java/com/naaturel/ANN/infrastructure/dataset/DataSet.java b/src/main/java/com/naaturel/ANN/infrastructure/dataset/DataSet.java index 7fe085d..8573314 100644 --- a/src/main/java/com/naaturel/ANN/infrastructure/dataset/DataSet.java +++ b/src/main/java/com/naaturel/ANN/infrastructure/dataset/DataSet.java @@ -3,23 +3,52 @@ package com.naaturel.ANN.infrastructure.dataset; import com.naaturel.ANN.domain.model.neuron.Input; import java.util.*; +import java.util.stream.Stream; public class DataSet implements Iterable{ private final Map data; + private final int nbrInputs; + private final int nbrLabels; + public DataSet() { - this(new LinkedHashMap<>()); + this(new LinkedHashMap<>()); //ensure iteration order is the same as insertion order } public DataSet(Map data){ this.data = data; + this.nbrInputs = this.calculateNbrInput(); + this.nbrLabels = this.calculateNbrLabel(); } + private int calculateNbrInput(){ + //assumes every entry are the same length + Stream keyStream = this.data.keySet().stream(); + Optional firstEntry = keyStream.findFirst(); + return firstEntry.map(inputs -> inputs.getData().size()).orElse(0); + } + + private int calculateNbrLabel(){ + //assumes every label are the same length + Stream keyStream = this.data.keySet().stream(); + Optional firstEntry = keyStream.findFirst(); + return firstEntry.map(inputs -> this.data.get(inputs).getValues().size()).orElse(0); + } + + public int size() { return data.size(); } + public int getNbrInputs() { + return this.nbrInputs; + } + + public int getNbrLabels(){ + return this.nbrLabels; + } + public List getData(){ return new ArrayList<>(this.data.keySet()); } diff --git a/src/main/java/com/naaturel/ANN/infrastructure/graph/GraphVisualizer.java b/src/main/java/com/naaturel/ANN/infrastructure/graph/GraphVisualizer.java index 67d2feb..cec7d14 100644 --- a/src/main/java/com/naaturel/ANN/infrastructure/graph/GraphVisualizer.java +++ b/src/main/java/com/naaturel/ANN/infrastructure/graph/GraphVisualizer.java @@ -3,6 +3,7 @@ package com.naaturel.ANN.infrastructure.graph; import org.jfree.chart.ChartFactory; import org.jfree.chart.ChartPanel; import org.jfree.chart.JFreeChart; +import org.jfree.chart.plot.XYPlot; import org.jfree.data.xy.XYSeries; import org.jfree.data.xy.XYSeriesCollection; @@ -29,7 +30,7 @@ public class GraphVisualizer { } } - public void build(){ + public void buildLineGraph(){ JFreeChart chart = ChartFactory.createXYLineChart( "Model learning", "X", "Y", dataset ); @@ -39,4 +40,21 @@ public class GraphVisualizer { frame.pack(); frame.setVisible(true); } + + + public void buildScatterGraph(){ + JFreeChart chart = ChartFactory.createScatterPlot( + "Predictions", "X", "Y", dataset + ); + XYPlot plot = chart.getXYPlot(); + plot.getDomainAxis().setRange(-2, 2); + plot.getRangeAxis().setRange(-2, 2); + + JFrame frame = new JFrame("Predictions"); + frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); + frame.add(new ChartPanel(chart)); + frame.pack(); + frame.setVisible(true); + } + } diff --git a/src/main/resources/assets/xor.csv b/src/main/resources/assets/xor.csv new file mode 100644 index 0000000..8da7332 --- /dev/null +++ b/src/main/resources/assets/xor.csv @@ -0,0 +1,4 @@ +0,0,0 +0,1,1 +1,0,1 +1,1,0 \ No newline at end of file