Fix weights correction
This commit is contained in:
@@ -68,8 +68,9 @@ public class Main {
|
|||||||
Network network = new Network(List.of(layer));
|
Network network = new Network(List.of(layer));
|
||||||
|
|
||||||
TrainingContext context = new TrainingContext();
|
TrainingContext context = new TrainingContext();
|
||||||
context.dataset = dataSet;
|
context.dataset = orDataSet;
|
||||||
context.model = network;
|
context.model = network;
|
||||||
|
context.learningRate = 0.3F;
|
||||||
|
|
||||||
List<TrainingStep> steps = List.of(
|
List<TrainingStep> steps = List.of(
|
||||||
new PredictionStep(),
|
new PredictionStep(),
|
||||||
@@ -81,8 +82,8 @@ public class Main {
|
|||||||
|
|
||||||
TrainingPipeline pipeline = new TrainingPipeline(steps);
|
TrainingPipeline pipeline = new TrainingPipeline(steps);
|
||||||
pipeline
|
pipeline
|
||||||
.stopCondition(ctx -> ctx.globalLoss == 0 && ctx.epoch >= 1000)
|
.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 100)
|
||||||
.afterEpoch(ctx -> ctx.globalLoss = 0)
|
.beforeEpoch(ctx -> ctx.globalLoss = 0)
|
||||||
.withVerbose(true)
|
.withVerbose(true)
|
||||||
.run(context);
|
.run(context);
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,6 @@ import java.util.function.Consumer;
|
|||||||
public interface Trainable {
|
public interface Trainable {
|
||||||
List<Float> predict(List<Input> inputs);
|
List<Float> predict(List<Input> inputs);
|
||||||
|
|
||||||
void forEachSynapse(Consumer<Synapse> consumer);
|
void applyOnSynapses(Consumer<Synapse> consumer);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
package com.naaturel.ANN.domain.model.neuron;
|
package com.naaturel.ANN.domain.model.neuron;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.CorrectionStrategy;
|
|
||||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||||
import com.naaturel.ANN.domain.abstraction.Trainable;
|
import com.naaturel.ANN.domain.abstraction.Trainable;
|
||||||
|
|
||||||
@@ -27,7 +26,7 @@ public class Layer implements Trainable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void forEachSynapse(Consumer<Synapse> consumer) {
|
public void applyOnSynapses(Consumer<Synapse> consumer) {
|
||||||
this.neurons.forEach(neuron -> neuron.forEachSynapse(consumer));
|
this.neurons.forEach(neuron -> neuron.applyOnSynapses(consumer));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ public class Network implements Trainable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void forEachSynapse(Consumer<Synapse> consumer) {
|
public void applyOnSynapses(Consumer<Synapse> consumer) {
|
||||||
this.layers.forEach(layer -> layer.forEachSynapse(consumer));
|
this.layers.forEach(layer -> layer.applyOnSynapses(consumer));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,11 +3,13 @@ package com.naaturel.ANN.domain.model.training;
|
|||||||
import com.naaturel.ANN.domain.abstraction.Trainable;
|
import com.naaturel.ANN.domain.abstraction.Trainable;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||||
|
import com.naaturel.ANN.domain.model.dataset.Label;
|
||||||
|
|
||||||
public class TrainingContext {
|
public class TrainingContext {
|
||||||
public Trainable model;
|
public Trainable model;
|
||||||
public DataSet dataset;
|
public DataSet dataset;
|
||||||
public DataSetEntry currentEntry;
|
public DataSetEntry currentEntry;
|
||||||
|
public Label currentLabel;
|
||||||
|
|
||||||
public float prediction;
|
public float prediction;
|
||||||
public float delta;
|
public float delta;
|
||||||
|
|||||||
@@ -11,7 +11,8 @@ import java.util.function.Predicate;
|
|||||||
public class TrainingPipeline {
|
public class TrainingPipeline {
|
||||||
|
|
||||||
private final List<TrainingStep> steps;
|
private final List<TrainingStep> steps;
|
||||||
private Consumer<TrainingContext> afterAll;
|
private Consumer<TrainingContext> beforeEpoch;
|
||||||
|
private Consumer<TrainingContext> afterEpoch;
|
||||||
private Predicate<TrainingContext> stopCondition;
|
private Predicate<TrainingContext> stopCondition;
|
||||||
|
|
||||||
private boolean verbose;
|
private boolean verbose;
|
||||||
@@ -19,6 +20,9 @@ public class TrainingPipeline {
|
|||||||
|
|
||||||
public TrainingPipeline(List<TrainingStep> steps) {
|
public TrainingPipeline(List<TrainingStep> steps) {
|
||||||
this.steps = new ArrayList<>(steps);
|
this.steps = new ArrayList<>(steps);
|
||||||
|
this.stopCondition = (ctx) -> false;
|
||||||
|
this.beforeEpoch = (context -> {});
|
||||||
|
this.afterEpoch = (context -> {});
|
||||||
}
|
}
|
||||||
|
|
||||||
public TrainingPipeline stopCondition(Predicate<TrainingContext> predicate) {
|
public TrainingPipeline stopCondition(Predicate<TrainingContext> predicate) {
|
||||||
@@ -26,8 +30,13 @@ public class TrainingPipeline {
|
|||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public TrainingPipeline beforeEpoch(Consumer<TrainingContext> consumer) {
|
||||||
|
this.beforeEpoch = consumer;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
public TrainingPipeline afterEpoch(Consumer<TrainingContext> consumer) {
|
public TrainingPipeline afterEpoch(Consumer<TrainingContext> consumer) {
|
||||||
this.afterAll = consumer;
|
this.afterEpoch = consumer;
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -43,25 +52,28 @@ public class TrainingPipeline {
|
|||||||
|
|
||||||
public void run(TrainingContext ctx) {
|
public void run(TrainingContext ctx) {
|
||||||
do {
|
do {
|
||||||
|
this.beforeEpoch.accept(ctx);
|
||||||
this.executeSteps(ctx);
|
this.executeSteps(ctx);
|
||||||
if(this.afterAll != null) {
|
this.afterEpoch.accept(ctx);
|
||||||
this.afterAll.accept(ctx);
|
|
||||||
}
|
|
||||||
} while (!this.stopCondition.test(ctx));
|
} while (!this.stopCondition.test(ctx));
|
||||||
}
|
}
|
||||||
|
|
||||||
private void executeSteps(TrainingContext ctx){
|
private void executeSteps(TrainingContext ctx){
|
||||||
for (DataSetEntry sample : ctx.dataset) {
|
for (DataSetEntry entry : ctx.dataset) {
|
||||||
ctx.currentEntry = sample;
|
ctx.currentEntry = entry;
|
||||||
|
ctx.currentLabel = ctx.dataset.getLabel(entry);
|
||||||
for (TrainingStep step : steps) {
|
for (TrainingStep step : steps) {
|
||||||
step.run(ctx);
|
step.run(ctx);
|
||||||
if(this.verbose) {
|
|
||||||
System.out.printf("Epoch : %d, ", ctx.epoch);
|
|
||||||
System.out.printf("predicted : %.2f, ", ctx.prediction);
|
|
||||||
System.out.printf("expected : %.2f, ", ctx.dataset.getLabel(ctx.currentEntry).getValue());
|
|
||||||
System.out.printf("delta : %.2f\n", ctx.delta);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
if(this.verbose) {
|
||||||
|
System.out.printf("Epoch : %d, ", ctx.epoch);
|
||||||
|
System.out.printf("predicted : %.2f, ", ctx.prediction);
|
||||||
|
System.out.printf("expected : %.2f, ", ctx.currentLabel.getValue());
|
||||||
|
System.out.printf("delta : %.2f\n", ctx.delta);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if(this.verbose) {
|
||||||
|
System.out.printf("[Global error] : %.2f\n", ctx.globalLoss);
|
||||||
}
|
}
|
||||||
ctx.epoch += 1;
|
ctx.epoch += 1;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,13 @@
|
|||||||
|
package com.naaturel.ANN.implementation.correction;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.CorrectionStrategy;
|
||||||
|
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
||||||
|
|
||||||
|
public class GradientDescentCorrectionStrategy implements CorrectionStrategy {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void apply(TrainingContext context) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -7,7 +7,9 @@ public class SimpleCorrectionStrategy implements CorrectionStrategy {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void apply(TrainingContext context) {
|
public void apply(TrainingContext context) {
|
||||||
context.model.forEachSynapse(syn -> {
|
if(context.currentLabel.getValue() == context.prediction) return ;
|
||||||
|
|
||||||
|
context.model.applyOnSynapses(syn -> {
|
||||||
float currentW = syn.getWeight();
|
float currentW = syn.getWeight();
|
||||||
float currentInput = syn.getInput();
|
float currentInput = syn.getInput();
|
||||||
float newValue = currentW + (context.learningRate * context.delta * currentInput);
|
float newValue = currentW + (context.learningRate * context.delta * currentInput);
|
||||||
|
|||||||
@@ -25,7 +25,8 @@ public class SimplePerceptron extends Neuron {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void forEachSynapse(Consumer<Synapse> consumer) {
|
public void applyOnSynapses(Consumer<Synapse> consumer) {
|
||||||
|
consumer.accept(this.bias);
|
||||||
this.synapses.forEach(consumer);
|
this.synapses.forEach(consumer);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user