From 2936bf33bf6c693a78e36d6d33180fbc97414380 Mon Sep 17 00:00:00 2001 From: Laurent Date: Mon, 23 Mar 2026 23:12:52 +0100 Subject: [PATCH] Start to reimplement gradient descent --- src/main/java/com/naaturel/ANN/Main.java | 42 ++-------------- ...onStrategy.java => AlgorithmStrategy.java} | 4 +- .../ANN/domain/model/dataset/DataSet.java | 12 +++-- .../domain/model/dataset/DataSetEntry.java | 12 +++-- .../model/dataset/DatasetExtractor.java | 36 ++++++++++++++ .../GradientDescentCorrectionStrategy.java | 16 +++++- .../correction/SimpleCorrectionStrategy.java | 4 +- .../loss/SimpleLossStrategy.java | 11 +++++ .../loss/SquareLossStrategy.java | 11 +++++ .../neuron/SimplePerceptron.java | 3 -- .../training/GradientDescentTraining.java | 49 ++++++++++++++----- .../training/SimpleTraining.java | 3 +- .../training/steps/LossStep.java | 19 +++++++ .../training/steps/PredictionStep.java | 6 +-- .../training/steps/SimpleLossStep.java | 12 ----- .../training/steps/WeightCorrectionStep.java | 6 +-- 16 files changed, 157 insertions(+), 89 deletions(-) rename src/main/java/com/naaturel/ANN/domain/abstraction/{CorrectionStrategy.java => AlgorithmStrategy.java} (59%) create mode 100644 src/main/java/com/naaturel/ANN/domain/model/dataset/DatasetExtractor.java create mode 100644 src/main/java/com/naaturel/ANN/implementation/loss/SimpleLossStrategy.java create mode 100644 src/main/java/com/naaturel/ANN/implementation/loss/SquareLossStrategy.java create mode 100644 src/main/java/com/naaturel/ANN/implementation/training/steps/LossStep.java delete mode 100644 src/main/java/com/naaturel/ANN/implementation/training/steps/SimpleLossStep.java diff --git a/src/main/java/com/naaturel/ANN/Main.java b/src/main/java/com/naaturel/ANN/Main.java index 434e03a..7c4d2bb 100644 --- a/src/main/java/com/naaturel/ANN/Main.java +++ b/src/main/java/com/naaturel/ANN/Main.java @@ -5,6 +5,7 @@ import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.domain.abstraction.TrainingStep; import com.naaturel.ANN.domain.model.dataset.DataSet; import com.naaturel.ANN.domain.model.dataset.DataSetEntry; +import com.naaturel.ANN.domain.model.dataset.DatasetExtractor; import com.naaturel.ANN.domain.model.dataset.Label; import com.naaturel.ANN.domain.model.neuron.*; import com.naaturel.ANN.domain.model.training.TrainingContext; @@ -15,49 +16,14 @@ import com.naaturel.ANN.implementation.neuron.SimplePerceptron; import com.naaturel.ANN.implementation.training.SimpleTraining; import com.naaturel.ANN.implementation.training.steps.*; +import javax.xml.crypto.Data; import java.util.*; public class Main { public static void main(String[] args){ - DataSet orDataSet = new DataSet(Map.ofEntries( - Map.entry(new DataSetEntry(List.of(0.0F, 0.0F)), new Label(0.0F)), - Map.entry(new DataSetEntry(List.of(0.0F, 1.0F)), new Label(1.0F)), - Map.entry(new DataSetEntry(List.of(1.0F, 0.0F)), new Label(1.0F)), - Map.entry(new DataSetEntry(List.of(1.0F, 1.0F)), new Label(1.0F)) - )); - - DataSet andDataSet = new DataSet(Map.ofEntries( - Map.entry(new DataSetEntry(List.of(0.0F, 0.0F)), new Label(-1.0F)), - Map.entry(new DataSetEntry(List.of(1.0F, 0.0F)), new Label(-1.0F)), - Map.entry(new DataSetEntry(List.of(0.0F, 1.0F)), new Label(-1.0F)), - Map.entry(new DataSetEntry(List.of(1.0F, 1.0F)), new Label(1.0F)) - )); - - DataSet dataSet = new DataSet(Map.ofEntries( - Map.entry(new DataSetEntry(List.of(1.0F, 6.0F)), new Label(1.0F)), - Map.entry(new DataSetEntry(List.of(7.0F, 9.0F)), new Label(-1.0F)), - Map.entry(new DataSetEntry(List.of(1.0F, 9.0F)), new Label(1.0F)), - Map.entry(new DataSetEntry(List.of(7.0F, 10.0F)), new Label(-1.0F)), - Map.entry(new DataSetEntry(List.of(2.0F, 5.0F)), new Label(-1.0F)), - Map.entry(new DataSetEntry(List.of(2.0F, 7.0F)), new Label(1.0F)), - Map.entry(new DataSetEntry(List.of(2.0F, 8.0F)), new Label(1.0F)), - Map.entry(new DataSetEntry(List.of(6.0F, 8.0F)), new Label(-1.0F)), - Map.entry(new DataSetEntry(List.of(6.0F, 9.0F)), new Label(-1.0F)), - Map.entry(new DataSetEntry(List.of(3.0F, 5.0F)), new Label(-1.0F)), - Map.entry(new DataSetEntry(List.of(3.0F, 6.0F)), new Label(-1.0F)), - Map.entry(new DataSetEntry(List.of(3.0F, 8.0F)), new Label(1.0F)), - Map.entry(new DataSetEntry(List.of(3.0F, 9.0F)), new Label(1.0F)), - Map.entry(new DataSetEntry(List.of(5.0F, 7.0F)), new Label(-1.0F)), - Map.entry(new DataSetEntry(List.of(5.0F, 8.0F)), new Label(-1.0F)), - Map.entry(new DataSetEntry(List.of(5.0F, 10.0F)), new Label(1.0F)), - Map.entry(new DataSetEntry(List.of(5.0F, 11.0F)), new Label(1.0F)), - Map.entry(new DataSetEntry(List.of(4.0F, 6.0F)), new Label(-1.0F)), - Map.entry(new DataSetEntry(List.of(4.0F, 7.0F)), new Label(-1.0F)), - Map.entry(new DataSetEntry(List.of(4.0F, 9.0F)), new Label(1.0F)), - Map.entry(new DataSetEntry(List.of(4.0F, 10.0F)), new Label(1.0F)) - )); + DataSet dataset = new DatasetExtractor().extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/or.csv"); List syns = new ArrayList<>(); syns.add(new Synapse(new Input(0), new Weight(0))); @@ -70,7 +36,7 @@ public class Main { Network network = new Network(List.of(layer)); Trainer trainer = new SimpleTraining(); - trainer.train(network, orDataSet); + trainer.train(network, dataset); } } diff --git a/src/main/java/com/naaturel/ANN/domain/abstraction/CorrectionStrategy.java b/src/main/java/com/naaturel/ANN/domain/abstraction/AlgorithmStrategy.java similarity index 59% rename from src/main/java/com/naaturel/ANN/domain/abstraction/CorrectionStrategy.java rename to src/main/java/com/naaturel/ANN/domain/abstraction/AlgorithmStrategy.java index 58dd0ae..160f83c 100644 --- a/src/main/java/com/naaturel/ANN/domain/abstraction/CorrectionStrategy.java +++ b/src/main/java/com/naaturel/ANN/domain/abstraction/AlgorithmStrategy.java @@ -2,8 +2,8 @@ package com.naaturel.ANN.domain.abstraction; import com.naaturel.ANN.domain.model.training.TrainingContext; -public interface CorrectionStrategy { +public interface AlgorithmStrategy { - void apply(TrainingContext context); + void apply(TrainingContext ctx); } diff --git a/src/main/java/com/naaturel/ANN/domain/model/dataset/DataSet.java b/src/main/java/com/naaturel/ANN/domain/model/dataset/DataSet.java index d0df51e..fbcd459 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/dataset/DataSet.java +++ b/src/main/java/com/naaturel/ANN/domain/model/dataset/DataSet.java @@ -1,12 +1,14 @@ package com.naaturel.ANN.domain.model.dataset; +import com.naaturel.ANN.domain.model.neuron.Input; + import java.util.*; public class DataSet implements Iterable{ private Map data; - public DataSet(){ + public DataSet() { this(new HashMap<>()); } @@ -31,15 +33,17 @@ public class DataSet implements Iterable{ float maxAbs = entries.stream() .flatMap(e -> e.getData().stream()) + .map(Input::getValue) .map(Math::abs) .max(Float::compare) .orElse(1.0F); Map normalized = new HashMap<>(); for (DataSetEntry entry : entries) { - List normalizedData = new ArrayList<>(); - for (float value : entry.getData()) { - normalizedData.add(Math.round((value / maxAbs) * 100.0F) / 100.0F); + List normalizedData = new ArrayList<>(); + for (Input input : entry.getData()) { + Input normalizedInput = new Input(Math.round((input.getValue() / maxAbs) * 100.0F) / 100.0F); + normalizedData.add(normalizedInput); } normalized.put(new DataSetEntry(normalizedData), this.data.get(entry)); } diff --git a/src/main/java/com/naaturel/ANN/domain/model/dataset/DataSetEntry.java b/src/main/java/com/naaturel/ANN/domain/model/dataset/DataSetEntry.java index b2f667d..c87bafe 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/dataset/DataSetEntry.java +++ b/src/main/java/com/naaturel/ANN/domain/model/dataset/DataSetEntry.java @@ -1,16 +1,18 @@ package com.naaturel.ANN.domain.model.dataset; +import com.naaturel.ANN.domain.model.neuron.Input; + import java.util.*; -public class DataSetEntry implements Iterable { +public class DataSetEntry implements Iterable { - private List data; + private List data; - public DataSetEntry(List data){ + public DataSetEntry(List data){ this.data = data; } - public List getData() { + public List getData() { return new ArrayList<>(data); } @@ -28,7 +30,7 @@ public class DataSetEntry implements Iterable { } @Override - public Iterator iterator() { + public Iterator iterator() { return this.data.iterator(); } diff --git a/src/main/java/com/naaturel/ANN/domain/model/dataset/DatasetExtractor.java b/src/main/java/com/naaturel/ANN/domain/model/dataset/DatasetExtractor.java new file mode 100644 index 0000000..66c744c --- /dev/null +++ b/src/main/java/com/naaturel/ANN/domain/model/dataset/DatasetExtractor.java @@ -0,0 +1,36 @@ +package com.naaturel.ANN.domain.model.dataset; + +import com.naaturel.ANN.domain.model.neuron.Input; + +import java.io.BufferedReader; +import java.io.FileReader; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class DatasetExtractor { + + public DataSet extract(String path) { + Map data = new HashMap<>(); + + try (BufferedReader reader = new BufferedReader(new FileReader(path))) { + String line; + while ((line = reader.readLine()) != null) { + String[] parts = line.split(","); + List inputs = new ArrayList<>(); + for (int i = 0; i < parts.length - 1; i++) { + inputs.add(new Input(Float.parseFloat(parts[i].trim()))); + } + float label = Float.parseFloat(parts[parts.length - 1].trim()); + data.put(new DataSetEntry(inputs), new Label(label)); + } + } catch (IOException e) { + throw new RuntimeException("Failed to read dataset from: " + path, e); + } + + return new DataSet(data); + } + +} diff --git a/src/main/java/com/naaturel/ANN/implementation/correction/GradientDescentCorrectionStrategy.java b/src/main/java/com/naaturel/ANN/implementation/correction/GradientDescentCorrectionStrategy.java index bfaa79b..b9a1f6c 100644 --- a/src/main/java/com/naaturel/ANN/implementation/correction/GradientDescentCorrectionStrategy.java +++ b/src/main/java/com/naaturel/ANN/implementation/correction/GradientDescentCorrectionStrategy.java @@ -1,9 +1,21 @@ package com.naaturel.ANN.implementation.correction; -import com.naaturel.ANN.domain.abstraction.CorrectionStrategy; +import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; import com.naaturel.ANN.domain.model.training.TrainingContext; -public class GradientDescentCorrectionStrategy implements CorrectionStrategy { +import java.util.ArrayList; +import java.util.List; + +public class GradientDescentCorrectionStrategy implements AlgorithmStrategy { + + List correctorTerms; + + public GradientDescentCorrectionStrategy(int nbrCorrectors){ + this.correctorTerms = new ArrayList<>(); + for (int i = 0; i < nbrCorrectors; i++){ + correctorTerms.add(0F); + } + } @Override public void apply(TrainingContext context) { diff --git a/src/main/java/com/naaturel/ANN/implementation/correction/SimpleCorrectionStrategy.java b/src/main/java/com/naaturel/ANN/implementation/correction/SimpleCorrectionStrategy.java index 0efcd53..efabb8b 100644 --- a/src/main/java/com/naaturel/ANN/implementation/correction/SimpleCorrectionStrategy.java +++ b/src/main/java/com/naaturel/ANN/implementation/correction/SimpleCorrectionStrategy.java @@ -1,9 +1,9 @@ package com.naaturel.ANN.implementation.correction; -import com.naaturel.ANN.domain.abstraction.CorrectionStrategy; +import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; import com.naaturel.ANN.domain.model.training.TrainingContext; -public class SimpleCorrectionStrategy implements CorrectionStrategy { +public class SimpleCorrectionStrategy implements AlgorithmStrategy { @Override public void apply(TrainingContext context) { diff --git a/src/main/java/com/naaturel/ANN/implementation/loss/SimpleLossStrategy.java b/src/main/java/com/naaturel/ANN/implementation/loss/SimpleLossStrategy.java new file mode 100644 index 0000000..ee5d785 --- /dev/null +++ b/src/main/java/com/naaturel/ANN/implementation/loss/SimpleLossStrategy.java @@ -0,0 +1,11 @@ +package com.naaturel.ANN.implementation.loss; + +import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; +import com.naaturel.ANN.domain.model.training.TrainingContext; + +public class SimpleLossStrategy implements AlgorithmStrategy { + @Override + public void apply(TrainingContext ctx) { + ctx.localLoss = Math.abs(ctx.delta); + } +} diff --git a/src/main/java/com/naaturel/ANN/implementation/loss/SquareLossStrategy.java b/src/main/java/com/naaturel/ANN/implementation/loss/SquareLossStrategy.java new file mode 100644 index 0000000..8958430 --- /dev/null +++ b/src/main/java/com/naaturel/ANN/implementation/loss/SquareLossStrategy.java @@ -0,0 +1,11 @@ +package com.naaturel.ANN.implementation.loss; + +import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; +import com.naaturel.ANN.domain.model.training.TrainingContext; + +public class SquareLossStrategy implements AlgorithmStrategy { + @Override + public void apply(TrainingContext ctx) { + ctx.localLoss = (float)Math.pow(ctx.delta, 2) / 2; + } +} diff --git a/src/main/java/com/naaturel/ANN/implementation/neuron/SimplePerceptron.java b/src/main/java/com/naaturel/ANN/implementation/neuron/SimplePerceptron.java index 5861c74..16c7f85 100644 --- a/src/main/java/com/naaturel/ANN/implementation/neuron/SimplePerceptron.java +++ b/src/main/java/com/naaturel/ANN/implementation/neuron/SimplePerceptron.java @@ -1,13 +1,10 @@ package com.naaturel.ANN.implementation.neuron; import com.naaturel.ANN.domain.abstraction.ActivationFunction; -import com.naaturel.ANN.domain.abstraction.CorrectionStrategy; import com.naaturel.ANN.domain.abstraction.Neuron; -import com.naaturel.ANN.domain.abstraction.Trainable; import com.naaturel.ANN.domain.model.neuron.Bias; import com.naaturel.ANN.domain.model.neuron.Input; import com.naaturel.ANN.domain.model.neuron.Synapse; -import com.naaturel.ANN.domain.model.neuron.Weight; import java.util.List; import java.util.function.Consumer; diff --git a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java index 4aeb4e0..dc119b4 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java @@ -1,25 +1,50 @@ package com.naaturel.ANN.implementation.training; -import com.naaturel.ANN.domain.abstraction.Neuron; +import com.naaturel.ANN.domain.abstraction.Trainable; import com.naaturel.ANN.domain.abstraction.Trainer; +import com.naaturel.ANN.domain.abstraction.TrainingStep; import com.naaturel.ANN.domain.model.dataset.DataSet; -import com.naaturel.ANN.domain.model.dataset.DataSetEntry; -import com.naaturel.ANN.domain.model.neuron.Bias; -import com.naaturel.ANN.domain.model.neuron.Input; -import com.naaturel.ANN.domain.model.neuron.Synapse; -import com.naaturel.ANN.domain.model.neuron.Weight; +import com.naaturel.ANN.domain.model.training.TrainingContext; +import com.naaturel.ANN.domain.model.training.TrainingPipeline; +import com.naaturel.ANN.implementation.correction.GradientDescentCorrectionStrategy; +import com.naaturel.ANN.implementation.correction.SimpleCorrectionStrategy; +import com.naaturel.ANN.implementation.loss.SquareLossStrategy; +import com.naaturel.ANN.implementation.training.steps.*; -import java.util.ArrayList; -import java.util.Arrays; import java.util.List; -/*public class GradientDescentTraining implements Trainer { +public class GradientDescentTraining implements Trainer { public GradientDescentTraining(){ } - public void train(Neuron n, float learningRate, DataSet dataSet) { + @Override + public void train(Trainable model, DataSet dataset) { + TrainingContext context = new TrainingContext(); + context.dataset = dataset; + context.model = model; + context.learningRate = 0.3F; + + List steps = List.of( + new PredictionStep(), + new DeltaStep(), + new LossStep(new SquareLossStrategy()), + new SimpleErrorDetectionStep(), + new WeightCorrectionStep(new GradientDescentCorrectionStrategy(2)) + ); + + TrainingPipeline pipeline = new TrainingPipeline(steps); + pipeline + .stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 100) + .beforeEpoch(ctx -> ctx.globalLoss = 0) + .afterEpoch(ctx -> ()) + .withVerbose(true) + .withTimeMeasurement(true) + .run(context); + } + + /*public void train(Neuron n, float learningRate, DataSet dataSet) { int epoch = 1; int maxEpoch = 402; float errorThreshold = 0F; @@ -120,6 +145,6 @@ import java.util.List; variance /= dataSet.size(); return variance; - } + }*/ -}*/ +} diff --git a/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java index 7f8f43f..351619c 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java @@ -7,6 +7,7 @@ import com.naaturel.ANN.domain.model.dataset.DataSet; import com.naaturel.ANN.domain.model.training.TrainingContext; import com.naaturel.ANN.domain.model.training.TrainingPipeline; import com.naaturel.ANN.implementation.correction.SimpleCorrectionStrategy; +import com.naaturel.ANN.implementation.loss.SimpleLossStrategy; import com.naaturel.ANN.implementation.training.steps.*; import java.util.List; @@ -27,7 +28,7 @@ public class SimpleTraining implements Trainer { List steps = List.of( new PredictionStep(), new DeltaStep(), - new SimpleLossStep(), + new LossStep(new SimpleLossStrategy()), new SimpleErrorDetectionStep(), new WeightCorrectionStep(new SimpleCorrectionStrategy()) ); diff --git a/src/main/java/com/naaturel/ANN/implementation/training/steps/LossStep.java b/src/main/java/com/naaturel/ANN/implementation/training/steps/LossStep.java new file mode 100644 index 0000000..2c50287 --- /dev/null +++ b/src/main/java/com/naaturel/ANN/implementation/training/steps/LossStep.java @@ -0,0 +1,19 @@ +package com.naaturel.ANN.implementation.training.steps; + +import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; +import com.naaturel.ANN.domain.abstraction.TrainingStep; +import com.naaturel.ANN.domain.model.training.TrainingContext; + +public class LossStep implements TrainingStep { + + private final AlgorithmStrategy lossStrategy; + + public LossStep(AlgorithmStrategy strategy) { + this.lossStrategy = strategy; + } + + @Override + public void run(TrainingContext ctx) { + this.lossStrategy.apply(ctx); + } +} diff --git a/src/main/java/com/naaturel/ANN/implementation/training/steps/PredictionStep.java b/src/main/java/com/naaturel/ANN/implementation/training/steps/PredictionStep.java index a1e29b2..e14253c 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/steps/PredictionStep.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/steps/PredictionStep.java @@ -11,11 +11,7 @@ public class PredictionStep implements TrainingStep { @Override public void run(TrainingContext ctx) { - List inputs = new ArrayList<>(); - for(Float f : ctx.currentEntry.getData()){ - inputs.add(new Input(f)); - } - List predictions = ctx.model.predict(inputs); + List predictions = ctx.model.predict(ctx.currentEntry.getData()); ctx.prediction = predictions.getFirst(); } } diff --git a/src/main/java/com/naaturel/ANN/implementation/training/steps/SimpleLossStep.java b/src/main/java/com/naaturel/ANN/implementation/training/steps/SimpleLossStep.java deleted file mode 100644 index f46d815..0000000 --- a/src/main/java/com/naaturel/ANN/implementation/training/steps/SimpleLossStep.java +++ /dev/null @@ -1,12 +0,0 @@ -package com.naaturel.ANN.implementation.training.steps; - -import com.naaturel.ANN.domain.abstraction.TrainingStep; -import com.naaturel.ANN.domain.model.training.TrainingContext; - -public class SimpleLossStep implements TrainingStep { - - @Override - public void run(TrainingContext ctx) { - ctx.localLoss = Math.abs(ctx.delta); - } -} diff --git a/src/main/java/com/naaturel/ANN/implementation/training/steps/WeightCorrectionStep.java b/src/main/java/com/naaturel/ANN/implementation/training/steps/WeightCorrectionStep.java index 00996ab..1c015c0 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/steps/WeightCorrectionStep.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/steps/WeightCorrectionStep.java @@ -1,14 +1,14 @@ package com.naaturel.ANN.implementation.training.steps; -import com.naaturel.ANN.domain.abstraction.CorrectionStrategy; +import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; import com.naaturel.ANN.domain.abstraction.TrainingStep; import com.naaturel.ANN.domain.model.training.TrainingContext; public class WeightCorrectionStep implements TrainingStep { - private final CorrectionStrategy correctionStrategy; + private final AlgorithmStrategy correctionStrategy; - public WeightCorrectionStep(CorrectionStrategy strategy) { + public WeightCorrectionStep(AlgorithmStrategy strategy) { this.correctionStrategy = strategy; }