From fd97d0853c4ff1bd4443f86cf22134d44c658472 Mon Sep 17 00:00:00 2001 From: Laurent Date: Mon, 30 Mar 2026 21:13:03 +0200 Subject: [PATCH] Fix multi layer implementation --- src/main/java/com/naaturel/ANN/Main.java | 12 +++--- .../ANN/domain/abstraction/Model.java | 2 +- .../ANN/domain/abstraction/Trainer.java | 2 +- .../model/neuron/FullyConnectedNetwork.java | 20 ++++++---- .../ANN/domain/model/neuron/Layer.java | 18 ++++++++- .../ANN/domain/model/neuron/Neuron.java | 2 +- .../model/training/TrainingPipeline.java | 14 ++++--- .../gradientDescent/SquareLossStep.java | 1 + .../BackpropagationCorrectionStep.java | 2 +- .../multiLayers/ErrorSignalStep.java | 35 ++--------------- .../multiLayers/OutputLayerErrorStep.java | 39 +++++++++++++++++++ .../training/AdalineTraining.java | 6 +-- .../GradientBackpropagationTraining.java | 16 ++++---- .../training/GradientDescentTraining.java | 4 +- .../training/SimpleTraining.java | 6 +-- 15 files changed, 108 insertions(+), 71 deletions(-) create mode 100644 src/main/java/com/naaturel/ANN/implementation/multiLayers/OutputLayerErrorStep.java diff --git a/src/main/java/com/naaturel/ANN/Main.java b/src/main/java/com/naaturel/ANN/Main.java index b345dd1..723da24 100644 --- a/src/main/java/com/naaturel/ANN/Main.java +++ b/src/main/java/com/naaturel/ANN/Main.java @@ -17,12 +17,12 @@ public class Main { public static void main(String[] args){ - int nbrClass = 1; + int nbrClass = 3; DataSet dataset = new DatasetExtractor() - .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_12.csv", nbrClass); + .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_14.csv", nbrClass); - int[] neuronPerLayer = new int[]{10, dataset.getNbrLabels()}; + int[] neuronPerLayer = new int[]{3, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 37, dataset.getNbrLabels()}; int nbrInput = dataset.getNbrInputs(); List layers = new ArrayList<>(); @@ -40,7 +40,7 @@ public class Main { Bias bias = new Bias(new Weight()); - Neuron n = new Neuron(syns, bias, new Sigmoid(2)); + Neuron n = new Neuron(syns, bias, new TanH()); neurons.add(n); } Layer layer = new Layer(neurons); @@ -50,7 +50,7 @@ public class Main { FullyConnectedNetwork network = new FullyConnectedNetwork(layers); Trainer trainer = new GradientBackpropagationTraining(); - trainer.train(0.5F, network, dataset); + trainer.train(0.001F, 1000, network, dataset); /*GraphVisualizer visualizer = new GraphVisualizer(); @@ -59,7 +59,7 @@ public class Main { visualizer.addPoint("Label " + label.getFirst(), entry.getData().get(0).getValue(), entry.getData().get(1).getValue()); } - float min = -2F; + float min = -3F; float max = 2F; float step = 0.01F; for (float x = min; x < max; x+=step){ diff --git a/src/main/java/com/naaturel/ANN/domain/abstraction/Model.java b/src/main/java/com/naaturel/ANN/domain/abstraction/Model.java index 891f685..b9899e1 100644 --- a/src/main/java/com/naaturel/ANN/domain/abstraction/Model.java +++ b/src/main/java/com/naaturel/ANN/domain/abstraction/Model.java @@ -10,7 +10,7 @@ import java.util.function.Consumer; public interface Model { int synCount(); int neuronCount(); - int indexOf(Neuron n); + int indexInLayerOf(Neuron n); void forEachNeuron(Consumer consumer); void forEachSynapse(Consumer consumer); void forEachOutputNeurons(Consumer consumer); diff --git a/src/main/java/com/naaturel/ANN/domain/abstraction/Trainer.java b/src/main/java/com/naaturel/ANN/domain/abstraction/Trainer.java index 80b321e..4286d48 100644 --- a/src/main/java/com/naaturel/ANN/domain/abstraction/Trainer.java +++ b/src/main/java/com/naaturel/ANN/domain/abstraction/Trainer.java @@ -3,5 +3,5 @@ package com.naaturel.ANN.domain.abstraction; import com.naaturel.ANN.infrastructure.dataset.DataSet; public interface Trainer { - void train(float learningRate, Model model, DataSet dataset); + void train(float learningRate, int epoch, Model model, DataSet dataset); } diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/FullyConnectedNetwork.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/FullyConnectedNetwork.java index 48cd74c..f2a63b8 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/neuron/FullyConnectedNetwork.java +++ b/src/main/java/com/naaturel/ANN/domain/model/neuron/FullyConnectedNetwork.java @@ -6,6 +6,7 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; /** @@ -13,13 +14,13 @@ import java.util.function.Consumer; */ public class FullyConnectedNetwork implements Model { - private final List layers;; + private final List layers; private final Map> connectionMap; - private final Map neuronIndex; + private final Map layerIndexByNeuron; public FullyConnectedNetwork(List layers) { this.layers = layers; this.connectionMap = this.createConnectionMap(); - this.neuronIndex = this.createNeuronIndex(); + this.layerIndexByNeuron = this.createNeuronIndex(); } @Override @@ -71,8 +72,9 @@ public class FullyConnectedNetwork implements Model { } @Override - public int indexOf(Neuron n) { - return this.neuronIndex.get(n); + public int indexInLayerOf(Neuron n) { + int layerIndex = this.layerIndexByNeuron.get(n); + return this.layers.get(layerIndex).indexInLayerOf(n); } private Map> createConnectionMap() { @@ -83,14 +85,16 @@ public class FullyConnectedNetwork implements Model { this.layers.get(i + 1).forEachNeuron(nextLayerNeurons::add); this.layers.get(i).forEachNeuron(n -> res.put(n, nextLayerNeurons)); } - return res; } private Map createNeuronIndex() { Map res = new HashMap<>(); - int[] index = {0}; - this.layers.forEach(l -> l.forEachNeuron(n -> res.put(n, index[0]++))); + AtomicInteger index = new AtomicInteger(0); + this.layers.forEach(l -> { + l.forEachNeuron(n -> res.put(n, index.get())); + index.incrementAndGet(); + }); return res; } } diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java index 67282be..36c487d 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java +++ b/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java @@ -3,15 +3,19 @@ package com.naaturel.ANN.domain.model.neuron; import com.naaturel.ANN.domain.abstraction.Model; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.function.Consumer; public class Layer implements Model { private final List neurons; + private final Map neuronIndex; public Layer(List neurons) { this.neurons = neurons; + this.neuronIndex = createNeuronIndex(); } @Override @@ -39,8 +43,8 @@ public class Layer implements Model { } @Override - public int indexOf(Neuron n) { - return this.neurons.indexOf(n); + public int indexInLayerOf(Neuron n) { + return this.neuronIndex.get(n); } @Override @@ -62,4 +66,14 @@ public class Layer implements Model { public void forEachNeuronConnectedTo(Neuron n, Consumer consumer) { throw new UnsupportedOperationException("Neurons have no connection within the same layer"); } + + private Map createNeuronIndex() { + Map res = new HashMap<>(); + int[] index = {0}; + this.neurons.forEach(n -> { + res.put(n, index[0]++); + }); + return res; + } + } diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java index cd4272f..1ca4327 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java +++ b/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java @@ -72,7 +72,7 @@ public class Neuron implements Model { } @Override - public int indexOf(Neuron n) { + public int indexInLayerOf(Neuron n) { return 0; } diff --git a/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java b/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java index afbecff..18921d1 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java +++ b/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java @@ -20,11 +20,13 @@ public class TrainingPipeline { private Consumer afterEpoch; private Predicate stopCondition; - private GraphVisualizer visualizer; private boolean verbose; private boolean visualization; private boolean timeMeasurement; + private GraphVisualizer visualizer; + private int verboseDelay; + public TrainingPipeline(List steps) { this.steps = new ArrayList<>(steps); this.stopCondition = (ctx) -> false; @@ -47,8 +49,10 @@ public class TrainingPipeline { return this; } - public TrainingPipeline withVerbose(boolean enabled) { + public TrainingPipeline withVerbose(boolean enabled, int epochDelay) { + if(epochDelay <= 0) throw new IllegalArgumentException("Epoch delay cannot lower or equal to 0"); this.verbose = enabled; + this.verboseDelay = epochDelay; return this; } @@ -71,9 +75,10 @@ public class TrainingPipeline { this.beforeEpoch.accept(ctx); this.executeSteps(ctx); this.afterEpoch.accept(ctx); - if(this.verbose) { + if(this.verbose && ctx.epoch % this.verboseDelay == 0) { System.out.printf("[Global error] : %f\n", ctx.globalLoss); } + ctx.epoch += 1; } while (!this.stopCondition.test(ctx)); if(this.timeMeasurement) { @@ -94,7 +99,7 @@ public class TrainingPipeline { step.run(); } - if(this.verbose) { + if(this.verbose && ctx.epoch % this.verboseDelay == 0) { System.out.printf("Epoch : %d, ", ctx.epoch); System.out.printf("predicted : %s, ", Arrays.toString(ctx.predictions.toArray())); System.out.printf("expected : %s, ", Arrays.toString(ctx.expectations.toArray())); @@ -102,7 +107,6 @@ public class TrainingPipeline { System.out.printf("loss : %.5f\n", ctx.localLoss); } } - ctx.epoch += 1; } private void visualize(TrainingContext ctx){ diff --git a/src/main/java/com/naaturel/ANN/implementation/gradientDescent/SquareLossStep.java b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/SquareLossStep.java index cbced9e..f928f00 100644 --- a/src/main/java/com/naaturel/ANN/implementation/gradientDescent/SquareLossStep.java +++ b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/SquareLossStep.java @@ -18,5 +18,6 @@ public class SquareLossStep implements AlgorithmStep { Stream deltaStream = this.context.deltas.stream(); this.context.localLoss = deltaStream.reduce(0.0F, (acc, d) -> (float) (acc + Math.pow(d, 2))); this.context.localLoss /= 2; + this.context.globalLoss += this.context.localLoss; //broke MSE en gradientDescentTraining } } diff --git a/src/main/java/com/naaturel/ANN/implementation/multiLayers/BackpropagationCorrectionStep.java b/src/main/java/com/naaturel/ANN/implementation/multiLayers/BackpropagationCorrectionStep.java index 85c0a3b..25ec57d 100644 --- a/src/main/java/com/naaturel/ANN/implementation/multiLayers/BackpropagationCorrectionStep.java +++ b/src/main/java/com/naaturel/ANN/implementation/multiLayers/BackpropagationCorrectionStep.java @@ -12,7 +12,7 @@ public class BackpropagationCorrectionStep implements AlgorithmStep { @Override public void run() { - this.context.model.forEachOutputNeurons(n -> { + this.context.model.forEachNeuron(n -> { n.forEachSynapse(syn -> { float lr = context.learningRate; float signal = context.errorSignals.get(n); diff --git a/src/main/java/com/naaturel/ANN/implementation/multiLayers/ErrorSignalStep.java b/src/main/java/com/naaturel/ANN/implementation/multiLayers/ErrorSignalStep.java index b19e428..cbb0880 100644 --- a/src/main/java/com/naaturel/ANN/implementation/multiLayers/ErrorSignalStep.java +++ b/src/main/java/com/naaturel/ANN/implementation/multiLayers/ErrorSignalStep.java @@ -2,13 +2,8 @@ package com.naaturel.ANN.implementation.multiLayers; import com.naaturel.ANN.domain.abstraction.AlgorithmStep; import com.naaturel.ANN.domain.model.neuron.Neuron; -import com.naaturel.ANN.infrastructure.dataset.DataSetEntry; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; import java.util.Map; -import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; public class ErrorSignalStep implements AlgorithmStep { @@ -20,23 +15,19 @@ public class ErrorSignalStep implements AlgorithmStep { @Override public void run() { - this.context.deltas = new ArrayList<>(); - this.context.errorSignals = new HashMap<>(); - this.calculateOutputLayerErrorSignals(); - - this.context.model.forEachNeuron(n -> calculateErrorSignalRecursive(n, this.context.errorSignals)); + this.context.model.forEachNeuron(n -> { + calculateErrorSignalRecursive(n, this.context.errorSignals); + }); } private float calculateErrorSignalRecursive(Neuron n, Map signals) { if (signals.containsKey(n)) return signals.get(n); - AtomicInteger connectedIndex = new AtomicInteger(0); + int neuronIndex = this.context.model.indexInLayerOf(n); AtomicReference signalSum = new AtomicReference<>(0F); this.context.model.forEachNeuronConnectedTo(n, connected -> { - int neuronIndex = this.context.model.indexOf(n); float weightedSignal = calculateErrorSignalRecursive(connected, signals) * connected.getWeight(neuronIndex); signalSum.set(signalSum.get() + weightedSignal); - connectedIndex.incrementAndGet(); }); float derivative = n.getActivationFunction().derivative(n.getOutput()); @@ -44,22 +35,4 @@ public class ErrorSignalStep implements AlgorithmStep { signals.put(n, finalSignal); return finalSignal; } - - private void calculateOutputLayerErrorSignals(){ - DataSetEntry entry = this.context.currentEntry; - List expectations = this.context.dataset.getLabelsAsFloat(entry); - AtomicInteger index = new AtomicInteger(0); - - this.context.model.forEachOutputNeurons(n -> { - float expected = expectations.get(index.get()); - float predicted = n.getOutput(); - float output = n.getOutput(); - float delta = expected - predicted; - float signal = delta * n.getActivationFunction().derivative(output); - - this.context.deltas.add(delta); - this.context.errorSignals.put(n, signal); - index.incrementAndGet(); - }); - } } diff --git a/src/main/java/com/naaturel/ANN/implementation/multiLayers/OutputLayerErrorStep.java b/src/main/java/com/naaturel/ANN/implementation/multiLayers/OutputLayerErrorStep.java new file mode 100644 index 0000000..1390d37 --- /dev/null +++ b/src/main/java/com/naaturel/ANN/implementation/multiLayers/OutputLayerErrorStep.java @@ -0,0 +1,39 @@ +package com.naaturel.ANN.implementation.multiLayers; + +import com.naaturel.ANN.domain.abstraction.AlgorithmStep; +import com.naaturel.ANN.infrastructure.dataset.DataSetEntry; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +public class OutputLayerErrorStep implements AlgorithmStep { + + private final GradientBackpropagationContext context; + + public OutputLayerErrorStep(GradientBackpropagationContext context){ + this.context = context; + } + + @Override + public void run() { + context.deltas = new ArrayList<>(); + DataSetEntry entry = this.context.currentEntry; + List expectations = this.context.dataset.getLabelsAsFloat(entry); + AtomicInteger index = new AtomicInteger(0); + + context.errorSignals = new HashMap<>(); + this.context.model.forEachOutputNeurons(n -> { + float expected = expectations.get(index.get()); + float predicted = n.getOutput(); + float output = n.getOutput(); + float delta = expected - predicted; + float signal = delta * n.getActivationFunction().derivative(output); + + this.context.deltas.add(delta); + this.context.errorSignals.put(n, signal); + index.incrementAndGet(); + }); + } +} diff --git a/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java index b45c364..6a4805e 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java @@ -23,7 +23,7 @@ public class AdalineTraining implements Trainer { } @Override - public void train(float learningRate, Model model, DataSet dataset) { + public void train(float learningRate, int epoch, Model model, DataSet dataset) { AdalineTrainingContext context = new AdalineTrainingContext(); context.dataset = dataset; context.model = model; @@ -38,11 +38,11 @@ public class AdalineTraining implements Trainer { ); new TrainingPipeline(steps) - .stopCondition(ctx -> ctx.globalLoss <= 0.00F || ctx.epoch > 25) + .stopCondition(ctx -> ctx.globalLoss <= 0.00F || ctx.epoch > epoch) .beforeEpoch(ctx -> ctx.globalLoss = 0.0F) .afterEpoch(ctx -> ctx.globalLoss /= context.dataset.size()) .withTimeMeasurement(true) - .withVerbose(true) + .withVerbose(true, 1) .withVisualization(true, new GraphVisualizer()) .run(context); } diff --git a/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java index 54a9ff5..833ae01 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java @@ -8,14 +8,14 @@ import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep; import com.naaturel.ANN.implementation.multiLayers.BackpropagationCorrectionStep; import com.naaturel.ANN.implementation.multiLayers.GradientBackpropagationContext; import com.naaturel.ANN.implementation.multiLayers.ErrorSignalStep; +import com.naaturel.ANN.implementation.multiLayers.OutputLayerErrorStep; import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep; import com.naaturel.ANN.infrastructure.dataset.DataSet; - import java.util.List; public class GradientBackpropagationTraining implements Trainer { @Override - public void train(float learningRate, Model model, DataSet dataset) { + public void train(float learningRate, int epoch, Model model, DataSet dataset) { GradientBackpropagationContext context = new GradientBackpropagationContext(); context.dataset = dataset; context.model = model; @@ -23,18 +23,20 @@ public class GradientBackpropagationTraining implements Trainer { List steps = List.of( new SimplePredictionStep(context), + new OutputLayerErrorStep(context), new ErrorSignalStep(context), new BackpropagationCorrectionStep(context), new SquareLossStep(context) ); new TrainingPipeline(steps) - .beforeEpoch(ctx -> ctx.globalLoss = 0.0F) - .afterEpoch(ctx -> ctx.globalLoss = ctx.localLoss/dataset.size()) - .stopCondition(ctx -> ctx.epoch > 1000000) - .withVerbose(false) + .stopCondition(ctx -> ctx.globalLoss <= 0.0001F || ctx.epoch > epoch) + .beforeEpoch(ctx -> { + ctx.globalLoss = 0.0F; + }) + .afterEpoch(ctx -> ctx.globalLoss /= dataset.size()) + .withVerbose(true, 100) .withTimeMeasurement(true) .run(context); - } } diff --git a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java index 3a849bc..993fd93 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java @@ -23,7 +23,7 @@ public class GradientDescentTraining implements Trainer { } @Override - public void train(float learningRate, Model model, DataSet dataset) { + public void train(float learningRate, int epoch, Model model, DataSet dataset) { GradientDescentTrainingContext context = new GradientDescentTrainingContext(); context.dataset = dataset; context.model = model; @@ -38,7 +38,7 @@ public class GradientDescentTraining implements Trainer { ); new TrainingPipeline(steps) - .stopCondition(ctx -> ctx.globalLoss <= 0.08F || ctx.epoch > 150) + .stopCondition(ctx -> ctx.globalLoss <= 0.08F || ctx.epoch > epoch) .beforeEpoch(ctx -> { GradientDescentTrainingContext gdCtx = (GradientDescentTrainingContext) ctx; gdCtx.globalLoss = 0.0F; diff --git a/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java index b2cb32a..78fd2c7 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java @@ -16,7 +16,7 @@ public class SimpleTraining implements Trainer { } @Override - public void train(float learningRate, Model model, DataSet dataset) { + public void train(float learningRate, int epoch, Model model, DataSet dataset) { SimpleTrainingContext context = new SimpleTrainingContext(); context.dataset = dataset; context.model = model; @@ -32,9 +32,9 @@ public class SimpleTraining implements Trainer { TrainingPipeline pipeline = new TrainingPipeline(steps); pipeline - .stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 10) + .stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > epoch) .beforeEpoch(ctx -> ctx.globalLoss = 0) - .withVerbose(true) + .withVerbose(true, 1) .run(context); }