Fix error signal computation

This commit is contained in:
2026-05-12 22:24:25 +02:00
parent dc895f5d2e
commit 312d0847ae
3 changed files with 28 additions and 18 deletions

View File

@@ -1,14 +1,14 @@
{
"model": {
"new": false,
"parameters": [10, 10, 5],
"path": "C:/HEPL/RNA2/ANN-framework/src/main/resources/snapshots/snapshot-low-learning.json"
"new": true,
"parameters": [1, 2, 3, 4],
"path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/snapshot-test.json"
},
"training" : {
"learning_rate" : 0.003,
"max_epoch" : 5000
"learning_rate" : 0.01,
"max_epoch" : 500
},
"dataset" : {
"path" : "C:/HEPL/RNA2/ANN-framework/src/main/resources/assets/LangageDesSignes/data_formatted_Low_learning.csv"
"path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_3_5.csv"
}
}

View File

@@ -13,17 +13,27 @@ public class ErrorSignalStep implements AlgorithmStep {
@Override
public void run() {
boolean[] isTotallyComputed = new boolean[1];
do{
isTotallyComputed[0] = true;
context.model.forEachNeuron(n -> {
if (context.errorSignalsComputed[n.getId()]) return;
int neuronIndex = context.model.indexInLayerOf(n);
float[] signalSum = {0f};
boolean[] canBeComputed = {true};
context.model.forEachNeuronConnectedTo(n, connected -> {
signalSum[0] += context.errorSignals[connected.getId()] * connected.getWeight(neuronIndex);
canBeComputed[0] &= context.errorSignalsComputed[connected.getId()];
});
if(canBeComputed[0]) {
context.errorSignals[n.getId()] = n.getActivationFunction().derivative(n.getOutput()) * signalSum[0];
context.errorSignalsComputed[n.getId()] = true;
}
isTotallyComputed[0] &= canBeComputed[0];
});
} while(!isTotallyComputed[0]);
}
}

View File

@@ -33,7 +33,8 @@ public class GradientDescentTraining implements Trainer {
new SimplePredictionStep(context),
new SimpleDeltaStep(context),
new SquareLossStep(context),
new GradientDescentErrorStrategy(context)
new GradientDescentErrorStrategy(context),
new GradientDescentCorrectionStrategy(context)
);
new TrainingPipeline(steps)
@@ -48,7 +49,6 @@ public class GradientDescentTraining implements Trainer {
})
.afterEpoch(ctx -> {
context.globalLoss /= context.dataset.size();
new GradientDescentCorrectionStrategy(context).run();
})
.withVerbose(true, 5)
.withTimeMeasurement(true)