Fix multi layer implementation

This commit is contained in:
2026-03-30 21:13:03 +02:00
parent ada01d350b
commit fd97d0853c
15 changed files with 108 additions and 71 deletions

View File

@@ -17,12 +17,12 @@ public class Main {
public static void main(String[] args){
int nbrClass = 1;
int nbrClass = 3;
DataSet dataset = new DatasetExtractor()
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_12.csv", nbrClass);
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_14.csv", nbrClass);
int[] neuronPerLayer = new int[]{10, dataset.getNbrLabels()};
int[] neuronPerLayer = new int[]{3, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 37, dataset.getNbrLabels()};
int nbrInput = dataset.getNbrInputs();
List<Layer> layers = new ArrayList<>();
@@ -40,7 +40,7 @@ public class Main {
Bias bias = new Bias(new Weight());
Neuron n = new Neuron(syns, bias, new Sigmoid(2));
Neuron n = new Neuron(syns, bias, new TanH());
neurons.add(n);
}
Layer layer = new Layer(neurons);
@@ -50,7 +50,7 @@ public class Main {
FullyConnectedNetwork network = new FullyConnectedNetwork(layers);
Trainer trainer = new GradientBackpropagationTraining();
trainer.train(0.5F, network, dataset);
trainer.train(0.001F, 1000, network, dataset);
/*GraphVisualizer visualizer = new GraphVisualizer();
@@ -59,7 +59,7 @@ public class Main {
visualizer.addPoint("Label " + label.getFirst(), entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
}
float min = -2F;
float min = -3F;
float max = 2F;
float step = 0.01F;
for (float x = min; x < max; x+=step){

View File

@@ -10,7 +10,7 @@ import java.util.function.Consumer;
public interface Model {
int synCount();
int neuronCount();
int indexOf(Neuron n);
int indexInLayerOf(Neuron n);
void forEachNeuron(Consumer<Neuron> consumer);
void forEachSynapse(Consumer<Synapse> consumer);
void forEachOutputNeurons(Consumer<Neuron> consumer);

View File

@@ -3,5 +3,5 @@ package com.naaturel.ANN.domain.abstraction;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
public interface Trainer {
void train(float learningRate, Model model, DataSet dataset);
void train(float learningRate, int epoch, Model model, DataSet dataset);
}

View File

@@ -6,6 +6,7 @@ import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
/**
@@ -13,13 +14,13 @@ import java.util.function.Consumer;
*/
public class FullyConnectedNetwork implements Model {
private final List<Layer> layers;;
private final List<Layer> layers;
private final Map<Neuron, List<Neuron>> connectionMap;
private final Map<Neuron, Integer> neuronIndex;
private final Map<Neuron, Integer> layerIndexByNeuron;
public FullyConnectedNetwork(List<Layer> layers) {
this.layers = layers;
this.connectionMap = this.createConnectionMap();
this.neuronIndex = this.createNeuronIndex();
this.layerIndexByNeuron = this.createNeuronIndex();
}
@Override
@@ -71,8 +72,9 @@ public class FullyConnectedNetwork implements Model {
}
@Override
public int indexOf(Neuron n) {
return this.neuronIndex.get(n);
public int indexInLayerOf(Neuron n) {
int layerIndex = this.layerIndexByNeuron.get(n);
return this.layers.get(layerIndex).indexInLayerOf(n);
}
private Map<Neuron, List<Neuron>> createConnectionMap() {
@@ -83,14 +85,16 @@ public class FullyConnectedNetwork implements Model {
this.layers.get(i + 1).forEachNeuron(nextLayerNeurons::add);
this.layers.get(i).forEachNeuron(n -> res.put(n, nextLayerNeurons));
}
return res;
}
private Map<Neuron, Integer> createNeuronIndex() {
Map<Neuron, Integer> res = new HashMap<>();
int[] index = {0};
this.layers.forEach(l -> l.forEachNeuron(n -> res.put(n, index[0]++)));
AtomicInteger index = new AtomicInteger(0);
this.layers.forEach(l -> {
l.forEachNeuron(n -> res.put(n, index.get()));
index.incrementAndGet();
});
return res;
}
}

View File

@@ -3,15 +3,19 @@ package com.naaturel.ANN.domain.model.neuron;
import com.naaturel.ANN.domain.abstraction.Model;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
public class Layer implements Model {
private final List<Neuron> neurons;
private final Map<Neuron, Integer> neuronIndex;
public Layer(List<Neuron> neurons) {
this.neurons = neurons;
this.neuronIndex = createNeuronIndex();
}
@Override
@@ -39,8 +43,8 @@ public class Layer implements Model {
}
@Override
public int indexOf(Neuron n) {
return this.neurons.indexOf(n);
public int indexInLayerOf(Neuron n) {
return this.neuronIndex.get(n);
}
@Override
@@ -62,4 +66,14 @@ public class Layer implements Model {
public void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer) {
throw new UnsupportedOperationException("Neurons have no connection within the same layer");
}
private Map<Neuron, Integer> createNeuronIndex() {
Map<Neuron, Integer> res = new HashMap<>();
int[] index = {0};
this.neurons.forEach(n -> {
res.put(n, index[0]++);
});
return res;
}
}

View File

@@ -72,7 +72,7 @@ public class Neuron implements Model {
}
@Override
public int indexOf(Neuron n) {
public int indexInLayerOf(Neuron n) {
return 0;
}

View File

@@ -20,11 +20,13 @@ public class TrainingPipeline {
private Consumer<TrainingContext> afterEpoch;
private Predicate<TrainingContext> stopCondition;
private GraphVisualizer visualizer;
private boolean verbose;
private boolean visualization;
private boolean timeMeasurement;
private GraphVisualizer visualizer;
private int verboseDelay;
public TrainingPipeline(List<AlgorithmStep> steps) {
this.steps = new ArrayList<>(steps);
this.stopCondition = (ctx) -> false;
@@ -47,8 +49,10 @@ public class TrainingPipeline {
return this;
}
public TrainingPipeline withVerbose(boolean enabled) {
public TrainingPipeline withVerbose(boolean enabled, int epochDelay) {
if(epochDelay <= 0) throw new IllegalArgumentException("Epoch delay cannot lower or equal to 0");
this.verbose = enabled;
this.verboseDelay = epochDelay;
return this;
}
@@ -71,9 +75,10 @@ public class TrainingPipeline {
this.beforeEpoch.accept(ctx);
this.executeSteps(ctx);
this.afterEpoch.accept(ctx);
if(this.verbose) {
if(this.verbose && ctx.epoch % this.verboseDelay == 0) {
System.out.printf("[Global error] : %f\n", ctx.globalLoss);
}
ctx.epoch += 1;
} while (!this.stopCondition.test(ctx));
if(this.timeMeasurement) {
@@ -94,7 +99,7 @@ public class TrainingPipeline {
step.run();
}
if(this.verbose) {
if(this.verbose && ctx.epoch % this.verboseDelay == 0) {
System.out.printf("Epoch : %d, ", ctx.epoch);
System.out.printf("predicted : %s, ", Arrays.toString(ctx.predictions.toArray()));
System.out.printf("expected : %s, ", Arrays.toString(ctx.expectations.toArray()));
@@ -102,7 +107,6 @@ public class TrainingPipeline {
System.out.printf("loss : %.5f\n", ctx.localLoss);
}
}
ctx.epoch += 1;
}
private void visualize(TrainingContext ctx){

View File

@@ -18,5 +18,6 @@ public class SquareLossStep implements AlgorithmStep {
Stream<Float> deltaStream = this.context.deltas.stream();
this.context.localLoss = deltaStream.reduce(0.0F, (acc, d) -> (float) (acc + Math.pow(d, 2)));
this.context.localLoss /= 2;
this.context.globalLoss += this.context.localLoss; //broke MSE en gradientDescentTraining
}
}

View File

@@ -12,7 +12,7 @@ public class BackpropagationCorrectionStep implements AlgorithmStep {
@Override
public void run() {
this.context.model.forEachOutputNeurons(n -> {
this.context.model.forEachNeuron(n -> {
n.forEachSynapse(syn -> {
float lr = context.learningRate;
float signal = context.errorSignals.get(n);

View File

@@ -2,13 +2,8 @@ package com.naaturel.ANN.implementation.multiLayers;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.model.neuron.Neuron;
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
public class ErrorSignalStep implements AlgorithmStep {
@@ -20,23 +15,19 @@ public class ErrorSignalStep implements AlgorithmStep {
@Override
public void run() {
this.context.deltas = new ArrayList<>();
this.context.errorSignals = new HashMap<>();
this.calculateOutputLayerErrorSignals();
this.context.model.forEachNeuron(n -> calculateErrorSignalRecursive(n, this.context.errorSignals));
this.context.model.forEachNeuron(n -> {
calculateErrorSignalRecursive(n, this.context.errorSignals);
});
}
private float calculateErrorSignalRecursive(Neuron n, Map<Neuron, Float> signals) {
if (signals.containsKey(n)) return signals.get(n);
AtomicInteger connectedIndex = new AtomicInteger(0);
int neuronIndex = this.context.model.indexInLayerOf(n);
AtomicReference<Float> signalSum = new AtomicReference<>(0F);
this.context.model.forEachNeuronConnectedTo(n, connected -> {
int neuronIndex = this.context.model.indexOf(n);
float weightedSignal = calculateErrorSignalRecursive(connected, signals) * connected.getWeight(neuronIndex);
signalSum.set(signalSum.get() + weightedSignal);
connectedIndex.incrementAndGet();
});
float derivative = n.getActivationFunction().derivative(n.getOutput());
@@ -44,22 +35,4 @@ public class ErrorSignalStep implements AlgorithmStep {
signals.put(n, finalSignal);
return finalSignal;
}
private void calculateOutputLayerErrorSignals(){
DataSetEntry entry = this.context.currentEntry;
List<Float> expectations = this.context.dataset.getLabelsAsFloat(entry);
AtomicInteger index = new AtomicInteger(0);
this.context.model.forEachOutputNeurons(n -> {
float expected = expectations.get(index.get());
float predicted = n.getOutput();
float output = n.getOutput();
float delta = expected - predicted;
float signal = delta * n.getActivationFunction().derivative(output);
this.context.deltas.add(delta);
this.context.errorSignals.put(n, signal);
index.incrementAndGet();
});
}
}

View File

@@ -0,0 +1,39 @@
package com.naaturel.ANN.implementation.multiLayers;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
public class OutputLayerErrorStep implements AlgorithmStep {
private final GradientBackpropagationContext context;
public OutputLayerErrorStep(GradientBackpropagationContext context){
this.context = context;
}
@Override
public void run() {
context.deltas = new ArrayList<>();
DataSetEntry entry = this.context.currentEntry;
List<Float> expectations = this.context.dataset.getLabelsAsFloat(entry);
AtomicInteger index = new AtomicInteger(0);
context.errorSignals = new HashMap<>();
this.context.model.forEachOutputNeurons(n -> {
float expected = expectations.get(index.get());
float predicted = n.getOutput();
float output = n.getOutput();
float delta = expected - predicted;
float signal = delta * n.getActivationFunction().derivative(output);
this.context.deltas.add(delta);
this.context.errorSignals.put(n, signal);
index.incrementAndGet();
});
}
}

View File

@@ -23,7 +23,7 @@ public class AdalineTraining implements Trainer {
}
@Override
public void train(float learningRate, Model model, DataSet dataset) {
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
AdalineTrainingContext context = new AdalineTrainingContext();
context.dataset = dataset;
context.model = model;
@@ -38,11 +38,11 @@ public class AdalineTraining implements Trainer {
);
new TrainingPipeline(steps)
.stopCondition(ctx -> ctx.globalLoss <= 0.00F || ctx.epoch > 25)
.stopCondition(ctx -> ctx.globalLoss <= 0.00F || ctx.epoch > epoch)
.beforeEpoch(ctx -> ctx.globalLoss = 0.0F)
.afterEpoch(ctx -> ctx.globalLoss /= context.dataset.size())
.withTimeMeasurement(true)
.withVerbose(true)
.withVerbose(true, 1)
.withVisualization(true, new GraphVisualizer())
.run(context);
}

View File

@@ -8,14 +8,14 @@ import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep;
import com.naaturel.ANN.implementation.multiLayers.BackpropagationCorrectionStep;
import com.naaturel.ANN.implementation.multiLayers.GradientBackpropagationContext;
import com.naaturel.ANN.implementation.multiLayers.ErrorSignalStep;
import com.naaturel.ANN.implementation.multiLayers.OutputLayerErrorStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
import java.util.List;
public class GradientBackpropagationTraining implements Trainer {
@Override
public void train(float learningRate, Model model, DataSet dataset) {
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
GradientBackpropagationContext context = new GradientBackpropagationContext();
context.dataset = dataset;
context.model = model;
@@ -23,18 +23,20 @@ public class GradientBackpropagationTraining implements Trainer {
List<AlgorithmStep> steps = List.of(
new SimplePredictionStep(context),
new OutputLayerErrorStep(context),
new ErrorSignalStep(context),
new BackpropagationCorrectionStep(context),
new SquareLossStep(context)
);
new TrainingPipeline(steps)
.beforeEpoch(ctx -> ctx.globalLoss = 0.0F)
.afterEpoch(ctx -> ctx.globalLoss = ctx.localLoss/dataset.size())
.stopCondition(ctx -> ctx.epoch > 1000000)
.withVerbose(false)
.stopCondition(ctx -> ctx.globalLoss <= 0.0001F || ctx.epoch > epoch)
.beforeEpoch(ctx -> {
ctx.globalLoss = 0.0F;
})
.afterEpoch(ctx -> ctx.globalLoss /= dataset.size())
.withVerbose(true, 100)
.withTimeMeasurement(true)
.run(context);
}
}

View File

@@ -23,7 +23,7 @@ public class GradientDescentTraining implements Trainer {
}
@Override
public void train(float learningRate, Model model, DataSet dataset) {
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
GradientDescentTrainingContext context = new GradientDescentTrainingContext();
context.dataset = dataset;
context.model = model;
@@ -38,7 +38,7 @@ public class GradientDescentTraining implements Trainer {
);
new TrainingPipeline(steps)
.stopCondition(ctx -> ctx.globalLoss <= 0.08F || ctx.epoch > 150)
.stopCondition(ctx -> ctx.globalLoss <= 0.08F || ctx.epoch > epoch)
.beforeEpoch(ctx -> {
GradientDescentTrainingContext gdCtx = (GradientDescentTrainingContext) ctx;
gdCtx.globalLoss = 0.0F;

View File

@@ -16,7 +16,7 @@ public class SimpleTraining implements Trainer {
}
@Override
public void train(float learningRate, Model model, DataSet dataset) {
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
SimpleTrainingContext context = new SimpleTrainingContext();
context.dataset = dataset;
context.model = model;
@@ -32,9 +32,9 @@ public class SimpleTraining implements Trainer {
TrainingPipeline pipeline = new TrainingPipeline(steps);
pipeline
.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 10)
.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > epoch)
.beforeEpoch(ctx -> ctx.globalLoss = 0)
.withVerbose(true)
.withVerbose(true, 1)
.run(context);
}