diff --git a/src/main/java/com/naaturel/ANN/Main.java b/src/main/java/com/naaturel/ANN/Main.java index 8276e19..a285ce4 100644 --- a/src/main/java/com/naaturel/ANN/Main.java +++ b/src/main/java/com/naaturel/ANN/Main.java @@ -3,6 +3,7 @@ package com.naaturel.ANN; import com.naaturel.ANN.domain.model.neuron.Neuron; import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.implementation.multiLayers.Sigmoid; +import com.naaturel.ANN.implementation.multiLayers.TanH; import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining; import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor; @@ -14,33 +15,36 @@ public class Main { public static void main(String[] args){ - int nbrInput = 2; - int nbrClass = 3; + int nbrInput = 25; + int nbrClass = 4; - int nbrLayers = 2; + int[] neuronPerLayer = new int[]{10, nbrClass}; DataSet dataset = new DatasetExtractor() - .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_3_1.csv", nbrClass); + .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_3_5.csv", nbrClass); List layers = new ArrayList<>(); - for(int i = 0; i < nbrLayers; i++){ + for (int i = 0; i < neuronPerLayer.length; i++){ List neurons = new ArrayList<>(); - for (int j=0; j < nbrClass; j++){ + for (int j = 0; j < neuronPerLayer[i]; j++){ + + int nbrSyn = i == 0 ? nbrInput: neuronPerLayer[i-1]; List syns = new ArrayList<>(); - for (int k=0; k < nbrInput; k++){ - syns.add(new Synapse(new Input(0), new Weight(0))); + for (int k=0; k < nbrSyn; k++){ + syns.add(new Synapse(new Input(0), new Weight())); } - Bias bias = new Bias(new Weight(0)); + Bias bias = new Bias(new Weight()); - Neuron n = new Neuron(syns, bias, new Sigmoid(1)); + Neuron n = new Neuron(syns, bias, new TanH()); neurons.add(n); } Layer layer = new Layer(neurons); layers.add(layer); } + FullyConnectedNetwork network = new FullyConnectedNetwork(layers); Trainer trainer = new GradientBackpropagationTraining(); diff --git a/src/main/java/com/naaturel/ANN/domain/abstraction/Model.java b/src/main/java/com/naaturel/ANN/domain/abstraction/Model.java index d32e093..891f685 100644 --- a/src/main/java/com/naaturel/ANN/domain/abstraction/Model.java +++ b/src/main/java/com/naaturel/ANN/domain/abstraction/Model.java @@ -10,6 +10,7 @@ import java.util.function.Consumer; public interface Model { int synCount(); int neuronCount(); + int indexOf(Neuron n); void forEachNeuron(Consumer consumer); void forEachSynapse(Consumer consumer); void forEachOutputNeurons(Consumer consumer); diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/FullyConnectedNetwork.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/FullyConnectedNetwork.java index 8018d5a..48cd74c 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/neuron/FullyConnectedNetwork.java +++ b/src/main/java/com/naaturel/ANN/domain/model/neuron/FullyConnectedNetwork.java @@ -1,7 +1,6 @@ package com.naaturel.ANN.domain.model.neuron; import com.naaturel.ANN.domain.abstraction.Model; -import com.naaturel.ANN.domain.abstraction.Network; import java.util.ArrayList; import java.util.HashMap; @@ -16,10 +15,11 @@ public class FullyConnectedNetwork implements Model { private final List layers;; private final Map> connectionMap; - + private final Map neuronIndex; public FullyConnectedNetwork(List layers) { this.layers = layers; this.connectionMap = this.createConnectionMap(); + this.neuronIndex = this.createNeuronIndex(); } @Override @@ -70,6 +70,11 @@ public class FullyConnectedNetwork implements Model { this.connectionMap.get(n).forEach(consumer); } + @Override + public int indexOf(Neuron n) { + return this.neuronIndex.get(n); + } + private Map> createConnectionMap() { Map> res = new HashMap<>(); @@ -81,4 +86,11 @@ public class FullyConnectedNetwork implements Model { return res; } + + private Map createNeuronIndex() { + Map res = new HashMap<>(); + int[] index = {0}; + this.layers.forEach(l -> l.forEachNeuron(n -> res.put(n, index[0]++))); + return res; + } } diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java index 36340e0..67282be 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java +++ b/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java @@ -38,6 +38,11 @@ public class Layer implements Model { return this.neurons.size(); } + @Override + public int indexOf(Neuron n) { + return this.neurons.indexOf(n); + } + @Override public void forEachNeuron(Consumer consumer) { this.neurons.forEach(consumer); diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java index 9893cf8..ea6cb97 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java +++ b/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java @@ -11,12 +11,14 @@ public class Neuron implements Model { protected Bias bias; protected ActivationFunction activationFunction; protected Float output; + protected Float weightedSum; public Neuron(List synapses, Bias bias, ActivationFunction func){ this.synapses = synapses; this.bias = bias; this.activationFunction = func; - this.output = 0F; + this.output = null; + this.weightedSum = null; } public void updateBias(Weight weight) { @@ -42,13 +44,22 @@ public class Neuron implements Model { return this.output; } + public float getWeight(int index){ + return this.synapses.get(index).getWeight(); + } + + public float getWeightedSum(){ + return this.weightedSum; + } + public float calculateWeightedSum() { float res = 0; res += this.bias.getWeight() * this.bias.getInput(); for(Synapse syn : this.synapses){ res += syn.getWeight() * syn.getInput(); } - return res; + this.weightedSum = res; + return this.weightedSum; } @Override @@ -61,6 +72,11 @@ public class Neuron implements Model { return 1; } + @Override + public int indexOf(Neuron n) { + return 0; + } + @Override public List predict(List inputs) { this.setInputs(inputs); diff --git a/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java b/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java index 7f40cbc..de70dd6 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java +++ b/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java @@ -2,7 +2,6 @@ package com.naaturel.ANN.domain.model.training; import com.naaturel.ANN.domain.abstraction.AlgorithmStep; import com.naaturel.ANN.domain.abstraction.TrainingContext; -import com.naaturel.ANN.domain.abstraction.TrainingStep; import com.naaturel.ANN.infrastructure.dataset.DataSetEntry; import com.naaturel.ANN.domain.model.neuron.Input; import com.naaturel.ANN.infrastructure.graph.GraphVisualizer; diff --git a/src/main/java/com/naaturel/ANN/implementation/multiLayers/BackpropagationCorrectionStep.java b/src/main/java/com/naaturel/ANN/implementation/multiLayers/BackpropagationCorrectionStep.java new file mode 100644 index 0000000..85c0a3b --- /dev/null +++ b/src/main/java/com/naaturel/ANN/implementation/multiLayers/BackpropagationCorrectionStep.java @@ -0,0 +1,24 @@ +package com.naaturel.ANN.implementation.multiLayers; + +import com.naaturel.ANN.domain.abstraction.AlgorithmStep; + +public class BackpropagationCorrectionStep implements AlgorithmStep { + + private GradientBackpropagationContext context; + + public BackpropagationCorrectionStep(GradientBackpropagationContext context){ + this.context = context; + } + + @Override + public void run() { + this.context.model.forEachOutputNeurons(n -> { + n.forEachSynapse(syn -> { + float lr = context.learningRate; + float signal = context.errorSignals.get(n); + float newWeight = syn.getWeight() + (lr * signal * syn.getInput()); + syn.setWeight(newWeight); + }); + }); + } +} diff --git a/src/main/java/com/naaturel/ANN/implementation/multiLayers/ErrorSignalStep.java b/src/main/java/com/naaturel/ANN/implementation/multiLayers/ErrorSignalStep.java new file mode 100644 index 0000000..b19e428 --- /dev/null +++ b/src/main/java/com/naaturel/ANN/implementation/multiLayers/ErrorSignalStep.java @@ -0,0 +1,65 @@ +package com.naaturel.ANN.implementation.multiLayers; + +import com.naaturel.ANN.domain.abstraction.AlgorithmStep; +import com.naaturel.ANN.domain.model.neuron.Neuron; +import com.naaturel.ANN.infrastructure.dataset.DataSetEntry; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +public class ErrorSignalStep implements AlgorithmStep { + + private GradientBackpropagationContext context; + public ErrorSignalStep(GradientBackpropagationContext context) { + this.context = context; + } + + @Override + public void run() { + this.context.deltas = new ArrayList<>(); + this.context.errorSignals = new HashMap<>(); + this.calculateOutputLayerErrorSignals(); + + this.context.model.forEachNeuron(n -> calculateErrorSignalRecursive(n, this.context.errorSignals)); + } + + private float calculateErrorSignalRecursive(Neuron n, Map signals) { + if (signals.containsKey(n)) return signals.get(n); + + AtomicInteger connectedIndex = new AtomicInteger(0); + AtomicReference signalSum = new AtomicReference<>(0F); + this.context.model.forEachNeuronConnectedTo(n, connected -> { + int neuronIndex = this.context.model.indexOf(n); + float weightedSignal = calculateErrorSignalRecursive(connected, signals) * connected.getWeight(neuronIndex); + signalSum.set(signalSum.get() + weightedSignal); + connectedIndex.incrementAndGet(); + }); + + float derivative = n.getActivationFunction().derivative(n.getOutput()); + float finalSignal = derivative * signalSum.get(); + signals.put(n, finalSignal); + return finalSignal; + } + + private void calculateOutputLayerErrorSignals(){ + DataSetEntry entry = this.context.currentEntry; + List expectations = this.context.dataset.getLabelsAsFloat(entry); + AtomicInteger index = new AtomicInteger(0); + + this.context.model.forEachOutputNeurons(n -> { + float expected = expectations.get(index.get()); + float predicted = n.getOutput(); + float output = n.getOutput(); + float delta = expected - predicted; + float signal = delta * n.getActivationFunction().derivative(output); + + this.context.deltas.add(delta); + this.context.errorSignals.put(n, signal); + index.incrementAndGet(); + }); + } +} diff --git a/src/main/java/com/naaturel/ANN/implementation/multiLayers/GradientBackpropagationContext.java b/src/main/java/com/naaturel/ANN/implementation/multiLayers/GradientBackpropagationContext.java index b2c22c7..f37d65a 100644 --- a/src/main/java/com/naaturel/ANN/implementation/multiLayers/GradientBackpropagationContext.java +++ b/src/main/java/com/naaturel/ANN/implementation/multiLayers/GradientBackpropagationContext.java @@ -1,15 +1,14 @@ package com.naaturel.ANN.implementation.multiLayers; import com.naaturel.ANN.domain.abstraction.TrainingContext; +import com.naaturel.ANN.domain.model.neuron.Neuron; -import java.util.ArrayList; -import java.util.List; +import java.util.Map; public class GradientBackpropagationContext extends TrainingContext { - public List hiddenDeltas; + public Map errorSignals; public GradientBackpropagationContext(){ } - } diff --git a/src/main/java/com/naaturel/ANN/implementation/multiLayers/GradientBackpropagationStep.java b/src/main/java/com/naaturel/ANN/implementation/multiLayers/GradientBackpropagationStep.java deleted file mode 100644 index dd0b552..0000000 --- a/src/main/java/com/naaturel/ANN/implementation/multiLayers/GradientBackpropagationStep.java +++ /dev/null @@ -1,24 +0,0 @@ -package com.naaturel.ANN.implementation.multiLayers; - -import com.naaturel.ANN.domain.abstraction.AlgorithmStep; -import com.naaturel.ANN.domain.model.neuron.Neuron; - -public class GradientBackpropagationStep implements AlgorithmStep { - - private GradientBackpropagationContext context; - public GradientBackpropagationStep(GradientBackpropagationContext context) { - this.context = context; - } - - @Override - public void run() { - - - } - - - private float calculateDeltaRecursive(Neuron n){ - - } - -} diff --git a/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java index 20fc04d..0c3df4f 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java @@ -4,9 +4,10 @@ import com.naaturel.ANN.domain.abstraction.AlgorithmStep; import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.domain.model.training.TrainingPipeline; +import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep; +import com.naaturel.ANN.implementation.multiLayers.BackpropagationCorrectionStep; import com.naaturel.ANN.implementation.multiLayers.GradientBackpropagationContext; -import com.naaturel.ANN.implementation.multiLayers.GradientBackpropagationStep; -import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep; +import com.naaturel.ANN.implementation.multiLayers.ErrorSignalStep; import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep; import com.naaturel.ANN.infrastructure.dataset.DataSet; @@ -19,17 +20,19 @@ public class GradientBackpropagationTraining implements Trainer { GradientBackpropagationContext context = new GradientBackpropagationContext(); context.dataset = dataset; context.model = model; - context.learningRate = 0.001F; + context.learningRate = 0.1F; List steps = List.of( new SimplePredictionStep(context), - new SimpleDeltaStep(context), - new GradientBackpropagationStep(context) + new ErrorSignalStep(context), + new BackpropagationCorrectionStep(context), + new SquareLossStep(context) ); new TrainingPipeline(steps) - .stopCondition(ctx -> false) + .stopCondition(ctx -> ctx.epoch == 250) .withVerbose(true) + .withTimeMeasurement(true) .run(context); } diff --git a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java index 9f8abab..1fc558b 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java @@ -11,7 +11,6 @@ import com.naaturel.ANN.implementation.gradientDescent.GradientDescentCorrection import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep; import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep; import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep; -import com.naaturel.ANN.implementation.training.steps.*; import com.naaturel.ANN.infrastructure.graph.GraphVisualizer; import java.util.ArrayList; diff --git a/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java index ee59129..6e73eef 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java @@ -6,7 +6,6 @@ import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.implementation.simplePerceptron.*; import com.naaturel.ANN.domain.model.training.TrainingPipeline; -import com.naaturel.ANN.implementation.training.steps.*; import java.util.List;