diff --git a/src/main/java/com/naaturel/ANN/Main.java b/src/main/java/com/naaturel/ANN/Main.java index 08052bc..16afb02 100644 --- a/src/main/java/com/naaturel/ANN/Main.java +++ b/src/main/java/com/naaturel/ANN/Main.java @@ -68,8 +68,9 @@ public class Main { Network network = new Network(List.of(layer)); TrainingContext context = new TrainingContext(); - context.dataset = dataSet; + context.dataset = orDataSet; context.model = network; + context.learningRate = 0.3F; List steps = List.of( new PredictionStep(), @@ -81,8 +82,8 @@ public class Main { TrainingPipeline pipeline = new TrainingPipeline(steps); pipeline - .stopCondition(ctx -> ctx.globalLoss == 0 && ctx.epoch >= 1000) - .afterEpoch(ctx -> ctx.globalLoss = 0) + .stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 100) + .beforeEpoch(ctx -> ctx.globalLoss = 0) .withVerbose(true) .run(context); diff --git a/src/main/java/com/naaturel/ANN/domain/abstraction/Trainable.java b/src/main/java/com/naaturel/ANN/domain/abstraction/Trainable.java index 76e7f83..2c0d757 100644 --- a/src/main/java/com/naaturel/ANN/domain/abstraction/Trainable.java +++ b/src/main/java/com/naaturel/ANN/domain/abstraction/Trainable.java @@ -9,6 +9,6 @@ import java.util.function.Consumer; public interface Trainable { List predict(List inputs); - void forEachSynapse(Consumer consumer); + void applyOnSynapses(Consumer consumer); } diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java index 21f3ea3..ee7362e 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java +++ b/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java @@ -1,6 +1,5 @@ package com.naaturel.ANN.domain.model.neuron; -import com.naaturel.ANN.domain.abstraction.CorrectionStrategy; import com.naaturel.ANN.domain.abstraction.Neuron; import com.naaturel.ANN.domain.abstraction.Trainable; @@ -27,7 +26,7 @@ public class Layer implements Trainable { } @Override - public void forEachSynapse(Consumer consumer) { - this.neurons.forEach(neuron -> neuron.forEachSynapse(consumer)); + public void applyOnSynapses(Consumer consumer) { + this.neurons.forEach(neuron -> neuron.applyOnSynapses(consumer)); } } diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/Network.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/Network.java index 6f7697f..f960eca 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/neuron/Network.java +++ b/src/main/java/com/naaturel/ANN/domain/model/neuron/Network.java @@ -25,7 +25,7 @@ public class Network implements Trainable { } @Override - public void forEachSynapse(Consumer consumer) { - this.layers.forEach(layer -> layer.forEachSynapse(consumer)); + public void applyOnSynapses(Consumer consumer) { + this.layers.forEach(layer -> layer.applyOnSynapses(consumer)); } } diff --git a/src/main/java/com/naaturel/ANN/domain/model/training/TrainingContext.java b/src/main/java/com/naaturel/ANN/domain/model/training/TrainingContext.java index 989c731..316d478 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/training/TrainingContext.java +++ b/src/main/java/com/naaturel/ANN/domain/model/training/TrainingContext.java @@ -3,11 +3,13 @@ package com.naaturel.ANN.domain.model.training; import com.naaturel.ANN.domain.abstraction.Trainable; import com.naaturel.ANN.domain.model.dataset.DataSet; import com.naaturel.ANN.domain.model.dataset.DataSetEntry; +import com.naaturel.ANN.domain.model.dataset.Label; public class TrainingContext { public Trainable model; public DataSet dataset; public DataSetEntry currentEntry; + public Label currentLabel; public float prediction; public float delta; diff --git a/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java b/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java index 3c5f372..c8ff573 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java +++ b/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java @@ -11,7 +11,8 @@ import java.util.function.Predicate; public class TrainingPipeline { private final List steps; - private Consumer afterAll; + private Consumer beforeEpoch; + private Consumer afterEpoch; private Predicate stopCondition; private boolean verbose; @@ -19,6 +20,9 @@ public class TrainingPipeline { public TrainingPipeline(List steps) { this.steps = new ArrayList<>(steps); + this.stopCondition = (ctx) -> false; + this.beforeEpoch = (context -> {}); + this.afterEpoch = (context -> {}); } public TrainingPipeline stopCondition(Predicate predicate) { @@ -26,8 +30,13 @@ public class TrainingPipeline { return this; } + public TrainingPipeline beforeEpoch(Consumer consumer) { + this.beforeEpoch = consumer; + return this; + } + public TrainingPipeline afterEpoch(Consumer consumer) { - this.afterAll = consumer; + this.afterEpoch = consumer; return this; } @@ -43,25 +52,28 @@ public class TrainingPipeline { public void run(TrainingContext ctx) { do { + this.beforeEpoch.accept(ctx); this.executeSteps(ctx); - if(this.afterAll != null) { - this.afterAll.accept(ctx); - } + this.afterEpoch.accept(ctx); } while (!this.stopCondition.test(ctx)); } private void executeSteps(TrainingContext ctx){ - for (DataSetEntry sample : ctx.dataset) { - ctx.currentEntry = sample; + for (DataSetEntry entry : ctx.dataset) { + ctx.currentEntry = entry; + ctx.currentLabel = ctx.dataset.getLabel(entry); for (TrainingStep step : steps) { step.run(ctx); - if(this.verbose) { - System.out.printf("Epoch : %d, ", ctx.epoch); - System.out.printf("predicted : %.2f, ", ctx.prediction); - System.out.printf("expected : %.2f, ", ctx.dataset.getLabel(ctx.currentEntry).getValue()); - System.out.printf("delta : %.2f\n", ctx.delta); - } } + if(this.verbose) { + System.out.printf("Epoch : %d, ", ctx.epoch); + System.out.printf("predicted : %.2f, ", ctx.prediction); + System.out.printf("expected : %.2f, ", ctx.currentLabel.getValue()); + System.out.printf("delta : %.2f\n", ctx.delta); + } + } + if(this.verbose) { + System.out.printf("[Global error] : %.2f\n", ctx.globalLoss); } ctx.epoch += 1; } diff --git a/src/main/java/com/naaturel/ANN/implementation/correction/GradientDescentCorrectionStrategy.java b/src/main/java/com/naaturel/ANN/implementation/correction/GradientDescentCorrectionStrategy.java new file mode 100644 index 0000000..bfaa79b --- /dev/null +++ b/src/main/java/com/naaturel/ANN/implementation/correction/GradientDescentCorrectionStrategy.java @@ -0,0 +1,13 @@ +package com.naaturel.ANN.implementation.correction; + +import com.naaturel.ANN.domain.abstraction.CorrectionStrategy; +import com.naaturel.ANN.domain.model.training.TrainingContext; + +public class GradientDescentCorrectionStrategy implements CorrectionStrategy { + + @Override + public void apply(TrainingContext context) { + + } + +} diff --git a/src/main/java/com/naaturel/ANN/implementation/correction/SimpleCorrectionStrategy.java b/src/main/java/com/naaturel/ANN/implementation/correction/SimpleCorrectionStrategy.java index 47e9fa9..0efcd53 100644 --- a/src/main/java/com/naaturel/ANN/implementation/correction/SimpleCorrectionStrategy.java +++ b/src/main/java/com/naaturel/ANN/implementation/correction/SimpleCorrectionStrategy.java @@ -7,7 +7,9 @@ public class SimpleCorrectionStrategy implements CorrectionStrategy { @Override public void apply(TrainingContext context) { - context.model.forEachSynapse(syn -> { + if(context.currentLabel.getValue() == context.prediction) return ; + + context.model.applyOnSynapses(syn -> { float currentW = syn.getWeight(); float currentInput = syn.getInput(); float newValue = currentW + (context.learningRate * context.delta * currentInput); diff --git a/src/main/java/com/naaturel/ANN/implementation/neuron/SimplePerceptron.java b/src/main/java/com/naaturel/ANN/implementation/neuron/SimplePerceptron.java index b628a33..5861c74 100644 --- a/src/main/java/com/naaturel/ANN/implementation/neuron/SimplePerceptron.java +++ b/src/main/java/com/naaturel/ANN/implementation/neuron/SimplePerceptron.java @@ -25,7 +25,8 @@ public class SimplePerceptron extends Neuron { } @Override - public void forEachSynapse(Consumer consumer) { + public void applyOnSynapses(Consumer consumer) { + consumer.accept(this.bias); this.synapses.forEach(consumer); }