From 89d9abe329add4a2d63554737ee02e54293247b8 Mon Sep 17 00:00:00 2001 From: Laurent <2-naaturel@users.noreply.gitlab.naaturel.be> Date: Mon, 23 Mar 2026 16:39:12 +0100 Subject: [PATCH] Implement main structure of framework --- src/main/java/com/naaturel/ANN/Main.java | 41 +++++++---- .../abstraction/CorrectionStrategy.java | 9 +++ .../ANN/domain/abstraction/Neuron.java | 33 +++------ .../ANN/domain/abstraction/NeuronTrainer.java | 13 ---- .../ANN/domain/abstraction/Trainable.java | 9 ++- .../ANN/domain/abstraction/Trainer.java | 6 +- .../ANN/domain/abstraction/TrainingStep.java | 9 +++ .../ANN/domain/model/neuron/Layer.java | 33 +++++++++ .../ANN/domain/model/neuron/Network.java | 31 +++++++++ .../ANN/domain/model/neuron/Synapse.java | 4 +- .../model/training/TrainingContext.java | 19 ++++++ .../model/training/TrainingPipeline.java | 68 +++++++++++++++++++ .../Heaviside.java | 2 +- .../Linear.java | 2 +- .../correction/SimpleCorrectionStrategy.java | 17 +++++ .../neuron/SimplePerceptron.java | 16 ++++- .../training/AdalineTraining.java | 4 +- .../training/GradientDescentTraining.java | 4 +- .../training/SimpleTraining.java | 22 ++++-- .../training/steps/DeltaStep.java | 19 ++++++ .../training/steps/PredictionStep.java | 21 ++++++ .../steps/SimpleErrorDetectionStep.java | 13 ++++ .../training/steps/SimpleLossStep.java | 12 ++++ .../training/steps/WeightCorrectionStep.java | 19 ++++++ 24 files changed, 353 insertions(+), 73 deletions(-) create mode 100644 src/main/java/com/naaturel/ANN/domain/abstraction/CorrectionStrategy.java delete mode 100644 src/main/java/com/naaturel/ANN/domain/abstraction/NeuronTrainer.java create mode 100644 src/main/java/com/naaturel/ANN/domain/abstraction/TrainingStep.java create mode 100644 src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java create mode 100644 src/main/java/com/naaturel/ANN/domain/model/neuron/Network.java create mode 100644 src/main/java/com/naaturel/ANN/domain/model/training/TrainingContext.java create mode 100644 src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java rename src/main/java/com/naaturel/ANN/implementation/{activationFunction => activation}/Heaviside.java (85%) rename src/main/java/com/naaturel/ANN/implementation/{activationFunction => activation}/Linear.java (81%) create mode 100644 src/main/java/com/naaturel/ANN/implementation/correction/SimpleCorrectionStrategy.java create mode 100644 src/main/java/com/naaturel/ANN/implementation/training/steps/DeltaStep.java create mode 100644 src/main/java/com/naaturel/ANN/implementation/training/steps/PredictionStep.java create mode 100644 src/main/java/com/naaturel/ANN/implementation/training/steps/SimpleErrorDetectionStep.java create mode 100644 src/main/java/com/naaturel/ANN/implementation/training/steps/SimpleLossStep.java create mode 100644 src/main/java/com/naaturel/ANN/implementation/training/steps/WeightCorrectionStep.java diff --git a/src/main/java/com/naaturel/ANN/Main.java b/src/main/java/com/naaturel/ANN/Main.java index 218d3d8..08052bc 100644 --- a/src/main/java/com/naaturel/ANN/Main.java +++ b/src/main/java/com/naaturel/ANN/Main.java @@ -1,18 +1,17 @@ package com.naaturel.ANN; import com.naaturel.ANN.domain.abstraction.Neuron; -import com.naaturel.ANN.domain.abstraction.Trainer; +import com.naaturel.ANN.domain.abstraction.TrainingStep; import com.naaturel.ANN.domain.model.dataset.DataSet; import com.naaturel.ANN.domain.model.dataset.DataSetEntry; import com.naaturel.ANN.domain.model.dataset.Label; -import com.naaturel.ANN.domain.model.neuron.Bias; -import com.naaturel.ANN.domain.model.neuron.Input; -import com.naaturel.ANN.domain.model.neuron.Synapse; -import com.naaturel.ANN.domain.model.neuron.Weight; -import com.naaturel.ANN.implementation.activationFunction.Linear; +import com.naaturel.ANN.domain.model.neuron.*; +import com.naaturel.ANN.domain.model.training.TrainingContext; +import com.naaturel.ANN.domain.model.training.TrainingPipeline; +import com.naaturel.ANN.implementation.activation.Heaviside; +import com.naaturel.ANN.implementation.correction.SimpleCorrectionStrategy; import com.naaturel.ANN.implementation.neuron.SimplePerceptron; -import com.naaturel.ANN.implementation.training.AdalineTraining; -import com.naaturel.ANN.implementation.training.GradientDescentTraining; +import com.naaturel.ANN.implementation.training.steps.*; import java.util.*; @@ -64,14 +63,28 @@ public class Main { Bias bias = new Bias(new Weight(0)); - Neuron n = new SimplePerceptron(syns, bias, new Linear()); - Trainer trainer = new AdalineTraining(); + Neuron neuron = new SimplePerceptron(syns, bias, new Heaviside()); + Layer layer = new Layer(List.of(neuron)); + Network network = new Network(List.of(layer)); - long start = System.currentTimeMillis(); + TrainingContext context = new TrainingContext(); + context.dataset = dataSet; + context.model = network; - trainer.train(n, 0.03F, andDataSet); + List steps = List.of( + new PredictionStep(), + new DeltaStep(), + new SimpleLossStep(), + new SimpleErrorDetectionStep(), + new WeightCorrectionStep(new SimpleCorrectionStrategy()) + ); + + TrainingPipeline pipeline = new TrainingPipeline(steps); + pipeline + .stopCondition(ctx -> ctx.globalLoss == 0 && ctx.epoch >= 1000) + .afterEpoch(ctx -> ctx.globalLoss = 0) + .withVerbose(true) + .run(context); - long end = System.currentTimeMillis(); - System.out.printf("Training completed in %.2f s%n", (end - start) / 1000.0); } } diff --git a/src/main/java/com/naaturel/ANN/domain/abstraction/CorrectionStrategy.java b/src/main/java/com/naaturel/ANN/domain/abstraction/CorrectionStrategy.java new file mode 100644 index 0000000..58dd0ae --- /dev/null +++ b/src/main/java/com/naaturel/ANN/domain/abstraction/CorrectionStrategy.java @@ -0,0 +1,9 @@ +package com.naaturel.ANN.domain.abstraction; + +import com.naaturel.ANN.domain.model.training.TrainingContext; + +public interface CorrectionStrategy { + + void apply(TrainingContext context); + +} diff --git a/src/main/java/com/naaturel/ANN/domain/abstraction/Neuron.java b/src/main/java/com/naaturel/ANN/domain/abstraction/Neuron.java index dd74a52..b6d35be 100644 --- a/src/main/java/com/naaturel/ANN/domain/abstraction/Neuron.java +++ b/src/main/java/com/naaturel/ANN/domain/abstraction/Neuron.java @@ -7,7 +7,7 @@ import com.naaturel.ANN.domain.model.neuron.Weight; import java.util.ArrayList; import java.util.List; -public abstract class Neuron { +public abstract class Neuron implements Trainable { protected List synapses; protected Bias bias; @@ -19,37 +19,20 @@ public abstract class Neuron { this.activationFunction = func; } - public abstract float predict(); public abstract float calculateWeightedSum(); - public int getSynCount(){ - return this.synapses.size(); - } - - public void setInput(int index, Input input){ - Synapse syn = this.synapses.get(index); - syn.setInput(input.getValue()); - } - - public Bias getBias(){ - return this.bias; - } - public void updateBias(Weight weight) { this.bias.setWeight(weight.getValue()); } - public Synapse getSynapse(int index){ - return this.synapses.get(index); + public void updateWeight(int index, Weight weight) { + this.synapses.get(index).setWeight(weight.getValue()); } - public List getSynapses() { - return new ArrayList<>(this.synapses); + protected void setInputs(List inputs){ + for(int i = 0; i < inputs.size() && i < synapses.size(); i++){ + Synapse syn = this.synapses.get(i); + syn.setInput(inputs.get(i)); + } } - - public void setWeight(int index, Weight weight){ - Synapse syn = this.synapses.get(index); - syn.setWeight(weight.getValue()); - } - } diff --git a/src/main/java/com/naaturel/ANN/domain/abstraction/NeuronTrainer.java b/src/main/java/com/naaturel/ANN/domain/abstraction/NeuronTrainer.java deleted file mode 100644 index 7127ba0..0000000 --- a/src/main/java/com/naaturel/ANN/domain/abstraction/NeuronTrainer.java +++ /dev/null @@ -1,13 +0,0 @@ -package com.naaturel.ANN.domain.abstraction; - -public abstract class NeuronTrainer { - - private Trainable trainable; - - public NeuronTrainer(Trainable trainable){ - this.trainable = trainable; - } - - public abstract void train(); - -} diff --git a/src/main/java/com/naaturel/ANN/domain/abstraction/Trainable.java b/src/main/java/com/naaturel/ANN/domain/abstraction/Trainable.java index 78827f4..76e7f83 100644 --- a/src/main/java/com/naaturel/ANN/domain/abstraction/Trainable.java +++ b/src/main/java/com/naaturel/ANN/domain/abstraction/Trainable.java @@ -1,7 +1,14 @@ package com.naaturel.ANN.domain.abstraction; +import com.naaturel.ANN.domain.model.neuron.Input; +import com.naaturel.ANN.domain.model.neuron.Synapse; + +import java.util.List; +import java.util.function.Consumer; + public interface Trainable { + List predict(List inputs); - + void forEachSynapse(Consumer consumer); } diff --git a/src/main/java/com/naaturel/ANN/domain/abstraction/Trainer.java b/src/main/java/com/naaturel/ANN/domain/abstraction/Trainer.java index f8c44dc..eec3555 100644 --- a/src/main/java/com/naaturel/ANN/domain/abstraction/Trainer.java +++ b/src/main/java/com/naaturel/ANN/domain/abstraction/Trainer.java @@ -1,8 +1,10 @@ package com.naaturel.ANN.domain.abstraction; -import com.naaturel.ANN.domain.model.dataset.DataSet; +import com.naaturel.ANN.domain.model.training.TrainingContext; + +import java.util.List; public interface Trainer { - void train(Neuron n, float learningRate, DataSet dataSet); + void train(TrainingContext context, List steps); } diff --git a/src/main/java/com/naaturel/ANN/domain/abstraction/TrainingStep.java b/src/main/java/com/naaturel/ANN/domain/abstraction/TrainingStep.java new file mode 100644 index 0000000..448ddae --- /dev/null +++ b/src/main/java/com/naaturel/ANN/domain/abstraction/TrainingStep.java @@ -0,0 +1,9 @@ +package com.naaturel.ANN.domain.abstraction; + +import com.naaturel.ANN.domain.model.training.TrainingContext; + +public interface TrainingStep { + + void run(TrainingContext ctx); + +} diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java new file mode 100644 index 0000000..21f3ea3 --- /dev/null +++ b/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java @@ -0,0 +1,33 @@ +package com.naaturel.ANN.domain.model.neuron; + +import com.naaturel.ANN.domain.abstraction.CorrectionStrategy; +import com.naaturel.ANN.domain.abstraction.Neuron; +import com.naaturel.ANN.domain.abstraction.Trainable; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Consumer; + +public class Layer implements Trainable { + + private final List neurons; + + public Layer(List neurons) { + this.neurons = neurons; + } + + @Override + public List predict(List inputs) { + List result = new ArrayList<>(); + for(Neuron neuron : this.neurons){ + List res = neuron.predict(inputs); + result.addAll(res); + } + return result; + } + + @Override + public void forEachSynapse(Consumer consumer) { + this.neurons.forEach(neuron -> neuron.forEachSynapse(consumer)); + } +} diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/Network.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/Network.java new file mode 100644 index 0000000..6f7697f --- /dev/null +++ b/src/main/java/com/naaturel/ANN/domain/model/neuron/Network.java @@ -0,0 +1,31 @@ +package com.naaturel.ANN.domain.model.neuron; + +import com.naaturel.ANN.domain.abstraction.Trainable; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Consumer; + +public class Network implements Trainable { + + private final List layers; + + public Network(List layers) { + this.layers = layers; + } + + @Override + public List predict(List inputs) { + List result = new ArrayList<>(); + for(Layer layer : this.layers){ + List res = layer.predict(inputs); + result.addAll(res); + } + return result; + } + + @Override + public void forEachSynapse(Consumer consumer) { + this.layers.forEach(layer -> layer.forEachSynapse(consumer)); + } +} diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/Synapse.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/Synapse.java index e6ed930..88a017c 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/neuron/Synapse.java +++ b/src/main/java/com/naaturel/ANN/domain/model/neuron/Synapse.java @@ -14,8 +14,8 @@ public class Synapse { return this.input.getValue(); } - public void setInput(float value){ - this.input.setValue(value); + public void setInput(Input input){ + this.input.setValue(input.getValue()); } public float getWeight() { diff --git a/src/main/java/com/naaturel/ANN/domain/model/training/TrainingContext.java b/src/main/java/com/naaturel/ANN/domain/model/training/TrainingContext.java new file mode 100644 index 0000000..989c731 --- /dev/null +++ b/src/main/java/com/naaturel/ANN/domain/model/training/TrainingContext.java @@ -0,0 +1,19 @@ +package com.naaturel.ANN.domain.model.training; + +import com.naaturel.ANN.domain.abstraction.Trainable; +import com.naaturel.ANN.domain.model.dataset.DataSet; +import com.naaturel.ANN.domain.model.dataset.DataSetEntry; + +public class TrainingContext { + public Trainable model; + public DataSet dataset; + public DataSetEntry currentEntry; + + public float prediction; + public float delta; + public float localLoss; + public float globalLoss; + public float learningRate; + + public int epoch; +} diff --git a/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java b/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java new file mode 100644 index 0000000..3c5f372 --- /dev/null +++ b/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java @@ -0,0 +1,68 @@ +package com.naaturel.ANN.domain.model.training; + +import com.naaturel.ANN.domain.abstraction.TrainingStep; +import com.naaturel.ANN.domain.model.dataset.DataSetEntry; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Consumer; +import java.util.function.Predicate; + +public class TrainingPipeline { + + private final List steps; + private Consumer afterAll; + private Predicate stopCondition; + + private boolean verbose; + private boolean timeMeasurement; + + public TrainingPipeline(List steps) { + this.steps = new ArrayList<>(steps); + } + + public TrainingPipeline stopCondition(Predicate predicate) { + this.stopCondition = predicate; + return this; + } + + public TrainingPipeline afterEpoch(Consumer consumer) { + this.afterAll = consumer; + return this; + } + + public TrainingPipeline withVerbose(boolean enabled) { + this.verbose = enabled; + return this; + } + + public TrainingPipeline withTimeMeasurement(boolean enabled) { + this.timeMeasurement = enabled; + return this; + } + + public void run(TrainingContext ctx) { + do { + this.executeSteps(ctx); + if(this.afterAll != null) { + this.afterAll.accept(ctx); + } + } while (!this.stopCondition.test(ctx)); + } + + private void executeSteps(TrainingContext ctx){ + for (DataSetEntry sample : ctx.dataset) { + ctx.currentEntry = sample; + for (TrainingStep step : steps) { + step.run(ctx); + if(this.verbose) { + System.out.printf("Epoch : %d, ", ctx.epoch); + System.out.printf("predicted : %.2f, ", ctx.prediction); + System.out.printf("expected : %.2f, ", ctx.dataset.getLabel(ctx.currentEntry).getValue()); + System.out.printf("delta : %.2f\n", ctx.delta); + } + } + } + ctx.epoch += 1; + } +} diff --git a/src/main/java/com/naaturel/ANN/implementation/activationFunction/Heaviside.java b/src/main/java/com/naaturel/ANN/implementation/activation/Heaviside.java similarity index 85% rename from src/main/java/com/naaturel/ANN/implementation/activationFunction/Heaviside.java rename to src/main/java/com/naaturel/ANN/implementation/activation/Heaviside.java index aae8b52..baaf10f 100644 --- a/src/main/java/com/naaturel/ANN/implementation/activationFunction/Heaviside.java +++ b/src/main/java/com/naaturel/ANN/implementation/activation/Heaviside.java @@ -1,4 +1,4 @@ -package com.naaturel.ANN.implementation.activationFunction; +package com.naaturel.ANN.implementation.activation; import com.naaturel.ANN.domain.abstraction.ActivationFunction; import com.naaturel.ANN.domain.abstraction.Neuron; diff --git a/src/main/java/com/naaturel/ANN/implementation/activationFunction/Linear.java b/src/main/java/com/naaturel/ANN/implementation/activation/Linear.java similarity index 81% rename from src/main/java/com/naaturel/ANN/implementation/activationFunction/Linear.java rename to src/main/java/com/naaturel/ANN/implementation/activation/Linear.java index c280dac..8268bf5 100644 --- a/src/main/java/com/naaturel/ANN/implementation/activationFunction/Linear.java +++ b/src/main/java/com/naaturel/ANN/implementation/activation/Linear.java @@ -1,4 +1,4 @@ -package com.naaturel.ANN.implementation.activationFunction; +package com.naaturel.ANN.implementation.activation; import com.naaturel.ANN.domain.abstraction.ActivationFunction; import com.naaturel.ANN.domain.abstraction.Neuron; diff --git a/src/main/java/com/naaturel/ANN/implementation/correction/SimpleCorrectionStrategy.java b/src/main/java/com/naaturel/ANN/implementation/correction/SimpleCorrectionStrategy.java new file mode 100644 index 0000000..47e9fa9 --- /dev/null +++ b/src/main/java/com/naaturel/ANN/implementation/correction/SimpleCorrectionStrategy.java @@ -0,0 +1,17 @@ +package com.naaturel.ANN.implementation.correction; + +import com.naaturel.ANN.domain.abstraction.CorrectionStrategy; +import com.naaturel.ANN.domain.model.training.TrainingContext; + +public class SimpleCorrectionStrategy implements CorrectionStrategy { + + @Override + public void apply(TrainingContext context) { + context.model.forEachSynapse(syn -> { + float currentW = syn.getWeight(); + float currentInput = syn.getInput(); + float newValue = currentW + (context.learningRate * context.delta * currentInput); + syn.setWeight(newValue); + }); + } +} diff --git a/src/main/java/com/naaturel/ANN/implementation/neuron/SimplePerceptron.java b/src/main/java/com/naaturel/ANN/implementation/neuron/SimplePerceptron.java index 6f6dae0..b628a33 100644 --- a/src/main/java/com/naaturel/ANN/implementation/neuron/SimplePerceptron.java +++ b/src/main/java/com/naaturel/ANN/implementation/neuron/SimplePerceptron.java @@ -1,22 +1,32 @@ package com.naaturel.ANN.implementation.neuron; import com.naaturel.ANN.domain.abstraction.ActivationFunction; +import com.naaturel.ANN.domain.abstraction.CorrectionStrategy; import com.naaturel.ANN.domain.abstraction.Neuron; import com.naaturel.ANN.domain.abstraction.Trainable; import com.naaturel.ANN.domain.model.neuron.Bias; +import com.naaturel.ANN.domain.model.neuron.Input; import com.naaturel.ANN.domain.model.neuron.Synapse; +import com.naaturel.ANN.domain.model.neuron.Weight; import java.util.List; +import java.util.function.Consumer; -public class SimplePerceptron extends Neuron implements Trainable { +public class SimplePerceptron extends Neuron { public SimplePerceptron(List synapses, Bias b, ActivationFunction func) { super(synapses, b, func); } @Override - public float predict() { - return activationFunction.accept(this); + public List predict(List inputs) { + super.setInputs(inputs); + return List.of(activationFunction.accept(this)); + } + + @Override + public void forEachSynapse(Consumer consumer) { + this.synapses.forEach(consumer); } @Override diff --git a/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java index 6c22a5d..5a7fb70 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java @@ -9,7 +9,7 @@ import com.naaturel.ANN.domain.model.neuron.Synapse; import com.naaturel.ANN.domain.model.neuron.Weight; -public class AdalineTraining implements Trainer { +/*public class AdalineTraining implements Trainer { public AdalineTraining(){ @@ -78,4 +78,4 @@ public class AdalineTraining implements Trainer { return (float) Math.pow(delta, 2)/2; } -} +}*/ diff --git a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java index fce18a2..4aeb4e0 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java @@ -13,7 +13,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -public class GradientDescentTraining implements Trainer { +/*public class GradientDescentTraining implements Trainer { public GradientDescentTraining(){ @@ -122,4 +122,4 @@ public class GradientDescentTraining implements Trainer { return variance; } -} +}*/ diff --git a/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java index 1fea010..357c7a9 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java @@ -1,12 +1,15 @@ package com.naaturel.ANN.implementation.training; -import com.naaturel.ANN.domain.abstraction.Neuron; import com.naaturel.ANN.domain.abstraction.Trainer; +import com.naaturel.ANN.domain.abstraction.TrainingStep; import com.naaturel.ANN.domain.model.dataset.DataSet; -import com.naaturel.ANN.domain.model.dataset.DataSetEntry; -import com.naaturel.ANN.domain.model.neuron.Input; -import com.naaturel.ANN.domain.model.neuron.Synapse; -import com.naaturel.ANN.domain.model.neuron.Weight; +import com.naaturel.ANN.domain.model.neuron.Network; +import com.naaturel.ANN.domain.model.training.TrainingContext; +import com.naaturel.ANN.domain.model.training.TrainingPipeline; +import com.naaturel.ANN.implementation.correction.SimpleCorrectionStrategy; +import com.naaturel.ANN.implementation.training.steps.*; + +import java.util.List; public class SimpleTraining implements Trainer { @@ -14,7 +17,12 @@ public class SimpleTraining implements Trainer { } - public void train(Neuron n, float learningRate, DataSet dataSet) { + @Override + public void train(TrainingContext context, List steps) { + + } + + /*public void train(Neuron n, float learningRate, DataSet dataSet) { int epoch = 1; int errorCount; @@ -65,5 +73,5 @@ public class SimpleTraining implements Trainer { private float calculateLoss(float delta){ return Math.abs(delta); } - +*/ } diff --git a/src/main/java/com/naaturel/ANN/implementation/training/steps/DeltaStep.java b/src/main/java/com/naaturel/ANN/implementation/training/steps/DeltaStep.java new file mode 100644 index 0000000..4938f40 --- /dev/null +++ b/src/main/java/com/naaturel/ANN/implementation/training/steps/DeltaStep.java @@ -0,0 +1,19 @@ +package com.naaturel.ANN.implementation.training.steps; + +import com.naaturel.ANN.domain.abstraction.TrainingStep; +import com.naaturel.ANN.domain.model.dataset.DataSet; +import com.naaturel.ANN.domain.model.dataset.DataSetEntry; +import com.naaturel.ANN.domain.model.dataset.Label; +import com.naaturel.ANN.domain.model.training.TrainingContext; + +public class DeltaStep implements TrainingStep { + + @Override + public void run(TrainingContext ctx) { + DataSet dataSet = ctx.dataset; + DataSetEntry entry = ctx.currentEntry; + Label label = dataSet.getLabel(entry); + + ctx.delta = label.getValue() - ctx.prediction; + } +} diff --git a/src/main/java/com/naaturel/ANN/implementation/training/steps/PredictionStep.java b/src/main/java/com/naaturel/ANN/implementation/training/steps/PredictionStep.java new file mode 100644 index 0000000..a1e29b2 --- /dev/null +++ b/src/main/java/com/naaturel/ANN/implementation/training/steps/PredictionStep.java @@ -0,0 +1,21 @@ +package com.naaturel.ANN.implementation.training.steps; + +import com.naaturel.ANN.domain.abstraction.TrainingStep; +import com.naaturel.ANN.domain.model.neuron.Input; +import com.naaturel.ANN.domain.model.training.TrainingContext; + +import java.util.ArrayList; +import java.util.List; + +public class PredictionStep implements TrainingStep { + + @Override + public void run(TrainingContext ctx) { + List inputs = new ArrayList<>(); + for(Float f : ctx.currentEntry.getData()){ + inputs.add(new Input(f)); + } + List predictions = ctx.model.predict(inputs); + ctx.prediction = predictions.getFirst(); + } +} diff --git a/src/main/java/com/naaturel/ANN/implementation/training/steps/SimpleErrorDetectionStep.java b/src/main/java/com/naaturel/ANN/implementation/training/steps/SimpleErrorDetectionStep.java new file mode 100644 index 0000000..0d411b4 --- /dev/null +++ b/src/main/java/com/naaturel/ANN/implementation/training/steps/SimpleErrorDetectionStep.java @@ -0,0 +1,13 @@ +package com.naaturel.ANN.implementation.training.steps; + +import com.naaturel.ANN.domain.abstraction.TrainingStep; +import com.naaturel.ANN.domain.model.training.TrainingContext; + +public class SimpleErrorDetectionStep implements TrainingStep { + + @Override + public void run(TrainingContext ctx) { + ctx.globalLoss += ctx.localLoss; + } + +} diff --git a/src/main/java/com/naaturel/ANN/implementation/training/steps/SimpleLossStep.java b/src/main/java/com/naaturel/ANN/implementation/training/steps/SimpleLossStep.java new file mode 100644 index 0000000..f46d815 --- /dev/null +++ b/src/main/java/com/naaturel/ANN/implementation/training/steps/SimpleLossStep.java @@ -0,0 +1,12 @@ +package com.naaturel.ANN.implementation.training.steps; + +import com.naaturel.ANN.domain.abstraction.TrainingStep; +import com.naaturel.ANN.domain.model.training.TrainingContext; + +public class SimpleLossStep implements TrainingStep { + + @Override + public void run(TrainingContext ctx) { + ctx.localLoss = Math.abs(ctx.delta); + } +} diff --git a/src/main/java/com/naaturel/ANN/implementation/training/steps/WeightCorrectionStep.java b/src/main/java/com/naaturel/ANN/implementation/training/steps/WeightCorrectionStep.java new file mode 100644 index 0000000..00996ab --- /dev/null +++ b/src/main/java/com/naaturel/ANN/implementation/training/steps/WeightCorrectionStep.java @@ -0,0 +1,19 @@ +package com.naaturel.ANN.implementation.training.steps; + +import com.naaturel.ANN.domain.abstraction.CorrectionStrategy; +import com.naaturel.ANN.domain.abstraction.TrainingStep; +import com.naaturel.ANN.domain.model.training.TrainingContext; + +public class WeightCorrectionStep implements TrainingStep { + + private final CorrectionStrategy correctionStrategy; + + public WeightCorrectionStep(CorrectionStrategy strategy) { + this.correctionStrategy = strategy; + } + + @Override + public void run(TrainingContext ctx) { + this.correctionStrategy.apply(ctx); + } +}