Implement multi layer

This commit is contained in:
2026-03-30 13:38:44 +02:00
parent b36a900f87
commit aed78fe9d2
13 changed files with 153 additions and 51 deletions

View File

@@ -3,6 +3,7 @@ package com.naaturel.ANN;
import com.naaturel.ANN.domain.model.neuron.Neuron; import com.naaturel.ANN.domain.model.neuron.Neuron;
import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.implementation.multiLayers.Sigmoid; import com.naaturel.ANN.implementation.multiLayers.Sigmoid;
import com.naaturel.ANN.implementation.multiLayers.TanH;
import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining; import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor; import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
@@ -14,33 +15,36 @@ public class Main {
public static void main(String[] args){ public static void main(String[] args){
int nbrInput = 2; int nbrInput = 25;
int nbrClass = 3; int nbrClass = 4;
int nbrLayers = 2; int[] neuronPerLayer = new int[]{10, nbrClass};
DataSet dataset = new DatasetExtractor() DataSet dataset = new DatasetExtractor()
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_3_1.csv", nbrClass); .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_3_5.csv", nbrClass);
List<Layer> layers = new ArrayList<>(); List<Layer> layers = new ArrayList<>();
for(int i = 0; i < nbrLayers; i++){ for (int i = 0; i < neuronPerLayer.length; i++){
List<Neuron> neurons = new ArrayList<>(); List<Neuron> neurons = new ArrayList<>();
for (int j=0; j < nbrClass; j++){ for (int j = 0; j < neuronPerLayer[i]; j++){
int nbrSyn = i == 0 ? nbrInput: neuronPerLayer[i-1];
List<Synapse> syns = new ArrayList<>(); List<Synapse> syns = new ArrayList<>();
for (int k=0; k < nbrInput; k++){ for (int k=0; k < nbrSyn; k++){
syns.add(new Synapse(new Input(0), new Weight(0))); syns.add(new Synapse(new Input(0), new Weight()));
} }
Bias bias = new Bias(new Weight(0)); Bias bias = new Bias(new Weight());
Neuron n = new Neuron(syns, bias, new Sigmoid(1)); Neuron n = new Neuron(syns, bias, new TanH());
neurons.add(n); neurons.add(n);
} }
Layer layer = new Layer(neurons); Layer layer = new Layer(neurons);
layers.add(layer); layers.add(layer);
} }
FullyConnectedNetwork network = new FullyConnectedNetwork(layers); FullyConnectedNetwork network = new FullyConnectedNetwork(layers);
Trainer trainer = new GradientBackpropagationTraining(); Trainer trainer = new GradientBackpropagationTraining();

View File

@@ -10,6 +10,7 @@ import java.util.function.Consumer;
public interface Model { public interface Model {
int synCount(); int synCount();
int neuronCount(); int neuronCount();
int indexOf(Neuron n);
void forEachNeuron(Consumer<Neuron> consumer); void forEachNeuron(Consumer<Neuron> consumer);
void forEachSynapse(Consumer<Synapse> consumer); void forEachSynapse(Consumer<Synapse> consumer);
void forEachOutputNeurons(Consumer<Neuron> consumer); void forEachOutputNeurons(Consumer<Neuron> consumer);

View File

@@ -1,7 +1,6 @@
package com.naaturel.ANN.domain.model.neuron; package com.naaturel.ANN.domain.model.neuron;
import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.Network;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
@@ -16,10 +15,11 @@ public class FullyConnectedNetwork implements Model {
private final List<Layer> layers;; private final List<Layer> layers;;
private final Map<Neuron, List<Neuron>> connectionMap; private final Map<Neuron, List<Neuron>> connectionMap;
private final Map<Neuron, Integer> neuronIndex;
public FullyConnectedNetwork(List<Layer> layers) { public FullyConnectedNetwork(List<Layer> layers) {
this.layers = layers; this.layers = layers;
this.connectionMap = this.createConnectionMap(); this.connectionMap = this.createConnectionMap();
this.neuronIndex = this.createNeuronIndex();
} }
@Override @Override
@@ -70,6 +70,11 @@ public class FullyConnectedNetwork implements Model {
this.connectionMap.get(n).forEach(consumer); this.connectionMap.get(n).forEach(consumer);
} }
@Override
public int indexOf(Neuron n) {
return this.neuronIndex.get(n);
}
private Map<Neuron, List<Neuron>> createConnectionMap() { private Map<Neuron, List<Neuron>> createConnectionMap() {
Map<Neuron, List<Neuron>> res = new HashMap<>(); Map<Neuron, List<Neuron>> res = new HashMap<>();
@@ -81,4 +86,11 @@ public class FullyConnectedNetwork implements Model {
return res; return res;
} }
private Map<Neuron, Integer> createNeuronIndex() {
Map<Neuron, Integer> res = new HashMap<>();
int[] index = {0};
this.layers.forEach(l -> l.forEachNeuron(n -> res.put(n, index[0]++)));
return res;
}
} }

View File

@@ -38,6 +38,11 @@ public class Layer implements Model {
return this.neurons.size(); return this.neurons.size();
} }
@Override
public int indexOf(Neuron n) {
return this.neurons.indexOf(n);
}
@Override @Override
public void forEachNeuron(Consumer<Neuron> consumer) { public void forEachNeuron(Consumer<Neuron> consumer) {
this.neurons.forEach(consumer); this.neurons.forEach(consumer);

View File

@@ -11,12 +11,14 @@ public class Neuron implements Model {
protected Bias bias; protected Bias bias;
protected ActivationFunction activationFunction; protected ActivationFunction activationFunction;
protected Float output; protected Float output;
protected Float weightedSum;
public Neuron(List<Synapse> synapses, Bias bias, ActivationFunction func){ public Neuron(List<Synapse> synapses, Bias bias, ActivationFunction func){
this.synapses = synapses; this.synapses = synapses;
this.bias = bias; this.bias = bias;
this.activationFunction = func; this.activationFunction = func;
this.output = 0F; this.output = null;
this.weightedSum = null;
} }
public void updateBias(Weight weight) { public void updateBias(Weight weight) {
@@ -42,13 +44,22 @@ public class Neuron implements Model {
return this.output; return this.output;
} }
public float getWeight(int index){
return this.synapses.get(index).getWeight();
}
public float getWeightedSum(){
return this.weightedSum;
}
public float calculateWeightedSum() { public float calculateWeightedSum() {
float res = 0; float res = 0;
res += this.bias.getWeight() * this.bias.getInput(); res += this.bias.getWeight() * this.bias.getInput();
for(Synapse syn : this.synapses){ for(Synapse syn : this.synapses){
res += syn.getWeight() * syn.getInput(); res += syn.getWeight() * syn.getInput();
} }
return res; this.weightedSum = res;
return this.weightedSum;
} }
@Override @Override
@@ -61,6 +72,11 @@ public class Neuron implements Model {
return 1; return 1;
} }
@Override
public int indexOf(Neuron n) {
return 0;
}
@Override @Override
public List<Float> predict(List<Input> inputs) { public List<Float> predict(List<Input> inputs) {
this.setInputs(inputs); this.setInputs(inputs);

View File

@@ -2,7 +2,6 @@ package com.naaturel.ANN.domain.model.training;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep; import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.TrainingContext; import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry; import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
import com.naaturel.ANN.domain.model.neuron.Input; import com.naaturel.ANN.domain.model.neuron.Input;
import com.naaturel.ANN.infrastructure.graph.GraphVisualizer; import com.naaturel.ANN.infrastructure.graph.GraphVisualizer;

View File

@@ -0,0 +1,24 @@
package com.naaturel.ANN.implementation.multiLayers;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
public class BackpropagationCorrectionStep implements AlgorithmStep {
private GradientBackpropagationContext context;
public BackpropagationCorrectionStep(GradientBackpropagationContext context){
this.context = context;
}
@Override
public void run() {
this.context.model.forEachOutputNeurons(n -> {
n.forEachSynapse(syn -> {
float lr = context.learningRate;
float signal = context.errorSignals.get(n);
float newWeight = syn.getWeight() + (lr * signal * syn.getInput());
syn.setWeight(newWeight);
});
});
}
}

View File

@@ -0,0 +1,65 @@
package com.naaturel.ANN.implementation.multiLayers;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.model.neuron.Neuron;
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
public class ErrorSignalStep implements AlgorithmStep {
private GradientBackpropagationContext context;
public ErrorSignalStep(GradientBackpropagationContext context) {
this.context = context;
}
@Override
public void run() {
this.context.deltas = new ArrayList<>();
this.context.errorSignals = new HashMap<>();
this.calculateOutputLayerErrorSignals();
this.context.model.forEachNeuron(n -> calculateErrorSignalRecursive(n, this.context.errorSignals));
}
private float calculateErrorSignalRecursive(Neuron n, Map<Neuron, Float> signals) {
if (signals.containsKey(n)) return signals.get(n);
AtomicInteger connectedIndex = new AtomicInteger(0);
AtomicReference<Float> signalSum = new AtomicReference<>(0F);
this.context.model.forEachNeuronConnectedTo(n, connected -> {
int neuronIndex = this.context.model.indexOf(n);
float weightedSignal = calculateErrorSignalRecursive(connected, signals) * connected.getWeight(neuronIndex);
signalSum.set(signalSum.get() + weightedSignal);
connectedIndex.incrementAndGet();
});
float derivative = n.getActivationFunction().derivative(n.getOutput());
float finalSignal = derivative * signalSum.get();
signals.put(n, finalSignal);
return finalSignal;
}
private void calculateOutputLayerErrorSignals(){
DataSetEntry entry = this.context.currentEntry;
List<Float> expectations = this.context.dataset.getLabelsAsFloat(entry);
AtomicInteger index = new AtomicInteger(0);
this.context.model.forEachOutputNeurons(n -> {
float expected = expectations.get(index.get());
float predicted = n.getOutput();
float output = n.getOutput();
float delta = expected - predicted;
float signal = delta * n.getActivationFunction().derivative(output);
this.context.deltas.add(delta);
this.context.errorSignals.put(n, signal);
index.incrementAndGet();
});
}
}

View File

@@ -1,15 +1,14 @@
package com.naaturel.ANN.implementation.multiLayers; package com.naaturel.ANN.implementation.multiLayers;
import com.naaturel.ANN.domain.abstraction.TrainingContext; import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.domain.model.neuron.Neuron;
import java.util.ArrayList; import java.util.Map;
import java.util.List;
public class GradientBackpropagationContext extends TrainingContext { public class GradientBackpropagationContext extends TrainingContext {
public List<Float> hiddenDeltas; public Map<Neuron, Float> errorSignals;
public GradientBackpropagationContext(){ public GradientBackpropagationContext(){
} }
} }

View File

@@ -1,24 +0,0 @@
package com.naaturel.ANN.implementation.multiLayers;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.model.neuron.Neuron;
public class GradientBackpropagationStep implements AlgorithmStep {
private GradientBackpropagationContext context;
public GradientBackpropagationStep(GradientBackpropagationContext context) {
this.context = context;
}
@Override
public void run() {
}
private float calculateDeltaRecursive(Neuron n){
}
}

View File

@@ -4,9 +4,10 @@ import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.model.training.TrainingPipeline; import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep;
import com.naaturel.ANN.implementation.multiLayers.BackpropagationCorrectionStep;
import com.naaturel.ANN.implementation.multiLayers.GradientBackpropagationContext; import com.naaturel.ANN.implementation.multiLayers.GradientBackpropagationContext;
import com.naaturel.ANN.implementation.multiLayers.GradientBackpropagationStep; import com.naaturel.ANN.implementation.multiLayers.ErrorSignalStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep; import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.infrastructure.dataset.DataSet;
@@ -19,17 +20,19 @@ public class GradientBackpropagationTraining implements Trainer {
GradientBackpropagationContext context = new GradientBackpropagationContext(); GradientBackpropagationContext context = new GradientBackpropagationContext();
context.dataset = dataset; context.dataset = dataset;
context.model = model; context.model = model;
context.learningRate = 0.001F; context.learningRate = 0.1F;
List<AlgorithmStep> steps = List.of( List<AlgorithmStep> steps = List.of(
new SimplePredictionStep(context), new SimplePredictionStep(context),
new SimpleDeltaStep(context), new ErrorSignalStep(context),
new GradientBackpropagationStep(context) new BackpropagationCorrectionStep(context),
new SquareLossStep(context)
); );
new TrainingPipeline(steps) new TrainingPipeline(steps)
.stopCondition(ctx -> false) .stopCondition(ctx -> ctx.epoch == 250)
.withVerbose(true) .withVerbose(true)
.withTimeMeasurement(true)
.run(context); .run(context);
} }

View File

@@ -11,7 +11,6 @@ import com.naaturel.ANN.implementation.gradientDescent.GradientDescentCorrection
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep; import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep; import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep; import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
import com.naaturel.ANN.implementation.training.steps.*;
import com.naaturel.ANN.infrastructure.graph.GraphVisualizer; import com.naaturel.ANN.infrastructure.graph.GraphVisualizer;
import java.util.ArrayList; import java.util.ArrayList;

View File

@@ -6,7 +6,6 @@ import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.implementation.simplePerceptron.*; import com.naaturel.ANN.implementation.simplePerceptron.*;
import com.naaturel.ANN.domain.model.training.TrainingPipeline; import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.training.steps.*;
import java.util.List; import java.util.List;