diff --git a/src/main/java/com/naaturel/ANN/Main.java b/src/main/java/com/naaturel/ANN/Main.java index 7c4d2bb..6e3847b 100644 --- a/src/main/java/com/naaturel/ANN/Main.java +++ b/src/main/java/com/naaturel/ANN/Main.java @@ -2,28 +2,26 @@ package com.naaturel.ANN; import com.naaturel.ANN.domain.abstraction.Neuron; import com.naaturel.ANN.domain.abstraction.Trainer; -import com.naaturel.ANN.domain.abstraction.TrainingStep; import com.naaturel.ANN.domain.model.dataset.DataSet; -import com.naaturel.ANN.domain.model.dataset.DataSetEntry; import com.naaturel.ANN.domain.model.dataset.DatasetExtractor; -import com.naaturel.ANN.domain.model.dataset.Label; import com.naaturel.ANN.domain.model.neuron.*; -import com.naaturel.ANN.domain.model.training.TrainingContext; -import com.naaturel.ANN.domain.model.training.TrainingPipeline; -import com.naaturel.ANN.implementation.activation.Heaviside; -import com.naaturel.ANN.implementation.correction.SimpleCorrectionStrategy; +import com.naaturel.ANN.implementation.gradientDescent.Linear; +import com.naaturel.ANN.implementation.simplePerceptron.Heaviside; import com.naaturel.ANN.implementation.neuron.SimplePerceptron; +import com.naaturel.ANN.implementation.training.GradientDescentTraining; import com.naaturel.ANN.implementation.training.SimpleTraining; -import com.naaturel.ANN.implementation.training.steps.*; -import javax.xml.crypto.Data; import java.util.*; public class Main { public static void main(String[] args){ - DataSet dataset = new DatasetExtractor().extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/or.csv"); + DataSet dataset = new DatasetExtractor() + .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_9.csv"); + + DataSet orDataset = new DatasetExtractor() + .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/or.csv"); List syns = new ArrayList<>(); syns.add(new Synapse(new Input(0), new Weight(0))); @@ -31,11 +29,11 @@ public class Main { Bias bias = new Bias(new Weight(0)); - Neuron neuron = new SimplePerceptron(syns, bias, new Heaviside()); + Neuron neuron = new SimplePerceptron(syns, bias, new Linear()); Layer layer = new Layer(List.of(neuron)); Network network = new Network(List.of(layer)); - Trainer trainer = new SimpleTraining(); + Trainer trainer = new GradientDescentTraining(); trainer.train(network, dataset); } diff --git a/src/main/java/com/naaturel/ANN/domain/abstraction/AlgorithmStrategy.java b/src/main/java/com/naaturel/ANN/domain/abstraction/AlgorithmStrategy.java index 160f83c..39fa2dd 100644 --- a/src/main/java/com/naaturel/ANN/domain/abstraction/AlgorithmStrategy.java +++ b/src/main/java/com/naaturel/ANN/domain/abstraction/AlgorithmStrategy.java @@ -1,9 +1,7 @@ package com.naaturel.ANN.domain.abstraction; -import com.naaturel.ANN.domain.model.training.TrainingContext; - public interface AlgorithmStrategy { - void apply(TrainingContext ctx); + void apply(); } diff --git a/src/main/java/com/naaturel/ANN/domain/abstraction/Trainable.java b/src/main/java/com/naaturel/ANN/domain/abstraction/Model.java similarity index 87% rename from src/main/java/com/naaturel/ANN/domain/abstraction/Trainable.java rename to src/main/java/com/naaturel/ANN/domain/abstraction/Model.java index 2c0d757..3b47ed6 100644 --- a/src/main/java/com/naaturel/ANN/domain/abstraction/Trainable.java +++ b/src/main/java/com/naaturel/ANN/domain/abstraction/Model.java @@ -6,9 +6,9 @@ import com.naaturel.ANN.domain.model.neuron.Synapse; import java.util.List; import java.util.function.Consumer; -public interface Trainable { +public interface Model { + int synCount(); + void applyOnSynapses(Consumer consumer); List predict(List inputs); - void applyOnSynapses(Consumer consumer); - } diff --git a/src/main/java/com/naaturel/ANN/domain/abstraction/Neuron.java b/src/main/java/com/naaturel/ANN/domain/abstraction/Neuron.java index b6d35be..b8664b1 100644 --- a/src/main/java/com/naaturel/ANN/domain/abstraction/Neuron.java +++ b/src/main/java/com/naaturel/ANN/domain/abstraction/Neuron.java @@ -4,10 +4,9 @@ import com.naaturel.ANN.domain.model.neuron.Input; import com.naaturel.ANN.domain.model.neuron.Synapse; import com.naaturel.ANN.domain.model.neuron.Weight; -import java.util.ArrayList; import java.util.List; -public abstract class Neuron implements Trainable { +public abstract class Neuron implements Model { protected List synapses; protected Bias bias; @@ -35,4 +34,9 @@ public abstract class Neuron implements Trainable { syn.setInput(inputs.get(i)); } } + + @Override + public int synCount() { + return this.synapses.size()+1; //take the bias in account + } } diff --git a/src/main/java/com/naaturel/ANN/domain/abstraction/Trainer.java b/src/main/java/com/naaturel/ANN/domain/abstraction/Trainer.java index a305cb1..06875a2 100644 --- a/src/main/java/com/naaturel/ANN/domain/abstraction/Trainer.java +++ b/src/main/java/com/naaturel/ANN/domain/abstraction/Trainer.java @@ -3,5 +3,5 @@ package com.naaturel.ANN.domain.abstraction; import com.naaturel.ANN.domain.model.dataset.DataSet; public interface Trainer { - void train(Trainable model, DataSet dataset); + void train(Model model, DataSet dataset); } diff --git a/src/main/java/com/naaturel/ANN/domain/model/training/TrainingContext.java b/src/main/java/com/naaturel/ANN/domain/abstraction/TrainingContext.java similarity index 72% rename from src/main/java/com/naaturel/ANN/domain/model/training/TrainingContext.java rename to src/main/java/com/naaturel/ANN/domain/abstraction/TrainingContext.java index 316d478..e329ccd 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/training/TrainingContext.java +++ b/src/main/java/com/naaturel/ANN/domain/abstraction/TrainingContext.java @@ -1,21 +1,21 @@ -package com.naaturel.ANN.domain.model.training; +package com.naaturel.ANN.domain.abstraction; -import com.naaturel.ANN.domain.abstraction.Trainable; import com.naaturel.ANN.domain.model.dataset.DataSet; import com.naaturel.ANN.domain.model.dataset.DataSetEntry; import com.naaturel.ANN.domain.model.dataset.Label; -public class TrainingContext { - public Trainable model; +public abstract class TrainingContext { + public Model model; public DataSet dataset; public DataSetEntry currentEntry; - public Label currentLabel; + public Label currentLabel; public float prediction; public float delta; - public float localLoss; - public float globalLoss; - public float learningRate; + public float globalLoss; + public float localLoss; + + public float learningRate; public int epoch; } diff --git a/src/main/java/com/naaturel/ANN/domain/abstraction/TrainingStep.java b/src/main/java/com/naaturel/ANN/domain/abstraction/TrainingStep.java index 448ddae..b0b865b 100644 --- a/src/main/java/com/naaturel/ANN/domain/abstraction/TrainingStep.java +++ b/src/main/java/com/naaturel/ANN/domain/abstraction/TrainingStep.java @@ -1,9 +1,7 @@ package com.naaturel.ANN.domain.abstraction; -import com.naaturel.ANN.domain.model.training.TrainingContext; - public interface TrainingStep { - void run(TrainingContext ctx); + void run(); } diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java index ee7362e..a79ba9e 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java +++ b/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java @@ -1,13 +1,13 @@ package com.naaturel.ANN.domain.model.neuron; import com.naaturel.ANN.domain.abstraction.Neuron; -import com.naaturel.ANN.domain.abstraction.Trainable; +import com.naaturel.ANN.domain.abstraction.Model; import java.util.ArrayList; import java.util.List; import java.util.function.Consumer; -public class Layer implements Trainable { +public class Layer implements Model { private final List neurons; @@ -25,6 +25,15 @@ public class Layer implements Trainable { return result; } + @Override + public int synCount() { + int res = 0; + for (Neuron neuron : this.neurons) { + res += neuron.synCount(); + } + return res; + } + @Override public void applyOnSynapses(Consumer consumer) { this.neurons.forEach(neuron -> neuron.applyOnSynapses(consumer)); diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/Network.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/Network.java index f960eca..283106d 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/neuron/Network.java +++ b/src/main/java/com/naaturel/ANN/domain/model/neuron/Network.java @@ -1,12 +1,12 @@ package com.naaturel.ANN.domain.model.neuron; -import com.naaturel.ANN.domain.abstraction.Trainable; +import com.naaturel.ANN.domain.abstraction.Model; import java.util.ArrayList; import java.util.List; import java.util.function.Consumer; -public class Network implements Trainable { +public class Network implements Model { private final List layers; @@ -24,6 +24,15 @@ public class Network implements Trainable { return result; } + @Override + public int synCount() { + int res = 0; + for(Layer layer : this.layers){ + res += layer.synCount(); + } + return res; + } + @Override public void applyOnSynapses(Consumer consumer) { this.layers.forEach(layer -> layer.applyOnSynapses(consumer)); diff --git a/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java b/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java index c8ff573..7f2d00c 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java +++ b/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java @@ -1,5 +1,6 @@ package com.naaturel.ANN.domain.model.training; +import com.naaturel.ANN.domain.abstraction.TrainingContext; import com.naaturel.ANN.domain.abstraction.TrainingStep; import com.naaturel.ANN.domain.model.dataset.DataSetEntry; @@ -55,6 +56,9 @@ public class TrainingPipeline { this.beforeEpoch.accept(ctx); this.executeSteps(ctx); this.afterEpoch.accept(ctx); + if(this.verbose) { + System.out.printf("[Global error] : %.2f\n", ctx.globalLoss); + } } while (!this.stopCondition.test(ctx)); } @@ -63,18 +67,16 @@ public class TrainingPipeline { ctx.currentEntry = entry; ctx.currentLabel = ctx.dataset.getLabel(entry); for (TrainingStep step : steps) { - step.run(ctx); + step.run(); } if(this.verbose) { System.out.printf("Epoch : %d, ", ctx.epoch); System.out.printf("predicted : %.2f, ", ctx.prediction); System.out.printf("expected : %.2f, ", ctx.currentLabel.getValue()); - System.out.printf("delta : %.2f\n", ctx.delta); + System.out.printf("delta : %.2f, ", ctx.delta); + System.out.printf("loss : %.5f\n", ctx.localLoss); } } - if(this.verbose) { - System.out.printf("[Global error] : %.2f\n", ctx.globalLoss); - } ctx.epoch += 1; } } diff --git a/src/main/java/com/naaturel/ANN/implementation/correction/GradientDescentCorrectionStrategy.java b/src/main/java/com/naaturel/ANN/implementation/correction/GradientDescentCorrectionStrategy.java deleted file mode 100644 index b9a1f6c..0000000 --- a/src/main/java/com/naaturel/ANN/implementation/correction/GradientDescentCorrectionStrategy.java +++ /dev/null @@ -1,25 +0,0 @@ -package com.naaturel.ANN.implementation.correction; - -import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; -import com.naaturel.ANN.domain.model.training.TrainingContext; - -import java.util.ArrayList; -import java.util.List; - -public class GradientDescentCorrectionStrategy implements AlgorithmStrategy { - - List correctorTerms; - - public GradientDescentCorrectionStrategy(int nbrCorrectors){ - this.correctorTerms = new ArrayList<>(); - for (int i = 0; i < nbrCorrectors; i++){ - correctorTerms.add(0F); - } - } - - @Override - public void apply(TrainingContext context) { - - } - -} diff --git a/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentCorrectionStrategy.java b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentCorrectionStrategy.java new file mode 100644 index 0000000..c182e3e --- /dev/null +++ b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentCorrectionStrategy.java @@ -0,0 +1,25 @@ +package com.naaturel.ANN.implementation.gradientDescent; + +import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; + +import java.util.concurrent.atomic.AtomicInteger; + +public class GradientDescentCorrectionStrategy implements AlgorithmStrategy { + + private final GradientDescentTrainingContext context; + + public GradientDescentCorrectionStrategy(GradientDescentTrainingContext context) { + this.context = context; + } + + @Override + public void apply() { + AtomicInteger i = new AtomicInteger(0); + context.model.applyOnSynapses(syn -> { + float corrector = context.correctorTerms.get(i.get()); + float c = syn.getWeight() + corrector; + syn.setWeight(c); + i.incrementAndGet(); + }); + } +} diff --git a/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentErrorStrategy.java b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentErrorStrategy.java new file mode 100644 index 0000000..84e01c2 --- /dev/null +++ b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentErrorStrategy.java @@ -0,0 +1,26 @@ +package com.naaturel.ANN.implementation.gradientDescent; + +import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; + +import java.util.concurrent.atomic.AtomicInteger; + +public class GradientDescentErrorStrategy implements AlgorithmStrategy { + + private final GradientDescentTrainingContext context; + + public GradientDescentErrorStrategy(GradientDescentTrainingContext context) { + this.context = context; + } + + + @Override + public void apply() { + AtomicInteger i = new AtomicInteger(0); + context.model.applyOnSynapses(syn -> { + float corrector = context.correctorTerms.get(i.get()); + corrector += context.learningRate * context.delta * syn.getInput(); + context.correctorTerms.set(i.get(), corrector); + i.incrementAndGet(); + }); + } +} diff --git a/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentTrainingContext.java b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentTrainingContext.java new file mode 100644 index 0000000..0b1ec5f --- /dev/null +++ b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentTrainingContext.java @@ -0,0 +1,11 @@ +package com.naaturel.ANN.implementation.gradientDescent; + +import com.naaturel.ANN.domain.abstraction.TrainingContext; + +import java.util.List; + +public class GradientDescentTrainingContext extends TrainingContext { + + public List correctorTerms; + +} diff --git a/src/main/java/com/naaturel/ANN/implementation/activation/Linear.java b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/Linear.java similarity index 82% rename from src/main/java/com/naaturel/ANN/implementation/activation/Linear.java rename to src/main/java/com/naaturel/ANN/implementation/gradientDescent/Linear.java index 8268bf5..b7bc8a7 100644 --- a/src/main/java/com/naaturel/ANN/implementation/activation/Linear.java +++ b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/Linear.java @@ -1,4 +1,4 @@ -package com.naaturel.ANN.implementation.activation; +package com.naaturel.ANN.implementation.gradientDescent; import com.naaturel.ANN.domain.abstraction.ActivationFunction; import com.naaturel.ANN.domain.abstraction.Neuron; diff --git a/src/main/java/com/naaturel/ANN/implementation/gradientDescent/SquareLossStrategy.java b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/SquareLossStrategy.java new file mode 100644 index 0000000..0243756 --- /dev/null +++ b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/SquareLossStrategy.java @@ -0,0 +1,19 @@ +package com.naaturel.ANN.implementation.gradientDescent; + +import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; +import com.naaturel.ANN.implementation.simplePerceptron.SimpleTrainingContext; + +public class SquareLossStrategy implements AlgorithmStrategy { + + private final GradientDescentTrainingContext context; + + public SquareLossStrategy(GradientDescentTrainingContext context) { + this.context = context; + } + + @Override + public void apply() { + this.context.localLoss = (float)Math.pow(this.context.delta, 2)/2; + this.context.globalLoss += context.localLoss; + } +} diff --git a/src/main/java/com/naaturel/ANN/implementation/loss/SimpleLossStrategy.java b/src/main/java/com/naaturel/ANN/implementation/loss/SimpleLossStrategy.java deleted file mode 100644 index ee5d785..0000000 --- a/src/main/java/com/naaturel/ANN/implementation/loss/SimpleLossStrategy.java +++ /dev/null @@ -1,11 +0,0 @@ -package com.naaturel.ANN.implementation.loss; - -import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; -import com.naaturel.ANN.domain.model.training.TrainingContext; - -public class SimpleLossStrategy implements AlgorithmStrategy { - @Override - public void apply(TrainingContext ctx) { - ctx.localLoss = Math.abs(ctx.delta); - } -} diff --git a/src/main/java/com/naaturel/ANN/implementation/loss/SquareLossStrategy.java b/src/main/java/com/naaturel/ANN/implementation/loss/SquareLossStrategy.java deleted file mode 100644 index 8958430..0000000 --- a/src/main/java/com/naaturel/ANN/implementation/loss/SquareLossStrategy.java +++ /dev/null @@ -1,11 +0,0 @@ -package com.naaturel.ANN.implementation.loss; - -import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; -import com.naaturel.ANN.domain.model.training.TrainingContext; - -public class SquareLossStrategy implements AlgorithmStrategy { - @Override - public void apply(TrainingContext ctx) { - ctx.localLoss = (float)Math.pow(ctx.delta, 2) / 2; - } -} diff --git a/src/main/java/com/naaturel/ANN/implementation/activation/Heaviside.java b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/Heaviside.java similarity index 85% rename from src/main/java/com/naaturel/ANN/implementation/activation/Heaviside.java rename to src/main/java/com/naaturel/ANN/implementation/simplePerceptron/Heaviside.java index baaf10f..badfc30 100644 --- a/src/main/java/com/naaturel/ANN/implementation/activation/Heaviside.java +++ b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/Heaviside.java @@ -1,4 +1,4 @@ -package com.naaturel.ANN.implementation.activation; +package com.naaturel.ANN.implementation.simplePerceptron; import com.naaturel.ANN.domain.abstraction.ActivationFunction; import com.naaturel.ANN.domain.abstraction.Neuron; diff --git a/src/main/java/com/naaturel/ANN/implementation/correction/SimpleCorrectionStrategy.java b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleCorrectionStrategy.java similarity index 68% rename from src/main/java/com/naaturel/ANN/implementation/correction/SimpleCorrectionStrategy.java rename to src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleCorrectionStrategy.java index efabb8b..0ed8bcd 100644 --- a/src/main/java/com/naaturel/ANN/implementation/correction/SimpleCorrectionStrategy.java +++ b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleCorrectionStrategy.java @@ -1,14 +1,18 @@ -package com.naaturel.ANN.implementation.correction; +package com.naaturel.ANN.implementation.simplePerceptron; import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; -import com.naaturel.ANN.domain.model.training.TrainingContext; public class SimpleCorrectionStrategy implements AlgorithmStrategy { - @Override - public void apply(TrainingContext context) { - if(context.currentLabel.getValue() == context.prediction) return ; + private final SimpleTrainingContext context; + public SimpleCorrectionStrategy(SimpleTrainingContext context) { + this.context = context; + } + + @Override + public void apply() { + if(context.currentLabel.getValue() == context.prediction) return ; context.model.applyOnSynapses(syn -> { float currentW = syn.getWeight(); float currentInput = syn.getInput(); diff --git a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleDeltaStrategy.java b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleDeltaStrategy.java new file mode 100644 index 0000000..4e7da26 --- /dev/null +++ b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleDeltaStrategy.java @@ -0,0 +1,26 @@ +package com.naaturel.ANN.implementation.simplePerceptron; + +import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; +import com.naaturel.ANN.domain.abstraction.TrainingContext; +import com.naaturel.ANN.domain.model.dataset.DataSet; +import com.naaturel.ANN.domain.model.dataset.DataSetEntry; +import com.naaturel.ANN.domain.model.dataset.Label; + +public class SimpleDeltaStrategy implements AlgorithmStrategy { + + private final TrainingContext context; + + public SimpleDeltaStrategy(TrainingContext context) { + this.context = context; + } + + @Override + public void apply() { + DataSet dataSet = context.dataset; + DataSetEntry entry = context.currentEntry; + Label label = dataSet.getLabel(entry); + + context.delta = label.getValue() - context.prediction; + } + +} diff --git a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleErrorRegistrationStrategy.java b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleErrorRegistrationStrategy.java new file mode 100644 index 0000000..c6be25d --- /dev/null +++ b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleErrorRegistrationStrategy.java @@ -0,0 +1,17 @@ +package com.naaturel.ANN.implementation.simplePerceptron; + +import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; + +public class SimpleErrorRegistrationStrategy implements AlgorithmStrategy { + + private final SimpleTrainingContext context; + + public SimpleErrorRegistrationStrategy(SimpleTrainingContext context) { + this.context = context; + } + + @Override + public void apply() { + context.globalLoss += context.localLoss; + } +} diff --git a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleLossStrategy.java b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleLossStrategy.java new file mode 100644 index 0000000..145413d --- /dev/null +++ b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleLossStrategy.java @@ -0,0 +1,17 @@ +package com.naaturel.ANN.implementation.simplePerceptron; + +import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; + +public class SimpleLossStrategy implements AlgorithmStrategy { + + private final SimpleTrainingContext context; + + public SimpleLossStrategy(SimpleTrainingContext context) { + this.context = context; + } + + @Override + public void apply() { + this.context.localLoss = Math.abs(this.context.delta); + } +} diff --git a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimplePredictionStrategy.java b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimplePredictionStrategy.java new file mode 100644 index 0000000..64b7e2e --- /dev/null +++ b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimplePredictionStrategy.java @@ -0,0 +1,21 @@ +package com.naaturel.ANN.implementation.simplePerceptron; + +import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; +import com.naaturel.ANN.domain.abstraction.TrainingContext; + +import java.util.List; + +public class SimplePredictionStrategy implements AlgorithmStrategy { + + private final TrainingContext context; + + public SimplePredictionStrategy(TrainingContext context) { + this.context = context; + } + + @Override + public void apply() { + List predictions = context.model.predict(context.currentEntry.getData()); + context.prediction = predictions.getFirst(); + } +} diff --git a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleTrainingContext.java b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleTrainingContext.java new file mode 100644 index 0000000..b804f21 --- /dev/null +++ b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleTrainingContext.java @@ -0,0 +1,6 @@ +package com.naaturel.ANN.implementation.simplePerceptron; + +import com.naaturel.ANN.domain.abstraction.TrainingContext; + +public class SimpleTrainingContext extends TrainingContext { +} diff --git a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java index dc119b4..42df3fa 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java @@ -1,16 +1,19 @@ package com.naaturel.ANN.implementation.training; -import com.naaturel.ANN.domain.abstraction.Trainable; +import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.domain.abstraction.TrainingStep; import com.naaturel.ANN.domain.model.dataset.DataSet; -import com.naaturel.ANN.domain.model.training.TrainingContext; +import com.naaturel.ANN.implementation.gradientDescent.GradientDescentErrorStrategy; +import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext; import com.naaturel.ANN.domain.model.training.TrainingPipeline; -import com.naaturel.ANN.implementation.correction.GradientDescentCorrectionStrategy; -import com.naaturel.ANN.implementation.correction.SimpleCorrectionStrategy; -import com.naaturel.ANN.implementation.loss.SquareLossStrategy; +import com.naaturel.ANN.implementation.gradientDescent.GradientDescentCorrectionStrategy; +import com.naaturel.ANN.implementation.gradientDescent.SquareLossStrategy; +import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStrategy; +import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy; import com.naaturel.ANN.implementation.training.steps.*; +import java.util.ArrayList; import java.util.List; public class GradientDescentTraining implements Trainer { @@ -20,25 +23,31 @@ public class GradientDescentTraining implements Trainer { } @Override - public void train(Trainable model, DataSet dataset) { - TrainingContext context = new TrainingContext(); + public void train(Model model, DataSet dataset) { + GradientDescentTrainingContext context = new GradientDescentTrainingContext(); context.dataset = dataset; context.model = model; - context.learningRate = 0.3F; + context.learningRate = 0.00011F; + context.correctorTerms = new ArrayList<>(); List steps = List.of( - new PredictionStep(), - new DeltaStep(), - new LossStep(new SquareLossStrategy()), - new SimpleErrorDetectionStep(), - new WeightCorrectionStep(new GradientDescentCorrectionStrategy(2)) + new PredictionStep(new SimplePredictionStrategy(context)), + new DeltaStep(new SimpleDeltaStrategy(context)), + new LossStep(new SquareLossStrategy(context)), + new ErrorRegistrationStep(new GradientDescentErrorStrategy(context)), + new WeightCorrectionStep(new GradientDescentCorrectionStrategy(context)) ); TrainingPipeline pipeline = new TrainingPipeline(steps); pipeline - .stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 100) - .beforeEpoch(ctx -> ctx.globalLoss = 0) - .afterEpoch(ctx -> ()) + .stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 1000) + .beforeEpoch(ctx -> { + ctx.globalLoss = 0.0F; + for (int i = 0; i < model.synCount(); i++){ + context.correctorTerms.add(0F); + } + }) + .afterEpoch(ctx -> ctx.globalLoss /= ctx.dataset.size()) .withVerbose(true) .withTimeMeasurement(true) .run(context); diff --git a/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java index 351619c..2fb7d20 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java @@ -1,13 +1,11 @@ package com.naaturel.ANN.implementation.training; -import com.naaturel.ANN.domain.abstraction.Trainable; +import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.domain.abstraction.TrainingStep; import com.naaturel.ANN.domain.model.dataset.DataSet; -import com.naaturel.ANN.domain.model.training.TrainingContext; +import com.naaturel.ANN.implementation.simplePerceptron.*; import com.naaturel.ANN.domain.model.training.TrainingPipeline; -import com.naaturel.ANN.implementation.correction.SimpleCorrectionStrategy; -import com.naaturel.ANN.implementation.loss.SimpleLossStrategy; import com.naaturel.ANN.implementation.training.steps.*; import java.util.List; @@ -19,18 +17,18 @@ public class SimpleTraining implements Trainer { } @Override - public void train(Trainable model, DataSet dataset) { - TrainingContext context = new TrainingContext(); + public void train(Model model, DataSet dataset) { + SimpleTrainingContext context = new SimpleTrainingContext(); context.dataset = dataset; context.model = model; context.learningRate = 0.3F; List steps = List.of( - new PredictionStep(), - new DeltaStep(), - new LossStep(new SimpleLossStrategy()), - new SimpleErrorDetectionStep(), - new WeightCorrectionStep(new SimpleCorrectionStrategy()) + new PredictionStep(new SimplePredictionStrategy(context)), + new DeltaStep(new SimpleDeltaStrategy(context)), + new LossStep(new SimpleLossStrategy(context)), + new ErrorRegistrationStep(new SimpleErrorRegistrationStrategy(context)), + new WeightCorrectionStep(new SimpleCorrectionStrategy(context)) ); TrainingPipeline pipeline = new TrainingPipeline(steps); diff --git a/src/main/java/com/naaturel/ANN/implementation/training/steps/DeltaStep.java b/src/main/java/com/naaturel/ANN/implementation/training/steps/DeltaStep.java index 4938f40..d05c9df 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/steps/DeltaStep.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/steps/DeltaStep.java @@ -1,19 +1,22 @@ package com.naaturel.ANN.implementation.training.steps; +import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; +import com.naaturel.ANN.domain.abstraction.TrainingContext; import com.naaturel.ANN.domain.abstraction.TrainingStep; import com.naaturel.ANN.domain.model.dataset.DataSet; import com.naaturel.ANN.domain.model.dataset.DataSetEntry; import com.naaturel.ANN.domain.model.dataset.Label; -import com.naaturel.ANN.domain.model.training.TrainingContext; public class DeltaStep implements TrainingStep { - @Override - public void run(TrainingContext ctx) { - DataSet dataSet = ctx.dataset; - DataSetEntry entry = ctx.currentEntry; - Label label = dataSet.getLabel(entry); + private final AlgorithmStrategy strategy; - ctx.delta = label.getValue() - ctx.prediction; + public DeltaStep(AlgorithmStrategy strategy) { + this.strategy = strategy; + } + + @Override + public void run() { + this.strategy.apply(); } } diff --git a/src/main/java/com/naaturel/ANN/implementation/training/steps/ErrorRegistrationStep.java b/src/main/java/com/naaturel/ANN/implementation/training/steps/ErrorRegistrationStep.java new file mode 100644 index 0000000..cd32511 --- /dev/null +++ b/src/main/java/com/naaturel/ANN/implementation/training/steps/ErrorRegistrationStep.java @@ -0,0 +1,18 @@ +package com.naaturel.ANN.implementation.training.steps; + +import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; +import com.naaturel.ANN.domain.abstraction.TrainingStep; + +public class ErrorRegistrationStep implements TrainingStep { + + private final AlgorithmStrategy strategy; + + public ErrorRegistrationStep(AlgorithmStrategy strategy) { + this.strategy = strategy; + } + + @Override + public void run() { + this.strategy.apply(); + } +} diff --git a/src/main/java/com/naaturel/ANN/implementation/training/steps/LossStep.java b/src/main/java/com/naaturel/ANN/implementation/training/steps/LossStep.java index 2c50287..b047c34 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/steps/LossStep.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/steps/LossStep.java @@ -1,11 +1,12 @@ package com.naaturel.ANN.implementation.training.steps; import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; +import com.naaturel.ANN.domain.abstraction.TrainingContext; import com.naaturel.ANN.domain.abstraction.TrainingStep; -import com.naaturel.ANN.domain.model.training.TrainingContext; public class LossStep implements TrainingStep { + private final AlgorithmStrategy lossStrategy; public LossStep(AlgorithmStrategy strategy) { @@ -13,7 +14,7 @@ public class LossStep implements TrainingStep { } @Override - public void run(TrainingContext ctx) { - this.lossStrategy.apply(ctx); + public void run() { + this.lossStrategy.apply(); } } diff --git a/src/main/java/com/naaturel/ANN/implementation/training/steps/PredictionStep.java b/src/main/java/com/naaturel/ANN/implementation/training/steps/PredictionStep.java index e14253c..b598a15 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/steps/PredictionStep.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/steps/PredictionStep.java @@ -1,17 +1,23 @@ package com.naaturel.ANN.implementation.training.steps; +import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; +import com.naaturel.ANN.domain.abstraction.TrainingContext; import com.naaturel.ANN.domain.abstraction.TrainingStep; -import com.naaturel.ANN.domain.model.neuron.Input; -import com.naaturel.ANN.domain.model.training.TrainingContext; +import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy; +import com.naaturel.ANN.implementation.simplePerceptron.SimpleTrainingContext; -import java.util.ArrayList; import java.util.List; public class PredictionStep implements TrainingStep { + private final SimplePredictionStrategy strategy; + + public PredictionStep(SimplePredictionStrategy strategy) { + this.strategy = strategy; + } + @Override - public void run(TrainingContext ctx) { - List predictions = ctx.model.predict(ctx.currentEntry.getData()); - ctx.prediction = predictions.getFirst(); + public void run() { + this.strategy.apply(); } } diff --git a/src/main/java/com/naaturel/ANN/implementation/training/steps/SimpleErrorDetectionStep.java b/src/main/java/com/naaturel/ANN/implementation/training/steps/SimpleErrorDetectionStep.java deleted file mode 100644 index 0d411b4..0000000 --- a/src/main/java/com/naaturel/ANN/implementation/training/steps/SimpleErrorDetectionStep.java +++ /dev/null @@ -1,13 +0,0 @@ -package com.naaturel.ANN.implementation.training.steps; - -import com.naaturel.ANN.domain.abstraction.TrainingStep; -import com.naaturel.ANN.domain.model.training.TrainingContext; - -public class SimpleErrorDetectionStep implements TrainingStep { - - @Override - public void run(TrainingContext ctx) { - ctx.globalLoss += ctx.localLoss; - } - -} diff --git a/src/main/java/com/naaturel/ANN/implementation/training/steps/WeightCorrectionStep.java b/src/main/java/com/naaturel/ANN/implementation/training/steps/WeightCorrectionStep.java index 1c015c0..0db68f6 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/steps/WeightCorrectionStep.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/steps/WeightCorrectionStep.java @@ -2,7 +2,6 @@ package com.naaturel.ANN.implementation.training.steps; import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; import com.naaturel.ANN.domain.abstraction.TrainingStep; -import com.naaturel.ANN.domain.model.training.TrainingContext; public class WeightCorrectionStep implements TrainingStep { @@ -13,7 +12,7 @@ public class WeightCorrectionStep implements TrainingStep { } @Override - public void run(TrainingContext ctx) { - this.correctionStrategy.apply(ctx); + public void run() { + this.correctionStrategy.apply(); } }