diff --git a/config.json b/config.json index f436ca6..b7824f4 100644 --- a/config.json +++ b/config.json @@ -1,14 +1,14 @@ { "model": { - "new": false, - "parameters": [10, 10, 5], - "path": "C:/HEPL/RNA2/ANN-framework/src/main/resources/snapshots/snapshot-low-learning.json" + "new": true, + "parameters": [1, 2, 3, 4], + "path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/snapshot-test.json" }, "training" : { - "learning_rate" : 0.003, - "max_epoch" : 5000 + "learning_rate" : 0.01, + "max_epoch" : 500 }, "dataset" : { - "path" : "C:/HEPL/RNA2/ANN-framework/src/main/resources/assets/LangageDesSignes/data_formatted_Low_learning.csv" + "path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_3_5.csv" } } \ No newline at end of file diff --git a/src/main/java/com/naaturel/ANN/implementation/multiLayers/ErrorSignalStep.java b/src/main/java/com/naaturel/ANN/implementation/multiLayers/ErrorSignalStep.java index d36bf2b..c8f83e8 100644 --- a/src/main/java/com/naaturel/ANN/implementation/multiLayers/ErrorSignalStep.java +++ b/src/main/java/com/naaturel/ANN/implementation/multiLayers/ErrorSignalStep.java @@ -13,17 +13,27 @@ public class ErrorSignalStep implements AlgorithmStep { @Override public void run() { - context.model.forEachNeuron(n -> { - if (context.errorSignalsComputed[n.getId()]) return; + boolean[] isTotallyComputed = new boolean[1]; - int neuronIndex = context.model.indexInLayerOf(n); - float[] signalSum = {0f}; - context.model.forEachNeuronConnectedTo(n, connected -> { - signalSum[0] += context.errorSignals[connected.getId()] * connected.getWeight(neuronIndex); + do{ + isTotallyComputed[0] = true; + context.model.forEachNeuron(n -> { + if (context.errorSignalsComputed[n.getId()]) return; + + int neuronIndex = context.model.indexInLayerOf(n); + float[] signalSum = {0f}; + boolean[] canBeComputed = {true}; + context.model.forEachNeuronConnectedTo(n, connected -> { + signalSum[0] += context.errorSignals[connected.getId()] * connected.getWeight(neuronIndex); + canBeComputed[0] &= context.errorSignalsComputed[connected.getId()]; + }); + + if(canBeComputed[0]) { + context.errorSignals[n.getId()] = n.getActivationFunction().derivative(n.getOutput()) * signalSum[0]; + context.errorSignalsComputed[n.getId()] = true; + } + isTotallyComputed[0] &= canBeComputed[0]; }); - - context.errorSignals[n.getId()] = n.getActivationFunction().derivative(n.getOutput()) * signalSum[0]; - context.errorSignalsComputed[n.getId()] = true; - }); + } while(!isTotallyComputed[0]); } } \ No newline at end of file diff --git a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java index d7af095..9842746 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java @@ -33,7 +33,8 @@ public class GradientDescentTraining implements Trainer { new SimplePredictionStep(context), new SimpleDeltaStep(context), new SquareLossStep(context), - new GradientDescentErrorStrategy(context) + new GradientDescentErrorStrategy(context), + new GradientDescentCorrectionStrategy(context) ); new TrainingPipeline(steps) @@ -48,7 +49,6 @@ public class GradientDescentTraining implements Trainer { }) .afterEpoch(ctx -> { context.globalLoss /= context.dataset.size(); - new GradientDescentCorrectionStrategy(context).run(); }) .withVerbose(true, 5) .withTimeMeasurement(true)