Fix error signal computation
This commit is contained in:
12
config.json
12
config.json
@@ -1,14 +1,14 @@
|
||||
{
|
||||
"model": {
|
||||
"new": false,
|
||||
"parameters": [10, 10, 5],
|
||||
"path": "C:/HEPL/RNA2/ANN-framework/src/main/resources/snapshots/snapshot-low-learning.json"
|
||||
"new": true,
|
||||
"parameters": [1, 2, 3, 4],
|
||||
"path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/snapshot-test.json"
|
||||
},
|
||||
"training" : {
|
||||
"learning_rate" : 0.003,
|
||||
"max_epoch" : 5000
|
||||
"learning_rate" : 0.01,
|
||||
"max_epoch" : 500
|
||||
},
|
||||
"dataset" : {
|
||||
"path" : "C:/HEPL/RNA2/ANN-framework/src/main/resources/assets/LangageDesSignes/data_formatted_Low_learning.csv"
|
||||
"path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_3_5.csv"
|
||||
}
|
||||
}
|
||||
@@ -13,17 +13,27 @@ public class ErrorSignalStep implements AlgorithmStep {
|
||||
@Override
|
||||
public void run() {
|
||||
|
||||
context.model.forEachNeuron(n -> {
|
||||
if (context.errorSignalsComputed[n.getId()]) return;
|
||||
boolean[] isTotallyComputed = new boolean[1];
|
||||
|
||||
int neuronIndex = context.model.indexInLayerOf(n);
|
||||
float[] signalSum = {0f};
|
||||
context.model.forEachNeuronConnectedTo(n, connected -> {
|
||||
signalSum[0] += context.errorSignals[connected.getId()] * connected.getWeight(neuronIndex);
|
||||
do{
|
||||
isTotallyComputed[0] = true;
|
||||
context.model.forEachNeuron(n -> {
|
||||
if (context.errorSignalsComputed[n.getId()]) return;
|
||||
|
||||
int neuronIndex = context.model.indexInLayerOf(n);
|
||||
float[] signalSum = {0f};
|
||||
boolean[] canBeComputed = {true};
|
||||
context.model.forEachNeuronConnectedTo(n, connected -> {
|
||||
signalSum[0] += context.errorSignals[connected.getId()] * connected.getWeight(neuronIndex);
|
||||
canBeComputed[0] &= context.errorSignalsComputed[connected.getId()];
|
||||
});
|
||||
|
||||
if(canBeComputed[0]) {
|
||||
context.errorSignals[n.getId()] = n.getActivationFunction().derivative(n.getOutput()) * signalSum[0];
|
||||
context.errorSignalsComputed[n.getId()] = true;
|
||||
}
|
||||
isTotallyComputed[0] &= canBeComputed[0];
|
||||
});
|
||||
|
||||
context.errorSignals[n.getId()] = n.getActivationFunction().derivative(n.getOutput()) * signalSum[0];
|
||||
context.errorSignalsComputed[n.getId()] = true;
|
||||
});
|
||||
} while(!isTotallyComputed[0]);
|
||||
}
|
||||
}
|
||||
@@ -33,7 +33,8 @@ public class GradientDescentTraining implements Trainer {
|
||||
new SimplePredictionStep(context),
|
||||
new SimpleDeltaStep(context),
|
||||
new SquareLossStep(context),
|
||||
new GradientDescentErrorStrategy(context)
|
||||
new GradientDescentErrorStrategy(context),
|
||||
new GradientDescentCorrectionStrategy(context)
|
||||
);
|
||||
|
||||
new TrainingPipeline(steps)
|
||||
@@ -48,7 +49,6 @@ public class GradientDescentTraining implements Trainer {
|
||||
})
|
||||
.afterEpoch(ctx -> {
|
||||
context.globalLoss /= context.dataset.size();
|
||||
new GradientDescentCorrectionStrategy(context).run();
|
||||
})
|
||||
.withVerbose(true, 5)
|
||||
.withTimeMeasurement(true)
|
||||
|
||||
Reference in New Issue
Block a user