From 5ddf6dc580bf24ad6ea4341f554501ec74d561c3 Mon Sep 17 00:00:00 2001 From: Laurent Date: Wed, 1 Apr 2026 22:48:06 +0200 Subject: [PATCH] Reworked synapses data structure --- src/main/java/com/naaturel/ANN/Main.java | 4 +- .../ANN/domain/abstraction/Model.java | 2 +- .../model/neuron/FullyConnectedNetwork.java | 7 -- .../ANN/domain/model/neuron/Layer.java | 4 +- .../ANN/domain/model/neuron/Neuron.java | 86 ++++++++----------- .../model/training/TrainingPipeline.java | 6 +- .../GradientDescentCorrectionStrategy.java | 14 +-- .../GradientDescentErrorStrategy.java | 7 +- .../BackpropagationCorrectionStep.java | 18 ++-- .../multiLayers/OutputLayerErrorStep.java | 2 +- .../SimpleCorrectionStep.java | 13 ++- .../GradientBackpropagationTraining.java | 4 +- .../training/GradientDescentTraining.java | 4 +- 13 files changed, 77 insertions(+), 94 deletions(-) diff --git a/src/main/java/com/naaturel/ANN/Main.java b/src/main/java/com/naaturel/ANN/Main.java index 9cc16bc..85c5fa0 100644 --- a/src/main/java/com/naaturel/ANN/Main.java +++ b/src/main/java/com/naaturel/ANN/Main.java @@ -26,7 +26,7 @@ public class Main { DataSet dataset = new DatasetExtractor() .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_12.csv", nbrClass); - int[] neuronPerLayer = new int[]{10, 5, 10, dataset.getNbrLabels()}; + int[] neuronPerLayer = new int[]{100, 100, 50, dataset.getNbrLabels()}; int nbrInput = dataset.getNbrInputs(); FullyConnectedNetwork network = createNetwork(neuronPerLayer, nbrInput); @@ -34,7 +34,7 @@ public class Main { System.out.println(network.synCount()); Trainer trainer = new GradientBackpropagationTraining(); - trainer.train(0.01F, 2000, network, dataset); + trainer.train(0.001F, 2000, network, dataset); plotGraph(dataset, network); } diff --git a/src/main/java/com/naaturel/ANN/domain/abstraction/Model.java b/src/main/java/com/naaturel/ANN/domain/abstraction/Model.java index b9899e1..9d13be9 100644 --- a/src/main/java/com/naaturel/ANN/domain/abstraction/Model.java +++ b/src/main/java/com/naaturel/ANN/domain/abstraction/Model.java @@ -12,7 +12,7 @@ public interface Model { int neuronCount(); int indexInLayerOf(Neuron n); void forEachNeuron(Consumer consumer); - void forEachSynapse(Consumer consumer); + //void forEachSynapse(Consumer consumer); void forEachOutputNeurons(Consumer consumer); void forEachNeuronConnectedTo(Neuron n, Consumer consumer); List predict(List inputs); diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/FullyConnectedNetwork.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/FullyConnectedNetwork.java index dfd51ef..d7d9407 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/neuron/FullyConnectedNetwork.java +++ b/src/main/java/com/naaturel/ANN/domain/model/neuron/FullyConnectedNetwork.java @@ -51,13 +51,6 @@ public class FullyConnectedNetwork implements Model { return res; } - @Override - public void forEachSynapse(Consumer consumer) { - for(Layer l : this.layers){ - l.forEachSynapse(consumer); - } - } - @Override public void forEachNeuron(Consumer consumer) { for(Layer l : this.layers){ diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java index 3d3039f..a214296 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java +++ b/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java @@ -54,12 +54,12 @@ public class Layer implements Model { } } - @Override + /*@Override public void forEachSynapse(Consumer consumer) { for (Neuron n : this.neurons){ n.forEachSynapse(consumer); } - } + }*/ @Override public void forEachOutputNeurons(Consumer consumer) { diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java index f5356a7..297f6dd 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java +++ b/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java @@ -7,42 +7,39 @@ import java.util.function.Consumer; public class Neuron implements Model { - protected int id; - protected final Synapse[] synapses; - protected Bias bias; - protected ActivationFunction activationFunction; - protected Float output; - protected Float weightedSum; - protected final float[] weights; - protected final float[] inputs; + private final int id; + private float output; + private final float[] weights; + private final float[] inputs; + private final ActivationFunction activationFunction; public Neuron(int id, Synapse[] synapses, Bias bias, ActivationFunction func){ this.id = id; - this.synapses = synapses; - this.bias = bias; this.activationFunction = func; - this.output = null; - this.weightedSum = null; - weights = new float[synapses.length]; - inputs = new float[synapses.length]; - } + weights = new float[synapses.length+1]; //takes the bias into account + inputs = new float[synapses.length+1]; //takes the bias into account - public void updateBias(Weight weight) { - this.bias.setWeight(weight.getValue()); - } - - public void updateWeight(int index, Weight weight) { - this.synapses[index].setWeight(weight.getValue()); - } - - protected void setInputs(List inputs){ - for(int i = 0; i < inputs.size() && i < synapses.length; i++){ - Synapse syn = this.synapses[i]; - syn.setInput(inputs.get(i)); + weights[0] = bias.getWeight(); + inputs[0] = bias.getInput(); + for (int i = 0; i < synapses.length; i++){ + weights[i+1] = synapses[i].getWeight(); + inputs[i+1] = synapses[i].getInput(); } } + public void setWeight(int index, float value) { + this.weights[index] = value; + } + + public float getWeight(int index) { + return this.weights[index]; + } + + public float getInput(int index) { + return this.inputs[index]; + } + public ActivationFunction getActivationFunction(){ return this.activationFunction; } @@ -51,21 +48,13 @@ public class Neuron implements Model { return this.output; } - public float getWeight(int index){ - return this.synapses[index].getWeight(); - } - - public float getWeightedSum(){ - return this.weightedSum; - } - public float calculateWeightedSum() { - this.weightedSum = 0F; - this.weightedSum += this.bias.getWeight() * this.bias.getInput(); - for(Synapse syn : this.synapses){ - this.weightedSum += syn.getWeight() * syn.getInput(); + int count = synCount(); + float weightedSum = 0F; + for (int i = 0; i < count; i++){ + weightedSum += weights[i] * inputs[i]; } - return this.weightedSum; + return weightedSum; } public int getId(){ @@ -74,7 +63,7 @@ public class Neuron implements Model { @Override public int synCount() { - return this.synapses.length+1; //take the bias into account + return this.weights.length; } @Override @@ -99,14 +88,6 @@ public class Neuron implements Model { consumer.accept(this); } - @Override - public void forEachSynapse(Consumer consumer) { - consumer.accept(this.bias); - for (Synapse syn : this.synapses){ - consumer.accept(syn); - } - } - @Override public void forEachOutputNeurons(Consumer consumer) { consumer.accept(this); @@ -116,4 +97,11 @@ public class Neuron implements Model { public void forEachNeuronConnectedTo(Neuron n, Consumer consumer) { throw new UnsupportedOperationException("Neurons have no connection with themselves"); } + + private void setInputs(List values){ + for(int i = 0; i < values.size(); i++){ + inputs[i+1] = values.get(i).getValue(); + } + } + } diff --git a/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java b/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java index c8b9b81..9e7a367 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java +++ b/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java @@ -86,7 +86,7 @@ public class TrainingPipeline { System.out.printf("[Training finished in %.3fs]", (end-start)/1000.0); } - if(this.visualization) this.visualize(ctx); + //if(this.visualization) this.visualize(ctx); } private void executeSteps(TrainingContext ctx){ @@ -109,7 +109,7 @@ public class TrainingPipeline { } } - private void visualize(TrainingContext ctx){ + /*private void visualize(TrainingContext ctx){ AtomicInteger neuronIndex = new AtomicInteger(0); ctx.model.forEachNeuron(n -> { List weights = new ArrayList<>(); @@ -129,6 +129,6 @@ public class TrainingPipeline { i++; } this.visualizer.buildLineGraph(); - } + }*/ } diff --git a/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentCorrectionStrategy.java b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentCorrectionStrategy.java index f76726a..4b58a68 100644 --- a/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentCorrectionStrategy.java +++ b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentCorrectionStrategy.java @@ -14,12 +14,14 @@ public class GradientDescentCorrectionStrategy implements AlgorithmStep { @Override public void run() { - AtomicInteger i = new AtomicInteger(0); - context.model.forEachSynapse(syn -> { - float corrector = context.correctorTerms.get(i.get()); - float c = syn.getWeight() + corrector; - syn.setWeight(c); - i.incrementAndGet(); + int[] globalSynIndex = {0}; + context.model.forEachNeuron(n -> { + for(int i = 0; i < n.synCount(); i++){ + float corrector = context.correctorTerms.get(globalSynIndex[0]); + float c = n.getWeight(i) + corrector; + n.setWeight(i, c); + globalSynIndex[0]++; + } }); } } diff --git a/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentErrorStrategy.java b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentErrorStrategy.java index 5b36a4d..2119421 100644 --- a/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentErrorStrategy.java +++ b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentErrorStrategy.java @@ -22,13 +22,12 @@ public class GradientDescentErrorStrategy implements AlgorithmStep { context.model.forEachNeuron(neuron -> { float correspondingDelta = context.deltas[neuronIndex.get()]; - neuron.forEachSynapse(syn -> { + for(int i = 0; i < neuron.synCount(); i++){ float corrector = context.correctorTerms.get(synIndex.get()); - corrector += context.learningRate * correspondingDelta * syn.getInput(); + corrector += context.learningRate * correspondingDelta * neuron.getInput(i); context.correctorTerms.set(synIndex.get(), corrector); synIndex.incrementAndGet(); - }); - + } neuronIndex.incrementAndGet(); }); diff --git a/src/main/java/com/naaturel/ANN/implementation/multiLayers/BackpropagationCorrectionStep.java b/src/main/java/com/naaturel/ANN/implementation/multiLayers/BackpropagationCorrectionStep.java index d8034d8..584330c 100644 --- a/src/main/java/com/naaturel/ANN/implementation/multiLayers/BackpropagationCorrectionStep.java +++ b/src/main/java/com/naaturel/ANN/implementation/multiLayers/BackpropagationCorrectionStep.java @@ -21,11 +21,11 @@ public class BackpropagationCorrectionStep implements AlgorithmStep { int[] synIndex = {0}; context.model.forEachNeuron(n -> { float signal = context.errorSignals[n.getId()]; - n.forEachSynapse(syn -> { - inputs[synIndex[0]] = syn.getInput(); + for (int i = 0; i < n.synCount(); i++){ + inputs[synIndex[0]] = n.getInput(i); signals[synIndex[0]] = signal; synIndex[0]++; - }); + } }); float lr = context.learningRate; @@ -44,13 +44,13 @@ public class BackpropagationCorrectionStep implements AlgorithmStep { } private void syncWeights() { - int[] i = {0}; + int[] synIndex = {0}; context.model.forEachNeuron(n -> { - n.forEachSynapse(syn -> { - syn.setWeight(syn.getWeight() + context.correctionBuffer[i[0]]); - context.correctionBuffer[i[0]] = 0f; - i[0]++; - }); + for (int i = 0; i < n.synCount(); i++) { + n.setWeight(i, n.getWeight(i) + context.correctionBuffer[synIndex[0]]); + context.correctionBuffer[synIndex[0]] = 0f; + synIndex[0]++; + } }); } } \ No newline at end of file diff --git a/src/main/java/com/naaturel/ANN/implementation/multiLayers/OutputLayerErrorStep.java b/src/main/java/com/naaturel/ANN/implementation/multiLayers/OutputLayerErrorStep.java index a51f0aa..06b2234 100644 --- a/src/main/java/com/naaturel/ANN/implementation/multiLayers/OutputLayerErrorStep.java +++ b/src/main/java/com/naaturel/ANN/implementation/multiLayers/OutputLayerErrorStep.java @@ -13,7 +13,7 @@ public class OutputLayerErrorStep implements AlgorithmStep { public OutputLayerErrorStep(GradientBackpropagationContext context){ this.context = context; - this.expectations = new float[context.model.neuronCount()]; + this.expectations = new float[context.dataset.getNbrLabels()]; } @Override diff --git a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleCorrectionStep.java b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleCorrectionStep.java index a726318..1d5848b 100644 --- a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleCorrectionStep.java +++ b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleCorrectionStep.java @@ -18,17 +18,16 @@ public class SimpleCorrectionStep implements AlgorithmStep { public void run() { if(context.expectations.equals(context.predictions)) return; AtomicInteger neuronIndex = new AtomicInteger(0); - AtomicInteger synIndex = new AtomicInteger(0); context.model.forEachNeuron(neuron -> { float correspondingDelta = context.deltas[neuronIndex.get()]; - neuron.forEachSynapse(syn -> { - float currentW = syn.getWeight(); - float currentInput = syn.getInput(); + + for(int i = 0; i < neuron.synCount(); i++){ + float currentW = neuron.getWeight(i); + float currentInput = neuron.getInput(i); float newValue = currentW + (context.learningRate * correspondingDelta * currentInput); - syn.setWeight(newValue); - synIndex.incrementAndGet(); - }); + neuron.setWeight(i, newValue); + } neuronIndex.incrementAndGet(); }); } diff --git a/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java index 2f7c80c..8452321 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java @@ -17,7 +17,7 @@ public class GradientBackpropagationTraining implements Trainer { @Override public void train(float learningRate, int epoch, Model model, DataSet dataset) { GradientBackpropagationContext context = - new GradientBackpropagationContext(model, dataset, learningRate, dataset.size()/3); + new GradientBackpropagationContext(model, dataset, learningRate, dataset.size()); List steps = List.of( new SimplePredictionStep(context), @@ -35,7 +35,7 @@ public class GradientBackpropagationTraining implements Trainer { .afterEpoch(ctx -> { ctx.globalLoss /= dataset.size(); }) - .withVerbose(true,epoch/10) + .withVerbose(false,epoch/10) .withTimeMeasurement(true) .run(context); } diff --git a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java index 5296936..e2d94e2 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java @@ -41,7 +41,9 @@ public class GradientDescentTraining implements Trainer { GradientDescentTrainingContext gdCtx = (GradientDescentTrainingContext) ctx; gdCtx.globalLoss = 0.0F; gdCtx.correctorTerms.clear(); - gdCtx.model.forEachSynapse(syn -> gdCtx.correctorTerms.add(0F)); + for(int i = 0; i < gdCtx.model.synCount(); i++){ + gdCtx.correctorTerms.add(0F); + } }) .afterEpoch(ctx -> { context.globalLoss /= context.dataset.size();