From 76465ab6eefb006340836fd8003af62b8be6f0eb Mon Sep 17 00:00:00 2001 From: Laurent Date: Wed, 25 Mar 2026 22:36:26 +0100 Subject: [PATCH] Start to add test coverage --- src/main/java/com/naaturel/ANN/Main.java | 10 +-- .../domain/abstraction/AlgorithmStrategy.java | 1 + .../ANN/domain/abstraction/Model.java | 2 +- .../ANN/domain/model/dataset/DataSet.java | 2 +- .../model/dataset/DatasetExtractor.java | 7 +- .../ANN/domain/model/neuron/Layer.java | 4 +- .../ANN/domain/model/neuron/Network.java | 4 +- .../model/training/TrainingPipeline.java | 6 +- .../GradientDescentCorrectionStrategy.java | 2 +- .../GradientDescentErrorStrategy.java | 2 +- .../neuron/SimplePerceptron.java | 4 +- .../simplePerceptron/Heaviside.java | 2 +- .../SimpleCorrectionStrategy.java | 3 +- .../training/GradientDescentTraining.java | 4 +- .../training/SimpleTraining.java | 2 +- .../java/perceptron/simplePerceptronTest.java | 85 +++++++++++++++++++ 16 files changed, 112 insertions(+), 28 deletions(-) create mode 100644 src/test/java/perceptron/simplePerceptronTest.java diff --git a/src/main/java/com/naaturel/ANN/Main.java b/src/main/java/com/naaturel/ANN/Main.java index 6e3847b..0d92b13 100644 --- a/src/main/java/com/naaturel/ANN/Main.java +++ b/src/main/java/com/naaturel/ANN/Main.java @@ -20,8 +20,8 @@ public class Main { DataSet dataset = new DatasetExtractor() .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_9.csv"); - DataSet orDataset = new DatasetExtractor() - .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/or.csv"); + DataSet andDataset = new DatasetExtractor() + .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and.csv"); List syns = new ArrayList<>(); syns.add(new Synapse(new Input(0), new Weight(0))); @@ -29,12 +29,12 @@ public class Main { Bias bias = new Bias(new Weight(0)); - Neuron neuron = new SimplePerceptron(syns, bias, new Linear()); + Neuron neuron = new SimplePerceptron(syns, bias, new Heaviside()); Layer layer = new Layer(List.of(neuron)); Network network = new Network(List.of(layer)); - Trainer trainer = new GradientDescentTraining(); - trainer.train(network, dataset); + Trainer trainer = new SimpleTraining(); + trainer.train(network, andDataset); } } diff --git a/src/main/java/com/naaturel/ANN/domain/abstraction/AlgorithmStrategy.java b/src/main/java/com/naaturel/ANN/domain/abstraction/AlgorithmStrategy.java index 39fa2dd..2f79614 100644 --- a/src/main/java/com/naaturel/ANN/domain/abstraction/AlgorithmStrategy.java +++ b/src/main/java/com/naaturel/ANN/domain/abstraction/AlgorithmStrategy.java @@ -1,5 +1,6 @@ package com.naaturel.ANN.domain.abstraction; +@FunctionalInterface public interface AlgorithmStrategy { void apply(); diff --git a/src/main/java/com/naaturel/ANN/domain/abstraction/Model.java b/src/main/java/com/naaturel/ANN/domain/abstraction/Model.java index 3b47ed6..d850f9c 100644 --- a/src/main/java/com/naaturel/ANN/domain/abstraction/Model.java +++ b/src/main/java/com/naaturel/ANN/domain/abstraction/Model.java @@ -8,7 +8,7 @@ import java.util.function.Consumer; public interface Model { int synCount(); - void applyOnSynapses(Consumer consumer); + void forEachSynapse(Consumer consumer); List predict(List inputs); } diff --git a/src/main/java/com/naaturel/ANN/domain/model/dataset/DataSet.java b/src/main/java/com/naaturel/ANN/domain/model/dataset/DataSet.java index fbcd459..dd10ca4 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/dataset/DataSet.java +++ b/src/main/java/com/naaturel/ANN/domain/model/dataset/DataSet.java @@ -9,7 +9,7 @@ public class DataSet implements Iterable{ private Map data; public DataSet() { - this(new HashMap<>()); + this(new LinkedHashMap<>()); } public DataSet(Map data){ diff --git a/src/main/java/com/naaturel/ANN/domain/model/dataset/DatasetExtractor.java b/src/main/java/com/naaturel/ANN/domain/model/dataset/DatasetExtractor.java index 66c744c..4f2688f 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/dataset/DatasetExtractor.java +++ b/src/main/java/com/naaturel/ANN/domain/model/dataset/DatasetExtractor.java @@ -5,15 +5,12 @@ import com.naaturel.ANN.domain.model.neuron.Input; import java.io.BufferedReader; import java.io.FileReader; import java.io.IOException; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.util.*; public class DatasetExtractor { public DataSet extract(String path) { - Map data = new HashMap<>(); + Map data = new LinkedHashMap<>(); try (BufferedReader reader = new BufferedReader(new FileReader(path))) { String line; diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java index a79ba9e..e366374 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java +++ b/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java @@ -35,7 +35,7 @@ public class Layer implements Model { } @Override - public void applyOnSynapses(Consumer consumer) { - this.neurons.forEach(neuron -> neuron.applyOnSynapses(consumer)); + public void forEachSynapse(Consumer consumer) { + this.neurons.forEach(neuron -> neuron.forEachSynapse(consumer)); } } diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/Network.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/Network.java index 283106d..769f9fc 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/neuron/Network.java +++ b/src/main/java/com/naaturel/ANN/domain/model/neuron/Network.java @@ -34,7 +34,7 @@ public class Network implements Model { } @Override - public void applyOnSynapses(Consumer consumer) { - this.layers.forEach(layer -> layer.applyOnSynapses(consumer)); + public void forEachSynapse(Consumer consumer) { + this.layers.forEach(layer -> layer.forEachSynapse(consumer)); } } diff --git a/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java b/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java index 7f2d00c..ff4c738 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java +++ b/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java @@ -56,9 +56,6 @@ public class TrainingPipeline { this.beforeEpoch.accept(ctx); this.executeSteps(ctx); this.afterEpoch.accept(ctx); - if(this.verbose) { - System.out.printf("[Global error] : %.2f\n", ctx.globalLoss); - } } while (!this.stopCondition.test(ctx)); } @@ -77,6 +74,9 @@ public class TrainingPipeline { System.out.printf("loss : %.5f\n", ctx.localLoss); } } + if(this.verbose) { + System.out.printf("[Global error] : %.2f\n", ctx.globalLoss); + } ctx.epoch += 1; } } diff --git a/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentCorrectionStrategy.java b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentCorrectionStrategy.java index c182e3e..ed11af3 100644 --- a/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentCorrectionStrategy.java +++ b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentCorrectionStrategy.java @@ -15,7 +15,7 @@ public class GradientDescentCorrectionStrategy implements AlgorithmStrategy { @Override public void apply() { AtomicInteger i = new AtomicInteger(0); - context.model.applyOnSynapses(syn -> { + context.model.forEachSynapse(syn -> { float corrector = context.correctorTerms.get(i.get()); float c = syn.getWeight() + corrector; syn.setWeight(c); diff --git a/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentErrorStrategy.java b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentErrorStrategy.java index 84e01c2..fff984e 100644 --- a/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentErrorStrategy.java +++ b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentErrorStrategy.java @@ -16,7 +16,7 @@ public class GradientDescentErrorStrategy implements AlgorithmStrategy { @Override public void apply() { AtomicInteger i = new AtomicInteger(0); - context.model.applyOnSynapses(syn -> { + context.model.forEachSynapse(syn -> { float corrector = context.correctorTerms.get(i.get()); corrector += context.learningRate * context.delta * syn.getInput(); context.correctorTerms.set(i.get(), corrector); diff --git a/src/main/java/com/naaturel/ANN/implementation/neuron/SimplePerceptron.java b/src/main/java/com/naaturel/ANN/implementation/neuron/SimplePerceptron.java index 16c7f85..0c7686f 100644 --- a/src/main/java/com/naaturel/ANN/implementation/neuron/SimplePerceptron.java +++ b/src/main/java/com/naaturel/ANN/implementation/neuron/SimplePerceptron.java @@ -22,7 +22,7 @@ public class SimplePerceptron extends Neuron { } @Override - public void applyOnSynapses(Consumer consumer) { + public void forEachSynapse(Consumer consumer) { consumer.accept(this.bias); this.synapses.forEach(consumer); } @@ -30,10 +30,10 @@ public class SimplePerceptron extends Neuron { @Override public float calculateWeightedSum() { float res = 0; + res += this.bias.getWeight() * this.bias.getInput(); for(Synapse syn : super.synapses){ res += syn.getWeight() * syn.getInput(); } - res += this.bias.getWeight() * this.bias.getInput(); return res; } diff --git a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/Heaviside.java b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/Heaviside.java index badfc30..b3e2ec4 100644 --- a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/Heaviside.java +++ b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/Heaviside.java @@ -12,6 +12,6 @@ public class Heaviside implements ActivationFunction { @Override public float accept(Neuron n) { float weightedSum = n.calculateWeightedSum(); - return weightedSum <= 0 ? 0:1; + return weightedSum < 0 ? 0:1; } } diff --git a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleCorrectionStrategy.java b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleCorrectionStrategy.java index 0ed8bcd..d36cc95 100644 --- a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleCorrectionStrategy.java +++ b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleCorrectionStrategy.java @@ -2,6 +2,7 @@ package com.naaturel.ANN.implementation.simplePerceptron; import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; + public class SimpleCorrectionStrategy implements AlgorithmStrategy { private final SimpleTrainingContext context; @@ -13,7 +14,7 @@ public class SimpleCorrectionStrategy implements AlgorithmStrategy { @Override public void apply() { if(context.currentLabel.getValue() == context.prediction) return ; - context.model.applyOnSynapses(syn -> { + context.model.forEachSynapse(syn -> { float currentW = syn.getWeight(); float currentInput = syn.getInput(); float newValue = currentW + (context.learningRate * context.delta * currentInput); diff --git a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java index 42df3fa..fa1893d 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java @@ -27,7 +27,7 @@ public class GradientDescentTraining implements Trainer { GradientDescentTrainingContext context = new GradientDescentTrainingContext(); context.dataset = dataset; context.model = model; - context.learningRate = 0.00011F; + context.learningRate = 0.2F; context.correctorTerms = new ArrayList<>(); List steps = List.of( @@ -40,7 +40,7 @@ public class GradientDescentTraining implements Trainer { TrainingPipeline pipeline = new TrainingPipeline(steps); pipeline - .stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 1000) + .stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 50) .beforeEpoch(ctx -> { ctx.globalLoss = 0.0F; for (int i = 0; i < model.synCount(); i++){ diff --git a/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java index 2fb7d20..3c7b97b 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java @@ -33,7 +33,7 @@ public class SimpleTraining implements Trainer { TrainingPipeline pipeline = new TrainingPipeline(steps); pipeline - .stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 100) + .stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 10) .beforeEpoch(ctx -> ctx.globalLoss = 0) .withVerbose(true) .run(context); diff --git a/src/test/java/perceptron/simplePerceptronTest.java b/src/test/java/perceptron/simplePerceptronTest.java new file mode 100644 index 0000000..4b8fd79 --- /dev/null +++ b/src/test/java/perceptron/simplePerceptronTest.java @@ -0,0 +1,85 @@ +package perceptron; + +import com.naaturel.ANN.domain.abstraction.Neuron; +import com.naaturel.ANN.domain.abstraction.TrainingStep; +import com.naaturel.ANN.domain.model.dataset.DataSet; +import com.naaturel.ANN.domain.model.dataset.DatasetExtractor; +import com.naaturel.ANN.domain.model.neuron.*; +import com.naaturel.ANN.domain.model.training.TrainingPipeline; +import com.naaturel.ANN.implementation.neuron.SimplePerceptron; +import com.naaturel.ANN.implementation.simplePerceptron.*; +import com.naaturel.ANN.implementation.training.steps.*; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; + + +public class simplePerceptronTest { + + private DataSet dataset; + private SimpleTrainingContext context; + + private List synapses; + private Bias bias; + private Network network; + + private TrainingPipeline pipeline; + + @BeforeEach + public void init(){ + dataset = new DatasetExtractor() + .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and.csv"); + + List syns = new ArrayList<>(); + syns.add(new Synapse(new Input(0), new Weight(0))); + syns.add(new Synapse(new Input(0), new Weight(0))); + + bias = new Bias(new Weight(0)); + + Neuron neuron = new SimplePerceptron(syns, bias, new Heaviside()); + Layer layer = new Layer(List.of(neuron)); + network = new Network(List.of(layer)); + + context = new SimpleTrainingContext(); + context.dataset = dataset; + context.model = network; + + List steps = List.of( + new PredictionStep(new SimplePredictionStrategy(context)), + new DeltaStep(new SimpleDeltaStrategy(context)), + new LossStep(new SimpleLossStrategy(context)), + new ErrorRegistrationStep(new SimpleErrorRegistrationStrategy(context)), + new WeightCorrectionStep(new SimpleCorrectionStrategy(context)) + ); + + pipeline = new TrainingPipeline(steps); + pipeline.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 100); + pipeline.beforeEpoch(ctx -> ctx.globalLoss = 0); + } + + @Test + public void test_the_whole_algorithm(){ + + List expectedGlobalLosses = List.of( + 2.0F, + 3.0F, + 3.0F, + 2.0F, + 1.0F, + 0.0F + ); + + context.learningRate = 1F; + pipeline.afterEpoch(ctx -> { + int index = ctx.epoch-1; + assertEquals(expectedGlobalLosses.get(index), context.globalLoss); + }); + + pipeline.run(context); + assertEquals(6, context.epoch); + } +}