Fix weights correction

This commit is contained in:
Laurent
2026-03-23 17:13:53 +01:00
parent 89d9abe329
commit fbf2a571ef
9 changed files with 54 additions and 24 deletions

View File

@@ -68,8 +68,9 @@ public class Main {
Network network = new Network(List.of(layer)); Network network = new Network(List.of(layer));
TrainingContext context = new TrainingContext(); TrainingContext context = new TrainingContext();
context.dataset = dataSet; context.dataset = orDataSet;
context.model = network; context.model = network;
context.learningRate = 0.3F;
List<TrainingStep> steps = List.of( List<TrainingStep> steps = List.of(
new PredictionStep(), new PredictionStep(),
@@ -81,8 +82,8 @@ public class Main {
TrainingPipeline pipeline = new TrainingPipeline(steps); TrainingPipeline pipeline = new TrainingPipeline(steps);
pipeline pipeline
.stopCondition(ctx -> ctx.globalLoss == 0 && ctx.epoch >= 1000) .stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 100)
.afterEpoch(ctx -> ctx.globalLoss = 0) .beforeEpoch(ctx -> ctx.globalLoss = 0)
.withVerbose(true) .withVerbose(true)
.run(context); .run(context);

View File

@@ -9,6 +9,6 @@ import java.util.function.Consumer;
public interface Trainable { public interface Trainable {
List<Float> predict(List<Input> inputs); List<Float> predict(List<Input> inputs);
void forEachSynapse(Consumer<Synapse> consumer); void applyOnSynapses(Consumer<Synapse> consumer);
} }

View File

@@ -1,6 +1,5 @@
package com.naaturel.ANN.domain.model.neuron; package com.naaturel.ANN.domain.model.neuron;
import com.naaturel.ANN.domain.abstraction.CorrectionStrategy;
import com.naaturel.ANN.domain.abstraction.Neuron; import com.naaturel.ANN.domain.abstraction.Neuron;
import com.naaturel.ANN.domain.abstraction.Trainable; import com.naaturel.ANN.domain.abstraction.Trainable;
@@ -27,7 +26,7 @@ public class Layer implements Trainable {
} }
@Override @Override
public void forEachSynapse(Consumer<Synapse> consumer) { public void applyOnSynapses(Consumer<Synapse> consumer) {
this.neurons.forEach(neuron -> neuron.forEachSynapse(consumer)); this.neurons.forEach(neuron -> neuron.applyOnSynapses(consumer));
} }
} }

View File

@@ -25,7 +25,7 @@ public class Network implements Trainable {
} }
@Override @Override
public void forEachSynapse(Consumer<Synapse> consumer) { public void applyOnSynapses(Consumer<Synapse> consumer) {
this.layers.forEach(layer -> layer.forEachSynapse(consumer)); this.layers.forEach(layer -> layer.applyOnSynapses(consumer));
} }
} }

View File

@@ -3,11 +3,13 @@ package com.naaturel.ANN.domain.model.training;
import com.naaturel.ANN.domain.abstraction.Trainable; import com.naaturel.ANN.domain.abstraction.Trainable;
import com.naaturel.ANN.domain.model.dataset.DataSet; import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry; import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
import com.naaturel.ANN.domain.model.dataset.Label;
public class TrainingContext { public class TrainingContext {
public Trainable model; public Trainable model;
public DataSet dataset; public DataSet dataset;
public DataSetEntry currentEntry; public DataSetEntry currentEntry;
public Label currentLabel;
public float prediction; public float prediction;
public float delta; public float delta;

View File

@@ -11,7 +11,8 @@ import java.util.function.Predicate;
public class TrainingPipeline { public class TrainingPipeline {
private final List<TrainingStep> steps; private final List<TrainingStep> steps;
private Consumer<TrainingContext> afterAll; private Consumer<TrainingContext> beforeEpoch;
private Consumer<TrainingContext> afterEpoch;
private Predicate<TrainingContext> stopCondition; private Predicate<TrainingContext> stopCondition;
private boolean verbose; private boolean verbose;
@@ -19,6 +20,9 @@ public class TrainingPipeline {
public TrainingPipeline(List<TrainingStep> steps) { public TrainingPipeline(List<TrainingStep> steps) {
this.steps = new ArrayList<>(steps); this.steps = new ArrayList<>(steps);
this.stopCondition = (ctx) -> false;
this.beforeEpoch = (context -> {});
this.afterEpoch = (context -> {});
} }
public TrainingPipeline stopCondition(Predicate<TrainingContext> predicate) { public TrainingPipeline stopCondition(Predicate<TrainingContext> predicate) {
@@ -26,8 +30,13 @@ public class TrainingPipeline {
return this; return this;
} }
public TrainingPipeline beforeEpoch(Consumer<TrainingContext> consumer) {
this.beforeEpoch = consumer;
return this;
}
public TrainingPipeline afterEpoch(Consumer<TrainingContext> consumer) { public TrainingPipeline afterEpoch(Consumer<TrainingContext> consumer) {
this.afterAll = consumer; this.afterEpoch = consumer;
return this; return this;
} }
@@ -43,25 +52,28 @@ public class TrainingPipeline {
public void run(TrainingContext ctx) { public void run(TrainingContext ctx) {
do { do {
this.beforeEpoch.accept(ctx);
this.executeSteps(ctx); this.executeSteps(ctx);
if(this.afterAll != null) { this.afterEpoch.accept(ctx);
this.afterAll.accept(ctx);
}
} while (!this.stopCondition.test(ctx)); } while (!this.stopCondition.test(ctx));
} }
private void executeSteps(TrainingContext ctx){ private void executeSteps(TrainingContext ctx){
for (DataSetEntry sample : ctx.dataset) { for (DataSetEntry entry : ctx.dataset) {
ctx.currentEntry = sample; ctx.currentEntry = entry;
ctx.currentLabel = ctx.dataset.getLabel(entry);
for (TrainingStep step : steps) { for (TrainingStep step : steps) {
step.run(ctx); step.run(ctx);
if(this.verbose) {
System.out.printf("Epoch : %d, ", ctx.epoch);
System.out.printf("predicted : %.2f, ", ctx.prediction);
System.out.printf("expected : %.2f, ", ctx.dataset.getLabel(ctx.currentEntry).getValue());
System.out.printf("delta : %.2f\n", ctx.delta);
}
} }
if(this.verbose) {
System.out.printf("Epoch : %d, ", ctx.epoch);
System.out.printf("predicted : %.2f, ", ctx.prediction);
System.out.printf("expected : %.2f, ", ctx.currentLabel.getValue());
System.out.printf("delta : %.2f\n", ctx.delta);
}
}
if(this.verbose) {
System.out.printf("[Global error] : %.2f\n", ctx.globalLoss);
} }
ctx.epoch += 1; ctx.epoch += 1;
} }

View File

@@ -0,0 +1,13 @@
package com.naaturel.ANN.implementation.correction;
import com.naaturel.ANN.domain.abstraction.CorrectionStrategy;
import com.naaturel.ANN.domain.model.training.TrainingContext;
public class GradientDescentCorrectionStrategy implements CorrectionStrategy {
@Override
public void apply(TrainingContext context) {
}
}

View File

@@ -7,7 +7,9 @@ public class SimpleCorrectionStrategy implements CorrectionStrategy {
@Override @Override
public void apply(TrainingContext context) { public void apply(TrainingContext context) {
context.model.forEachSynapse(syn -> { if(context.currentLabel.getValue() == context.prediction) return ;
context.model.applyOnSynapses(syn -> {
float currentW = syn.getWeight(); float currentW = syn.getWeight();
float currentInput = syn.getInput(); float currentInput = syn.getInput();
float newValue = currentW + (context.learningRate * context.delta * currentInput); float newValue = currentW + (context.learningRate * context.delta * currentInput);

View File

@@ -25,7 +25,8 @@ public class SimplePerceptron extends Neuron {
} }
@Override @Override
public void forEachSynapse(Consumer<Synapse> consumer) { public void applyOnSynapses(Consumer<Synapse> consumer) {
consumer.accept(this.bias);
this.synapses.forEach(consumer); this.synapses.forEach(consumer);
} }