diff --git a/src/main/java/com/naaturel/ANN/Main.java b/src/main/java/com/naaturel/ANN/Main.java index 348e93a..966a1e9 100644 --- a/src/main/java/com/naaturel/ANN/Main.java +++ b/src/main/java/com/naaturel/ANN/Main.java @@ -1,5 +1,7 @@ package com.naaturel.ANN; +import com.naaturel.ANN.domain.abstraction.Model; +import com.naaturel.ANN.domain.abstraction.Network; import com.naaturel.ANN.domain.model.neuron.Neuron; import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.implementation.gradientDescent.Linear; @@ -22,11 +24,21 @@ public class Main { int nbrClass = 1; DataSet dataset = new DatasetExtractor() - .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_9.csv", nbrClass); + .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_12.csv", nbrClass); - int[] neuronPerLayer = new int[]{10, 5, 5, dataset.getNbrLabels()}; + int[] neuronPerLayer = new int[]{50, 50, 50, dataset.getNbrLabels()}; int nbrInput = dataset.getNbrInputs(); + FullyConnectedNetwork network = createNetwork(neuronPerLayer, nbrInput); + + Trainer trainer = new GradientBackpropagationTraining(); + trainer.train(0.01F, 2000, network, dataset); + + //plotGraph(dataset, network); + + } + + private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){ List layers = new ArrayList<>(); for (int i = 0; i < neuronPerLayer.length; i++){ @@ -49,26 +61,27 @@ public class Main { layers.add(layer); } - FullyConnectedNetwork network = new FullyConnectedNetwork(layers.toArray(new Layer[0])); - - Trainer trainer = new GradientBackpropagationTraining(); - trainer.train(0.0005F, 15000, network, dataset); + return new FullyConnectedNetwork(layers.toArray(new Layer[0])); + } + private static void plotGraph(DataSet dataset, Model network){ GraphVisualizer visualizer = new GraphVisualizer(); for (DataSetEntry entry : dataset) { List label = dataset.getLabelsAsFloat(entry); - visualizer.addPoint("Label " + label.getFirst(), entry.getData().get(0).getValue(), entry.getData().get(1).getValue()); + label.forEach(l -> { + visualizer.addPoint("Label " + l, + entry.getData().get(0).getValue(), entry.getData().get(1).getValue()); + }); } - float min = 0F; - float max = 15F; - float step = 0.03F; + float min = -5F; + float max = 5F; + float step = 0.01F; for (float x = min; x < max; x+=step){ for (float y = min; y < max; y+=step){ - float prediction = network.predict(List.of(new Input(x), new Input(y))).getFirst(); - float predSeries = prediction > 0.5F ? 1 : -1; - visualizer.addPoint(Float.toString(predSeries), x, y); + List predictions = network.predict(List.of(new Input(x), new Input(y))); + visualizer.addPoint(Float.toString(Math.round(predictions.getFirst())), x, y); } } diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/Synapse.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/Synapse.java index 88a017c..dcedbd5 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/neuron/Synapse.java +++ b/src/main/java/com/naaturel/ANN/domain/model/neuron/Synapse.java @@ -25,8 +25,4 @@ public class Synapse { public void setWeight(float value){ this.weight.setValue(value); } - - - - } diff --git a/src/main/java/com/naaturel/ANN/implementation/multiLayers/BackpropagationCorrectionStep.java b/src/main/java/com/naaturel/ANN/implementation/multiLayers/BackpropagationCorrectionStep.java index ed999d9..980aaa3 100644 --- a/src/main/java/com/naaturel/ANN/implementation/multiLayers/BackpropagationCorrectionStep.java +++ b/src/main/java/com/naaturel/ANN/implementation/multiLayers/BackpropagationCorrectionStep.java @@ -2,6 +2,8 @@ package com.naaturel.ANN.implementation.multiLayers; import com.naaturel.ANN.domain.abstraction.AlgorithmStep; +import java.util.concurrent.atomic.AtomicInteger; + public class BackpropagationCorrectionStep implements AlgorithmStep { private GradientBackpropagationContext context; @@ -12,13 +14,29 @@ public class BackpropagationCorrectionStep implements AlgorithmStep { @Override public void run() { + + AtomicInteger synIndex = new AtomicInteger(0); this.context.model.forEachNeuron(n -> { + float signal = context.errorSignals.get(n); n.forEachSynapse(syn -> { float lr = context.learningRate; - float signal = context.errorSignals.get(n); - float newWeight = syn.getWeight() + (lr * signal * syn.getInput()); - syn.setWeight(newWeight); + float corrector = lr * signal * syn.getInput(); + float existingCorrector = context.correctionBuffer[synIndex.get()]; + float newCorrector = existingCorrector + corrector; + + if(context.currentSample >= context.batchSize){ + float newWeight = syn.getWeight() + newCorrector; + syn.setWeight(newWeight); + context.correctionBuffer[synIndex.get()] = 0; + } else { + context.correctionBuffer[synIndex.get()] = newCorrector; + } + synIndex.incrementAndGet(); }); }); + if(context.currentSample >= context.batchSize) { + context.currentSample = 0; + } + context.currentSample += 1; } } diff --git a/src/main/java/com/naaturel/ANN/implementation/multiLayers/BatchAccumulatorStep.java b/src/main/java/com/naaturel/ANN/implementation/multiLayers/BatchAccumulatorStep.java new file mode 100644 index 0000000..1038118 --- /dev/null +++ b/src/main/java/com/naaturel/ANN/implementation/multiLayers/BatchAccumulatorStep.java @@ -0,0 +1,11 @@ +package com.naaturel.ANN.implementation.multiLayers; + +import com.naaturel.ANN.domain.abstraction.AlgorithmStep; + +public class BatchAccumulatorStep implements AlgorithmStep { + + @Override + public void run() { + + } +} diff --git a/src/main/java/com/naaturel/ANN/implementation/multiLayers/GradientBackpropagationContext.java b/src/main/java/com/naaturel/ANN/implementation/multiLayers/GradientBackpropagationContext.java index f37d65a..11d1cd2 100644 --- a/src/main/java/com/naaturel/ANN/implementation/multiLayers/GradientBackpropagationContext.java +++ b/src/main/java/com/naaturel/ANN/implementation/multiLayers/GradientBackpropagationContext.java @@ -1,14 +1,29 @@ package com.naaturel.ANN.implementation.multiLayers; +import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.abstraction.TrainingContext; import com.naaturel.ANN.domain.model.neuron.Neuron; +import com.naaturel.ANN.infrastructure.dataset.DataSet; +import java.util.HashMap; import java.util.Map; public class GradientBackpropagationContext extends TrainingContext { - public Map errorSignals; + public final Map errorSignals; + public final float[] correctionBuffer; - public GradientBackpropagationContext(){ + public int currentSample; + public int batchSize; + + public GradientBackpropagationContext(Model model, DataSet dataSet, float learningRate, int batchSize){ + this.model = model; + this.dataset = dataSet; + this.learningRate = learningRate; + this.batchSize = batchSize; + + this.errorSignals = new HashMap<>(); + this.correctionBuffer = new float[model.synCount()]; + this.currentSample = 1; } } diff --git a/src/main/java/com/naaturel/ANN/implementation/multiLayers/OutputLayerErrorStep.java b/src/main/java/com/naaturel/ANN/implementation/multiLayers/OutputLayerErrorStep.java index 1390d37..c90f765 100644 --- a/src/main/java/com/naaturel/ANN/implementation/multiLayers/OutputLayerErrorStep.java +++ b/src/main/java/com/naaturel/ANN/implementation/multiLayers/OutputLayerErrorStep.java @@ -23,7 +23,7 @@ public class OutputLayerErrorStep implements AlgorithmStep { List expectations = this.context.dataset.getLabelsAsFloat(entry); AtomicInteger index = new AtomicInteger(0); - context.errorSignals = new HashMap<>(); + context.errorSignals.clear(); this.context.model.forEachOutputNeurons(n -> { float expected = expectations.get(index.get()); float predicted = n.getOutput(); diff --git a/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java index a6608c1..84e45cb 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java @@ -16,10 +16,8 @@ import java.util.List; public class GradientBackpropagationTraining implements Trainer { @Override public void train(float learningRate, int epoch, Model model, DataSet dataset) { - GradientBackpropagationContext context = new GradientBackpropagationContext(); - context.dataset = dataset; - context.model = model; - context.learningRate = learningRate; + GradientBackpropagationContext context = + new GradientBackpropagationContext(model, dataset, learningRate, dataset.size()); List steps = List.of( new SimplePredictionStep(context), @@ -34,7 +32,9 @@ public class GradientBackpropagationTraining implements Trainer { .beforeEpoch(ctx -> { ctx.globalLoss = 0.0F; }) - .afterEpoch(ctx -> ctx.globalLoss /= dataset.size()) + .afterEpoch(ctx -> { + ctx.globalLoss /= dataset.size(); + }) .withVerbose(true,epoch/10) .withTimeMeasurement(true) .run(context); diff --git a/src/main/resources/assets/table_2_9.csv b/src/main/resources/assets/table_2_9.csv index ed8c03f..69450e4 100644 --- a/src/main/resources/assets/table_2_9.csv +++ b/src/main/resources/assets/table_2_9.csv @@ -18,11 +18,4 @@ 4,6,-1 4,7,-1 4,9,1 -4,10,1 -2,6,-1 -7,7,-1 -5,9,1 -9,10,1 -7,1,-1 -5,0,1 -9,5,1 \ No newline at end of file +4,10,1 \ No newline at end of file