Fix error signal computation

This commit is contained in:
2026-05-12 22:24:25 +02:00
parent dc895f5d2e
commit 312d0847ae
3 changed files with 28 additions and 18 deletions

View File

@@ -1,14 +1,14 @@
{ {
"model": { "model": {
"new": false, "new": true,
"parameters": [10, 10, 5], "parameters": [1, 2, 3, 4],
"path": "C:/HEPL/RNA2/ANN-framework/src/main/resources/snapshots/snapshot-low-learning.json" "path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/snapshot-test.json"
}, },
"training" : { "training" : {
"learning_rate" : 0.003, "learning_rate" : 0.01,
"max_epoch" : 5000 "max_epoch" : 500
}, },
"dataset" : { "dataset" : {
"path" : "C:/HEPL/RNA2/ANN-framework/src/main/resources/assets/LangageDesSignes/data_formatted_Low_learning.csv" "path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_3_5.csv"
} }
} }

View File

@@ -13,17 +13,27 @@ public class ErrorSignalStep implements AlgorithmStep {
@Override @Override
public void run() { public void run() {
context.model.forEachNeuron(n -> { boolean[] isTotallyComputed = new boolean[1];
if (context.errorSignalsComputed[n.getId()]) return;
int neuronIndex = context.model.indexInLayerOf(n); do{
float[] signalSum = {0f}; isTotallyComputed[0] = true;
context.model.forEachNeuronConnectedTo(n, connected -> { context.model.forEachNeuron(n -> {
signalSum[0] += context.errorSignals[connected.getId()] * connected.getWeight(neuronIndex); if (context.errorSignalsComputed[n.getId()]) return;
int neuronIndex = context.model.indexInLayerOf(n);
float[] signalSum = {0f};
boolean[] canBeComputed = {true};
context.model.forEachNeuronConnectedTo(n, connected -> {
signalSum[0] += context.errorSignals[connected.getId()] * connected.getWeight(neuronIndex);
canBeComputed[0] &= context.errorSignalsComputed[connected.getId()];
});
if(canBeComputed[0]) {
context.errorSignals[n.getId()] = n.getActivationFunction().derivative(n.getOutput()) * signalSum[0];
context.errorSignalsComputed[n.getId()] = true;
}
isTotallyComputed[0] &= canBeComputed[0];
}); });
} while(!isTotallyComputed[0]);
context.errorSignals[n.getId()] = n.getActivationFunction().derivative(n.getOutput()) * signalSum[0];
context.errorSignalsComputed[n.getId()] = true;
});
} }
} }

View File

@@ -33,7 +33,8 @@ public class GradientDescentTraining implements Trainer {
new SimplePredictionStep(context), new SimplePredictionStep(context),
new SimpleDeltaStep(context), new SimpleDeltaStep(context),
new SquareLossStep(context), new SquareLossStep(context),
new GradientDescentErrorStrategy(context) new GradientDescentErrorStrategy(context),
new GradientDescentCorrectionStrategy(context)
); );
new TrainingPipeline(steps) new TrainingPipeline(steps)
@@ -48,7 +49,6 @@ public class GradientDescentTraining implements Trainer {
}) })
.afterEpoch(ctx -> { .afterEpoch(ctx -> {
context.globalLoss /= context.dataset.size(); context.globalLoss /= context.dataset.size();
new GradientDescentCorrectionStrategy(context).run();
}) })
.withVerbose(true, 5) .withVerbose(true, 5)
.withTimeMeasurement(true) .withTimeMeasurement(true)