From 76bc791889e274e553087c5d4180827fff1db1b3 Mon Sep 17 00:00:00 2001 From: Laurent Date: Sun, 22 Mar 2026 23:36:44 +0100 Subject: [PATCH] Just a regular commit --- src/main/java/com/naaturel/ANN/Main.java | 8 ++--- .../ANN/domain/abstraction/Trainer.java | 8 +++++ .../ANN/domain/model/neuron/Synapse.java | 2 ++ .../training/AdalineTraining.java | 34 +++++++------------ .../training/GradientDescentTraining.java | 33 +++++++++++++++--- .../training/SimpleTraining.java | 3 +- 6 files changed, 57 insertions(+), 31 deletions(-) create mode 100644 src/main/java/com/naaturel/ANN/domain/abstraction/Trainer.java diff --git a/src/main/java/com/naaturel/ANN/Main.java b/src/main/java/com/naaturel/ANN/Main.java index 856eca2..218d3d8 100644 --- a/src/main/java/com/naaturel/ANN/Main.java +++ b/src/main/java/com/naaturel/ANN/Main.java @@ -1,6 +1,7 @@ package com.naaturel.ANN; import com.naaturel.ANN.domain.abstraction.Neuron; +import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.domain.model.dataset.DataSet; import com.naaturel.ANN.domain.model.dataset.DataSetEntry; import com.naaturel.ANN.domain.model.dataset.Label; @@ -8,7 +9,6 @@ import com.naaturel.ANN.domain.model.neuron.Bias; import com.naaturel.ANN.domain.model.neuron.Input; import com.naaturel.ANN.domain.model.neuron.Synapse; import com.naaturel.ANN.domain.model.neuron.Weight; -import com.naaturel.ANN.implementation.activationFunction.Heaviside; import com.naaturel.ANN.implementation.activationFunction.Linear; import com.naaturel.ANN.implementation.neuron.SimplePerceptron; import com.naaturel.ANN.implementation.training.AdalineTraining; @@ -29,8 +29,8 @@ public class Main { DataSet andDataSet = new DataSet(Map.ofEntries( Map.entry(new DataSetEntry(List.of(0.0F, 0.0F)), new Label(-1.0F)), - Map.entry(new DataSetEntry(List.of(0.0F, 1.0F)), new Label(-1.0F)), Map.entry(new DataSetEntry(List.of(1.0F, 0.0F)), new Label(-1.0F)), + Map.entry(new DataSetEntry(List.of(0.0F, 1.0F)), new Label(-1.0F)), Map.entry(new DataSetEntry(List.of(1.0F, 1.0F)), new Label(1.0F)) )); @@ -65,11 +65,11 @@ public class Main { Bias bias = new Bias(new Weight(0)); Neuron n = new SimplePerceptron(syns, bias, new Linear()); - AdalineTraining st = new AdalineTraining(); + Trainer trainer = new AdalineTraining(); long start = System.currentTimeMillis(); - st.train(n, 0.03F, andDataSet); + trainer.train(n, 0.03F, andDataSet); long end = System.currentTimeMillis(); System.out.printf("Training completed in %.2f s%n", (end - start) / 1000.0); diff --git a/src/main/java/com/naaturel/ANN/domain/abstraction/Trainer.java b/src/main/java/com/naaturel/ANN/domain/abstraction/Trainer.java new file mode 100644 index 0000000..f8c44dc --- /dev/null +++ b/src/main/java/com/naaturel/ANN/domain/abstraction/Trainer.java @@ -0,0 +1,8 @@ +package com.naaturel.ANN.domain.abstraction; + +import com.naaturel.ANN.domain.model.dataset.DataSet; + +public interface Trainer { + + void train(Neuron n, float learningRate, DataSet dataSet); +} diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/Synapse.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/Synapse.java index 3900de9..e6ed930 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/neuron/Synapse.java +++ b/src/main/java/com/naaturel/ANN/domain/model/neuron/Synapse.java @@ -27,4 +27,6 @@ public class Synapse { } + + } diff --git a/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java index 0c2790d..6c22a5d 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java @@ -1,17 +1,15 @@ package com.naaturel.ANN.implementation.training; import com.naaturel.ANN.domain.abstraction.Neuron; +import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.domain.model.dataset.DataSet; import com.naaturel.ANN.domain.model.dataset.DataSetEntry; -import com.naaturel.ANN.domain.model.neuron.Bias; import com.naaturel.ANN.domain.model.neuron.Input; import com.naaturel.ANN.domain.model.neuron.Synapse; import com.naaturel.ANN.domain.model.neuron.Weight; -import java.util.ArrayList; -import java.util.List; -public class AdalineTraining { +public class AdalineTraining implements Trainer { public AdalineTraining(){ @@ -19,15 +17,14 @@ public class AdalineTraining { public void train(Neuron n, float learningRate, DataSet dataSet) { int epoch = 1; - int maxEpoch = 1000; + int maxEpoch = 202; float errorThreshold = 0.0F; float mse; do { if(epoch > maxEpoch) break; mse = 0; - - for(DataSetEntry entry : dataSet) { + for(DataSetEntry entry : dataSet) { this.updateInputs(n, entry); float prediction = n.predict(); float expectation = dataSet.getLabel(entry).getValue(); @@ -49,23 +46,22 @@ public class AdalineTraining { System.out.printf("predicted : %.2f, ", prediction); System.out.printf("expected : %.2f, ", expectation); System.out.printf("delta : %.2f, ", delta); - System.out.printf("loss : %.2f\n", loss); + System.out.printf("loss : %.5f\n", loss); } + mse /= dataSet.size(); System.out.printf("[Total error : %f]\n", mse); - + System.out.println("[Final weights]"); + System.out.printf("Bias: %f\n", n.getBias().getWeight()); + int i = 1; + for(Synapse syn : n.getSynapses()){ + System.out.printf("Syn %d: %f\n", i, syn.getWeight()); + i++; + } epoch++; } while(mse > errorThreshold); } - private List initCorrectorTerms(int number){ - List res = new ArrayList<>(); - for(int i = 0; i < number; i++){ - res.add(0F); - } - return res; - } - private void updateInputs(Neuron n, DataSetEntry entry){ int index = 0; for(float value : entry){ @@ -82,8 +78,4 @@ public class AdalineTraining { return (float) Math.pow(delta, 2)/2; } - private float calculateWeightCorrection(float value, float delta){ - return value * delta; - } - } diff --git a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java index f86c91b..fce18a2 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java @@ -1,6 +1,7 @@ package com.naaturel.ANN.implementation.training; import com.naaturel.ANN.domain.abstraction.Neuron; +import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.domain.model.dataset.DataSet; import com.naaturel.ANN.domain.model.dataset.DataSetEntry; import com.naaturel.ANN.domain.model.neuron.Bias; @@ -9,9 +10,10 @@ import com.naaturel.ANN.domain.model.neuron.Synapse; import com.naaturel.ANN.domain.model.neuron.Weight; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; -public class GradientDescentTraining { +public class GradientDescentTraining implements Trainer { public GradientDescentTraining(){ @@ -19,8 +21,8 @@ public class GradientDescentTraining { public void train(Neuron n, float learningRate, DataSet dataSet) { int epoch = 1; - int maxEpoch = 1000; - float errorThreshold = 0.0F; + int maxEpoch = 402; + float errorThreshold = 0F; float mse; do { @@ -54,6 +56,7 @@ public class GradientDescentTraining { System.out.printf("delta : %.2f, ", delta); System.out.printf("loss : %.2f\n", loss); } + mse /= dataSet.size(); System.out.printf("[Total error : %f]\n", mse); float currentBias = n.getBias().getWeight(); @@ -69,6 +72,13 @@ public class GradientDescentTraining { epoch++; } while(mse > errorThreshold); + System.out.println("[Final weights]"); + System.out.printf("Bias: %f\n", n.getBias().getWeight()); + int i = 1; + for(Synapse syn : n.getSynapses()){ + System.out.printf("Syn %d: %f\n", i, syn.getWeight()); + i++; + } } private List initCorrectorTerms(int number){ @@ -95,8 +105,21 @@ public class GradientDescentTraining { return (float) Math.pow(delta, 2)/2; } - private float calculateWeightCorrection(float value, float delta){ - return value * delta; + public float computeThreshold(DataSet dataSet) { + float sum = 0; + for (DataSetEntry entry : dataSet) { + sum += dataSet.getLabel(entry).getValue(); + } + float mean = sum / dataSet.size(); + + float variance = 0; + for (DataSetEntry entry : dataSet) { + float diff = dataSet.getLabel(entry).getValue() - mean; + variance += diff * diff; + } + variance /= dataSet.size(); + + return variance; } } diff --git a/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java index 64ced34..1fea010 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java @@ -1,13 +1,14 @@ package com.naaturel.ANN.implementation.training; import com.naaturel.ANN.domain.abstraction.Neuron; +import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.domain.model.dataset.DataSet; import com.naaturel.ANN.domain.model.dataset.DataSetEntry; import com.naaturel.ANN.domain.model.neuron.Input; import com.naaturel.ANN.domain.model.neuron.Synapse; import com.naaturel.ANN.domain.model.neuron.Weight; -public class SimpleTraining { +public class SimpleTraining implements Trainer { public SimpleTraining() {