From 0d3ab0de8d91e40567687ce00e2ed78761fa4ada Mon Sep 17 00:00:00 2001 From: Laurent Date: Thu, 26 Mar 2026 11:27:10 +0100 Subject: [PATCH] Reimplement Adaline --- src/main/java/com/naaturel/ANN/Main.java | 9 +- .../model/training/TrainingPipeline.java | 16 +++- .../adaline/AdalineTrainingContext.java | 6 ++ .../GradientDescentErrorStrategy.java | 1 + .../gradientDescent/SquareLossStrategy.java | 6 +- .../SimpleCorrectionStrategy.java | 5 +- .../SimpleErrorRegistrationStrategy.java | 5 +- .../training/AdalineTraining.java | 52 ++++++++++- .../training/GradientDescentTraining.java | 13 ++- src/test/java/adaline/AdalineTest.java | 93 +++++++++++++++++++ .../gradientDescent/GradientDescentTest.java | 6 +- 11 files changed, 187 insertions(+), 25 deletions(-) create mode 100644 src/main/java/com/naaturel/ANN/implementation/adaline/AdalineTrainingContext.java create mode 100644 src/test/java/adaline/AdalineTest.java diff --git a/src/main/java/com/naaturel/ANN/Main.java b/src/main/java/com/naaturel/ANN/Main.java index 0d92b13..9c25e08 100644 --- a/src/main/java/com/naaturel/ANN/Main.java +++ b/src/main/java/com/naaturel/ANN/Main.java @@ -8,6 +8,7 @@ import com.naaturel.ANN.domain.model.neuron.*; import com.naaturel.ANN.implementation.gradientDescent.Linear; import com.naaturel.ANN.implementation.simplePerceptron.Heaviside; import com.naaturel.ANN.implementation.neuron.SimplePerceptron; +import com.naaturel.ANN.implementation.training.AdalineTraining; import com.naaturel.ANN.implementation.training.GradientDescentTraining; import com.naaturel.ANN.implementation.training.SimpleTraining; @@ -21,7 +22,7 @@ public class Main { .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_9.csv"); DataSet andDataset = new DatasetExtractor() - .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and.csv"); + .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and-gradient.csv"); List syns = new ArrayList<>(); syns.add(new Synapse(new Input(0), new Weight(0))); @@ -29,12 +30,12 @@ public class Main { Bias bias = new Bias(new Weight(0)); - Neuron neuron = new SimplePerceptron(syns, bias, new Heaviside()); + Neuron neuron = new SimplePerceptron(syns, bias, new Linear()); Layer layer = new Layer(List.of(neuron)); Network network = new Network(List.of(layer)); - Trainer trainer = new SimpleTraining(); - trainer.train(network, andDataset); + Trainer trainer = new AdalineTraining(); + trainer.train(network, dataset); } } diff --git a/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java b/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java index ff4c738..d99ef80 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java +++ b/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java @@ -4,6 +4,7 @@ import com.naaturel.ANN.domain.abstraction.TrainingContext; import com.naaturel.ANN.domain.abstraction.TrainingStep; import com.naaturel.ANN.domain.model.dataset.DataSetEntry; +import java.sql.Time; import java.util.ArrayList; import java.util.List; import java.util.function.Consumer; @@ -52,11 +53,23 @@ public class TrainingPipeline { } public void run(TrainingContext ctx) { + + long start = this.timeMeasurement ? System.currentTimeMillis() : 0; + do { this.beforeEpoch.accept(ctx); this.executeSteps(ctx); this.afterEpoch.accept(ctx); + if(this.verbose) { + System.out.printf("[Global error] : %f\n", ctx.globalLoss); + } } while (!this.stopCondition.test(ctx)); + + if(this.timeMeasurement) { + long end = System.currentTimeMillis(); + System.out.printf("[Training finished in %.3fs]", (end-start)/1000.0); + } + } private void executeSteps(TrainingContext ctx){ @@ -74,9 +87,6 @@ public class TrainingPipeline { System.out.printf("loss : %.5f\n", ctx.localLoss); } } - if(this.verbose) { - System.out.printf("[Global error] : %.2f\n", ctx.globalLoss); - } ctx.epoch += 1; } } diff --git a/src/main/java/com/naaturel/ANN/implementation/adaline/AdalineTrainingContext.java b/src/main/java/com/naaturel/ANN/implementation/adaline/AdalineTrainingContext.java new file mode 100644 index 0000000..3b0b623 --- /dev/null +++ b/src/main/java/com/naaturel/ANN/implementation/adaline/AdalineTrainingContext.java @@ -0,0 +1,6 @@ +package com.naaturel.ANN.implementation.adaline; + +import com.naaturel.ANN.domain.abstraction.TrainingContext; + +public class AdalineTrainingContext extends TrainingContext { +} diff --git a/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentErrorStrategy.java b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentErrorStrategy.java index fff984e..e326eb7 100644 --- a/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentErrorStrategy.java +++ b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentErrorStrategy.java @@ -22,5 +22,6 @@ public class GradientDescentErrorStrategy implements AlgorithmStrategy { context.correctorTerms.set(i.get(), corrector); i.incrementAndGet(); }); + context.globalLoss += context.localLoss; } } diff --git a/src/main/java/com/naaturel/ANN/implementation/gradientDescent/SquareLossStrategy.java b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/SquareLossStrategy.java index 0243756..2aa2cd4 100644 --- a/src/main/java/com/naaturel/ANN/implementation/gradientDescent/SquareLossStrategy.java +++ b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/SquareLossStrategy.java @@ -1,19 +1,19 @@ package com.naaturel.ANN.implementation.gradientDescent; import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; +import com.naaturel.ANN.domain.abstraction.TrainingContext; import com.naaturel.ANN.implementation.simplePerceptron.SimpleTrainingContext; public class SquareLossStrategy implements AlgorithmStrategy { - private final GradientDescentTrainingContext context; + private final TrainingContext context; - public SquareLossStrategy(GradientDescentTrainingContext context) { + public SquareLossStrategy(TrainingContext context) { this.context = context; } @Override public void apply() { this.context.localLoss = (float)Math.pow(this.context.delta, 2)/2; - this.context.globalLoss += context.localLoss; } } diff --git a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleCorrectionStrategy.java b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleCorrectionStrategy.java index d36cc95..05671f9 100644 --- a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleCorrectionStrategy.java +++ b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleCorrectionStrategy.java @@ -1,13 +1,14 @@ package com.naaturel.ANN.implementation.simplePerceptron; import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; +import com.naaturel.ANN.domain.abstraction.TrainingContext; public class SimpleCorrectionStrategy implements AlgorithmStrategy { - private final SimpleTrainingContext context; + private final TrainingContext context; - public SimpleCorrectionStrategy(SimpleTrainingContext context) { + public SimpleCorrectionStrategy(TrainingContext context) { this.context = context; } diff --git a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleErrorRegistrationStrategy.java b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleErrorRegistrationStrategy.java index c6be25d..c526435 100644 --- a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleErrorRegistrationStrategy.java +++ b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleErrorRegistrationStrategy.java @@ -1,12 +1,13 @@ package com.naaturel.ANN.implementation.simplePerceptron; import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; +import com.naaturel.ANN.domain.abstraction.TrainingContext; public class SimpleErrorRegistrationStrategy implements AlgorithmStrategy { - private final SimpleTrainingContext context; + private final TrainingContext context; - public SimpleErrorRegistrationStrategy(SimpleTrainingContext context) { + public SimpleErrorRegistrationStrategy(TrainingContext context) { this.context = context; } diff --git a/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java index 5a7fb70..fbd8219 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java @@ -1,21 +1,65 @@ package com.naaturel.ANN.implementation.training; +import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.abstraction.Neuron; import com.naaturel.ANN.domain.abstraction.Trainer; +import com.naaturel.ANN.domain.abstraction.TrainingStep; import com.naaturel.ANN.domain.model.dataset.DataSet; import com.naaturel.ANN.domain.model.dataset.DataSetEntry; import com.naaturel.ANN.domain.model.neuron.Input; import com.naaturel.ANN.domain.model.neuron.Synapse; import com.naaturel.ANN.domain.model.neuron.Weight; +import com.naaturel.ANN.domain.model.training.TrainingPipeline; +import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext; +import com.naaturel.ANN.implementation.gradientDescent.GradientDescentCorrectionStrategy; +import com.naaturel.ANN.implementation.gradientDescent.GradientDescentErrorStrategy; +import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext; +import com.naaturel.ANN.implementation.gradientDescent.SquareLossStrategy; +import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStrategy; +import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStrategy; +import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStrategy; +import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy; +import com.naaturel.ANN.implementation.training.steps.*; + +import java.util.ArrayList; +import java.util.List; -/*public class AdalineTraining implements Trainer { +public class AdalineTraining implements Trainer { public AdalineTraining(){ } - public void train(Neuron n, float learningRate, DataSet dataSet) { + @Override + public void train(Model model, DataSet dataset) { + AdalineTrainingContext context = new AdalineTrainingContext(); + context.dataset = dataset; + context.model = model; + context.learningRate = 0.003F; + + List steps = List.of( + new PredictionStep(new SimplePredictionStrategy(context)), + new DeltaStep(new SimpleDeltaStrategy(context)), + new LossStep(new SquareLossStrategy(context)), + new ErrorRegistrationStep(new SimpleErrorRegistrationStrategy(context)), + new WeightCorrectionStep(new SimpleCorrectionStrategy(context)) + ); + + new TrainingPipeline(steps) + .stopCondition(ctx -> ctx.globalLoss <= 0.125F || ctx.epoch > 10000) + .beforeEpoch(ctx -> { + ctx.globalLoss = 0.0F; + }) + .afterEpoch(ctx -> { + ctx.globalLoss /= context.dataset.size(); + }) + .withVerbose(true) + .withTimeMeasurement(true) + .run(context); + } + + /*public void train(Neuron n, float learningRate, DataSet dataSet) { int epoch = 1; int maxEpoch = 202; float errorThreshold = 0.0F; @@ -76,6 +120,6 @@ import com.naaturel.ANN.domain.model.neuron.Weight; private float calculateLoss(float delta){ return (float) Math.pow(delta, 2)/2; - } + }*/ -}*/ +} diff --git a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java index 09804f4..8c2f975 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java @@ -27,6 +27,7 @@ public class GradientDescentTraining implements Trainer { GradientDescentTrainingContext context = new GradientDescentTrainingContext(); context.dataset = dataset; context.model = model; + context.learningRate = 0.0011F; context.correctorTerms = new ArrayList<>(); List steps = List.of( @@ -37,17 +38,19 @@ public class GradientDescentTraining implements Trainer { ); new TrainingPipeline(steps) - .stopCondition(ctx -> ctx.globalLoss <= 0.125F || ctx.epoch > 100) + .stopCondition(ctx -> ctx.globalLoss <= 0.08F || ctx.epoch > 5000) .beforeEpoch(ctx -> { GradientDescentTrainingContext gdCtx = (GradientDescentTrainingContext) ctx; gdCtx.globalLoss = 0.0F; gdCtx.correctorTerms.clear(); - for (int i = 0; i < ctx.model.synCount(); i++){ - gdCtx.correctorTerms.add(0F); - } + gdCtx.model.forEachSynapse(syn -> gdCtx.correctorTerms.add(0F)); + }) + .afterEpoch(ctx -> { + context.globalLoss /= context.dataset.size(); + new GradientDescentCorrectionStrategy(context).apply(); }) - .afterEpoch(ctx -> new GradientDescentCorrectionStrategy(context).apply()) .withVerbose(true) + .withTimeMeasurement(true) .run(context); } diff --git a/src/test/java/adaline/AdalineTest.java b/src/test/java/adaline/AdalineTest.java new file mode 100644 index 0000000..75c4f30 --- /dev/null +++ b/src/test/java/adaline/AdalineTest.java @@ -0,0 +1,93 @@ +package adaline; + + +import com.naaturel.ANN.domain.abstraction.Neuron; +import com.naaturel.ANN.domain.abstraction.TrainingStep; +import com.naaturel.ANN.domain.model.dataset.DataSet; +import com.naaturel.ANN.domain.model.dataset.DatasetExtractor; +import com.naaturel.ANN.domain.model.neuron.*; +import com.naaturel.ANN.domain.model.training.TrainingPipeline; +import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext; +import com.naaturel.ANN.implementation.gradientDescent.*; +import com.naaturel.ANN.implementation.neuron.SimplePerceptron; +import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStrategy; +import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStrategy; +import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStrategy; +import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy; +import com.naaturel.ANN.implementation.training.steps.*; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class AdalineTest { + + private DataSet dataset; + private AdalineTrainingContext context; + + private List synapses; + private Bias bias; + private Network network; + + private TrainingPipeline pipeline; + + @BeforeEach + public void init(){ + dataset = new DatasetExtractor() + .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and-gradient.csv"); + + List syns = new ArrayList<>(); + syns.add(new Synapse(new Input(0), new Weight(0))); + syns.add(new Synapse(new Input(0), new Weight(0))); + + bias = new Bias(new Weight(0)); + + Neuron neuron = new SimplePerceptron(syns, bias, new Linear()); + Layer layer = new Layer(List.of(neuron)); + network = new Network(List.of(layer)); + + context = new AdalineTrainingContext(); + context.dataset = dataset; + context.model = network; + + List steps = List.of( + new PredictionStep(new SimplePredictionStrategy(context)), + new DeltaStep(new SimpleDeltaStrategy(context)), + new LossStep(new SquareLossStrategy(context)), + new ErrorRegistrationStep(new SimpleErrorRegistrationStrategy(context)), + new WeightCorrectionStep(new SimpleCorrectionStrategy(context)) + ); + + pipeline = new TrainingPipeline(steps) + .stopCondition(ctx -> ctx.globalLoss <= 0.1329F || ctx.epoch > 10000) + .beforeEpoch(ctx -> { + ctx.globalLoss = 0.0F; + }); + } + + @Test + public void test_the_whole_algorithm(){ + + List expectedGlobalLosses = List.of( + 0.501522F, + 0.498601F + ); + + context.learningRate = 0.03F; + pipeline.afterEpoch(ctx -> { + ctx.globalLoss /= context.dataset.size(); + + int index = ctx.epoch-1; + if(index >= expectedGlobalLosses.size()) return; + + //assertEquals(expectedGlobalLosses.get(index), context.globalLoss, 0.00001f); + }); + + pipeline.run(context); + assertEquals(214, context.epoch); + } +} + diff --git a/src/test/java/gradientDescent/GradientDescentTest.java b/src/test/java/gradientDescent/GradientDescentTest.java index 948025f..6a33af6 100644 --- a/src/test/java/gradientDescent/GradientDescentTest.java +++ b/src/test/java/gradientDescent/GradientDescentTest.java @@ -49,7 +49,7 @@ public class GradientDescentTest { context = new GradientDescentTrainingContext(); context.dataset = dataset; context.model = network; - context.correctorTerms = new ArrayList<>(); + context.correctorTerms = new ArrayList<>(); List steps = List.of( new PredictionStep(new SimplePredictionStrategy(context)), @@ -92,7 +92,9 @@ public class GradientDescentTest { assertEquals(expectedGlobalLosses.get(index), context.globalLoss, 0.00001f); }); - pipeline.run(context); + pipeline + .withVerbose(true) + .run(context); assertEquals(67, context.epoch); } }