From 0fe309cd4e75fa2fbb238ddc6b4557f41572746c Mon Sep 17 00:00:00 2001 From: Laurent Date: Sun, 29 Mar 2026 21:32:08 +0200 Subject: [PATCH] Rename some stuff --- src/main/java/com/naaturel/ANN/Main.java | 7 +- .../abstraction/ActivationFunction.java | 1 + ...orithmStrategy.java => AlgorithmStep.java} | 4 +- .../ANN/domain/abstraction/Model.java | 5 +- .../ANN/domain/abstraction/Network.java | 10 +++ .../ANN/domain/abstraction/TrainingStep.java | 4 +- .../model/neuron/FullyConnectedNetwork.java | 84 +++++++++++++++++++ .../ANN/domain/model/neuron/Layer.java | 15 ++++ .../ANN/domain/model/neuron/Network.java | 48 ----------- .../ANN/domain/model/neuron/Neuron.java | 44 +++++++--- .../model/training/TrainingPipeline.java | 7 +- .../GradientDescentCorrectionStrategy.java | 6 +- .../GradientDescentErrorStrategy.java | 6 +- .../gradientDescent/Linear.java | 15 +++- ...eLossStrategy.java => SquareLossStep.java} | 9 +- .../GradientBackpropagationContext.java | 9 ++ .../GradientBackpropagationStep.java | 24 ++++++ .../GradientBackpropagationStrategy.java | 17 ---- .../implementation/multiLayers/Sigmoid.java | 5 ++ .../ANN/implementation/multiLayers/TanH.java | 4 + .../simplePerceptron/Heaviside.java | 7 ++ ...trategy.java => SimpleCorrectionStep.java} | 8 +- ...eltaStrategy.java => SimpleDeltaStep.java} | 9 +- ....java => SimpleErrorRegistrationStep.java} | 8 +- .../simplePerceptron/SimpleLossStrategy.java | 6 +- ...trategy.java => SimplePredictionStep.java} | 10 +-- .../training/AdalineTraining.java | 25 +++--- .../GradientBackpropagationTraining.java | 22 ++--- .../training/GradientDescentTraining.java | 20 ++--- .../training/SimpleTraining.java | 14 ++-- .../training/steps/DeltaStep.java | 8 +- .../training/steps/ErrorRegistrationStep.java | 8 +- .../training/steps/LossStep.java | 9 +- .../training/steps/PredictionStep.java | 13 +-- .../training/steps/WeightCorrectionStep.java | 8 +- src/test/java/adaline/AdalineTest.java | 24 +++--- .../gradientDescent/GradientDescentTest.java | 14 ++-- .../java/perceptron/SimplePerceptronTest.java | 12 +-- 38 files changed, 334 insertions(+), 215 deletions(-) rename src/main/java/com/naaturel/ANN/domain/abstraction/{AlgorithmStrategy.java => AlgorithmStep.java} (56%) create mode 100644 src/main/java/com/naaturel/ANN/domain/abstraction/Network.java create mode 100644 src/main/java/com/naaturel/ANN/domain/model/neuron/FullyConnectedNetwork.java delete mode 100644 src/main/java/com/naaturel/ANN/domain/model/neuron/Network.java rename src/main/java/com/naaturel/ANN/implementation/gradientDescent/{SquareLossStrategy.java => SquareLossStep.java} (61%) create mode 100644 src/main/java/com/naaturel/ANN/implementation/multiLayers/GradientBackpropagationStep.java delete mode 100644 src/main/java/com/naaturel/ANN/implementation/multiLayers/GradientBackpropagationStrategy.java rename src/main/java/com/naaturel/ANN/implementation/simplePerceptron/{SimpleCorrectionStrategy.java => SimpleCorrectionStep.java} (81%) rename src/main/java/com/naaturel/ANN/implementation/simplePerceptron/{SimpleDeltaStrategy.java => SimpleDeltaStep.java} (78%) rename src/main/java/com/naaturel/ANN/implementation/simplePerceptron/{SimpleErrorRegistrationStrategy.java => SimpleErrorRegistrationStep.java} (54%) rename src/main/java/com/naaturel/ANN/implementation/simplePerceptron/{SimplePredictionStrategy.java => SimplePredictionStep.java} (56%) diff --git a/src/main/java/com/naaturel/ANN/Main.java b/src/main/java/com/naaturel/ANN/Main.java index d1735b0..8276e19 100644 --- a/src/main/java/com/naaturel/ANN/Main.java +++ b/src/main/java/com/naaturel/ANN/Main.java @@ -3,13 +3,10 @@ package com.naaturel.ANN; import com.naaturel.ANN.domain.model.neuron.Neuron; import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.implementation.multiLayers.Sigmoid; -import com.naaturel.ANN.implementation.multiLayers.TanH; import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining; import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor; import com.naaturel.ANN.domain.model.neuron.*; -import com.naaturel.ANN.implementation.gradientDescent.Linear; -import com.naaturel.ANN.implementation.training.GradientDescentTraining; import java.util.*; @@ -20,7 +17,7 @@ public class Main { int nbrInput = 2; int nbrClass = 3; - int nbrLayers = 1; + int nbrLayers = 2; DataSet dataset = new DatasetExtractor() .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_3_1.csv", nbrClass); @@ -44,7 +41,7 @@ public class Main { Layer layer = new Layer(neurons); layers.add(layer); } - Network network = new Network(layers); + FullyConnectedNetwork network = new FullyConnectedNetwork(layers); Trainer trainer = new GradientBackpropagationTraining(); trainer.train(network, dataset); diff --git a/src/main/java/com/naaturel/ANN/domain/abstraction/ActivationFunction.java b/src/main/java/com/naaturel/ANN/domain/abstraction/ActivationFunction.java index d47ac8a..bf4491b 100644 --- a/src/main/java/com/naaturel/ANN/domain/abstraction/ActivationFunction.java +++ b/src/main/java/com/naaturel/ANN/domain/abstraction/ActivationFunction.java @@ -5,5 +5,6 @@ import com.naaturel.ANN.domain.model.neuron.Neuron; public interface ActivationFunction { float accept(Neuron n); + float derivative(float value); } diff --git a/src/main/java/com/naaturel/ANN/domain/abstraction/AlgorithmStrategy.java b/src/main/java/com/naaturel/ANN/domain/abstraction/AlgorithmStep.java similarity index 56% rename from src/main/java/com/naaturel/ANN/domain/abstraction/AlgorithmStrategy.java rename to src/main/java/com/naaturel/ANN/domain/abstraction/AlgorithmStep.java index 2f79614..9e3e00d 100644 --- a/src/main/java/com/naaturel/ANN/domain/abstraction/AlgorithmStrategy.java +++ b/src/main/java/com/naaturel/ANN/domain/abstraction/AlgorithmStep.java @@ -1,8 +1,8 @@ package com.naaturel.ANN.domain.abstraction; @FunctionalInterface -public interface AlgorithmStrategy { +public interface AlgorithmStep { - void apply(); + void run(); } diff --git a/src/main/java/com/naaturel/ANN/domain/abstraction/Model.java b/src/main/java/com/naaturel/ANN/domain/abstraction/Model.java index 011618c..d32e093 100644 --- a/src/main/java/com/naaturel/ANN/domain/abstraction/Model.java +++ b/src/main/java/com/naaturel/ANN/domain/abstraction/Model.java @@ -5,13 +5,14 @@ import com.naaturel.ANN.domain.model.neuron.Neuron; import com.naaturel.ANN.domain.model.neuron.Synapse; import java.util.List; -import java.util.function.BiConsumer; import java.util.function.Consumer; public interface Model { int synCount(); + int neuronCount(); void forEachNeuron(Consumer consumer); void forEachSynapse(Consumer consumer); + void forEachOutputNeurons(Consumer consumer); + void forEachNeuronConnectedTo(Neuron n, Consumer consumer); List predict(List inputs); - } diff --git a/src/main/java/com/naaturel/ANN/domain/abstraction/Network.java b/src/main/java/com/naaturel/ANN/domain/abstraction/Network.java new file mode 100644 index 0000000..d4145a6 --- /dev/null +++ b/src/main/java/com/naaturel/ANN/domain/abstraction/Network.java @@ -0,0 +1,10 @@ +package com.naaturel.ANN.domain.abstraction; + +import com.naaturel.ANN.domain.model.neuron.Neuron; + +import java.util.function.Consumer; + +public interface Network { + + +} diff --git a/src/main/java/com/naaturel/ANN/domain/abstraction/TrainingStep.java b/src/main/java/com/naaturel/ANN/domain/abstraction/TrainingStep.java index b0b865b..735d4cf 100644 --- a/src/main/java/com/naaturel/ANN/domain/abstraction/TrainingStep.java +++ b/src/main/java/com/naaturel/ANN/domain/abstraction/TrainingStep.java @@ -1,7 +1,7 @@ package com.naaturel.ANN.domain.abstraction; -public interface TrainingStep { +/*public interface TrainingStep { void run(); -} +}*/ diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/FullyConnectedNetwork.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/FullyConnectedNetwork.java new file mode 100644 index 0000000..8018d5a --- /dev/null +++ b/src/main/java/com/naaturel/ANN/domain/model/neuron/FullyConnectedNetwork.java @@ -0,0 +1,84 @@ +package com.naaturel.ANN.domain.model.neuron; + +import com.naaturel.ANN.domain.abstraction.Model; +import com.naaturel.ANN.domain.abstraction.Network; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; + +/** + * Represents a fully connected neural network + */ +public class FullyConnectedNetwork implements Model { + + private final List layers;; + private final Map> connectionMap; + + public FullyConnectedNetwork(List layers) { + this.layers = layers; + this.connectionMap = this.createConnectionMap(); + } + + @Override + public List predict(List inputs) { + List previousLayerOutputs = new ArrayList<>(inputs); + for(Layer layer : this.layers){ + List currentLayerOutputs = layer.predict(previousLayerOutputs); + previousLayerOutputs = currentLayerOutputs.stream().map(Input::new).toList(); + } + return previousLayerOutputs.stream().map(Input::getValue).toList(); + } + + @Override + public int synCount() { + int res = 0; + for(Layer layer : this.layers){ + res += layer.synCount(); + } + return res; + } + + @Override + public int neuronCount() { + int res = 0; + for(Layer layer : this.layers){ + res += layer.neuronCount(); + } + return res; + } + + @Override + public void forEachSynapse(Consumer consumer) { + this.layers.forEach(layer -> layer.forEachSynapse(consumer)); + } + + @Override + public void forEachNeuron(Consumer consumer) { + this.layers.forEach(layer -> layer.forEachNeuron(consumer)); + } + + @Override + public void forEachOutputNeurons(Consumer consumer) { + this.layers.getLast().forEachNeuron(consumer); + } + + @Override + public void forEachNeuronConnectedTo(Neuron n, Consumer consumer) { + this.connectionMap.get(n).forEach(consumer); + } + + private Map> createConnectionMap() { + Map> res = new HashMap<>(); + + for (int i = 0; i < this.layers.size() - 1; i++) { + List nextLayerNeurons = new ArrayList<>(); + this.layers.get(i + 1).forEachNeuron(nextLayerNeurons::add); + this.layers.get(i).forEachNeuron(n -> res.put(n, nextLayerNeurons)); + } + + return res; + } +} diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java index 6768ecd..36340e0 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java +++ b/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java @@ -33,6 +33,11 @@ public class Layer implements Model { return res; } + @Override + public int neuronCount() { + return this.neurons.size(); + } + @Override public void forEachNeuron(Consumer consumer) { this.neurons.forEach(consumer); @@ -42,4 +47,14 @@ public class Layer implements Model { public void forEachSynapse(Consumer consumer) { this.neurons.forEach(neuron -> neuron.forEachSynapse(consumer)); } + + @Override + public void forEachOutputNeurons(Consumer consumer) { + this.neurons.forEach(consumer); + } + + @Override + public void forEachNeuronConnectedTo(Neuron n, Consumer consumer) { + throw new UnsupportedOperationException("Neurons have no connection within the same layer"); + } } diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/Network.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/Network.java deleted file mode 100644 index 7423776..0000000 --- a/src/main/java/com/naaturel/ANN/domain/model/neuron/Network.java +++ /dev/null @@ -1,48 +0,0 @@ -package com.naaturel.ANN.domain.model.neuron; - -import com.naaturel.ANN.domain.abstraction.Model; - -import java.util.ArrayList; -import java.util.List; -import java.util.function.Consumer; - -/** - * Represents a fully connected neural network - */ -public class Network implements Model { - - private final List layers; - - public Network(List layers) { - this.layers = layers; - } - - @Override - public List predict(List inputs) { - List previousLayerOutput = new ArrayList<>(inputs); - for(Layer layer : this.layers){ - List currentLayerOutput = layer.predict(previousLayerOutput); - previousLayerOutput = currentLayerOutput.stream().map(Input::new).toList(); - } - return previousLayerOutput.stream().map(Input::getValue).toList(); - } - - @Override - public int synCount() { - int res = 0; - for(Layer layer : this.layers){ - res += layer.synCount(); - } - return res; - } - - @Override - public void forEachNeuron(Consumer consumer) { - this.layers.forEach(layer -> layer.forEachNeuron(consumer)); - } - - @Override - public void forEachSynapse(Consumer consumer) { - this.layers.forEach(layer -> layer.forEachSynapse(consumer)); - } -} diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java index e85174c..9893cf8 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java +++ b/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java @@ -3,7 +3,6 @@ import com.naaturel.ANN.domain.abstraction.ActivationFunction; import com.naaturel.ANN.domain.abstraction.Model; import java.util.List; -import java.util.function.BiConsumer; import java.util.function.Consumer; public class Neuron implements Model { @@ -11,11 +10,13 @@ public class Neuron implements Model { protected List synapses; protected Bias bias; protected ActivationFunction activationFunction; + protected Float output; public Neuron(List synapses, Bias bias, ActivationFunction func){ this.synapses = synapses; this.bias = bias; this.activationFunction = func; + this.output = 0F; } public void updateBias(Weight weight) { @@ -33,15 +34,38 @@ public class Neuron implements Model { } } + public ActivationFunction getActivationFunction(){ + return this.activationFunction; + } + + public float getOutput(){ + return this.output; + } + + public float calculateWeightedSum() { + float res = 0; + res += this.bias.getWeight() * this.bias.getInput(); + for(Synapse syn : this.synapses){ + res += syn.getWeight() * syn.getInput(); + } + return res; + } + @Override public int synCount() { - return this.synapses.size()+1; //take the bias in account + return this.synapses.size()+1; //take the bias into account + } + + @Override + public int neuronCount() { + return 1; } @Override public List predict(List inputs) { this.setInputs(inputs); - return List.of(activationFunction.accept(this)); + this.output = activationFunction.accept(this); + return List.of(output); } @Override @@ -55,13 +79,13 @@ public class Neuron implements Model { this.synapses.forEach(consumer); } - public float calculateWeightedSum() { - float res = 0; - res += this.bias.getWeight() * this.bias.getInput(); - for(Synapse syn : this.synapses){ - res += syn.getWeight() * syn.getInput(); - } - return res; + @Override + public void forEachOutputNeurons(Consumer consumer) { + consumer.accept(this); } + @Override + public void forEachNeuronConnectedTo(Neuron n, Consumer consumer) { + throw new UnsupportedOperationException("Neurons have no connection with themselves"); + } } diff --git a/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java b/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java index 8975d94..7f40cbc 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java +++ b/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java @@ -1,5 +1,6 @@ package com.naaturel.ANN.domain.model.training; +import com.naaturel.ANN.domain.abstraction.AlgorithmStep; import com.naaturel.ANN.domain.abstraction.TrainingContext; import com.naaturel.ANN.domain.abstraction.TrainingStep; import com.naaturel.ANN.infrastructure.dataset.DataSetEntry; @@ -15,7 +16,7 @@ import java.util.function.Predicate; public class TrainingPipeline { - private final List steps; + private final List steps; private Consumer beforeEpoch; private Consumer afterEpoch; private Predicate stopCondition; @@ -25,7 +26,7 @@ public class TrainingPipeline { private boolean visualization; private boolean timeMeasurement; - public TrainingPipeline(List steps) { + public TrainingPipeline(List steps) { this.steps = new ArrayList<>(steps); this.stopCondition = (ctx) -> false; this.beforeEpoch = (context -> {}); @@ -90,7 +91,7 @@ public class TrainingPipeline { ctx.currentEntry = entry; ctx.expectations = ctx.dataset.getLabelsAsFloat(entry); - for (TrainingStep step : steps) { + for (AlgorithmStep step : steps) { step.run(); } diff --git a/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentCorrectionStrategy.java b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentCorrectionStrategy.java index ed11af3..f76726a 100644 --- a/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentCorrectionStrategy.java +++ b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentCorrectionStrategy.java @@ -1,10 +1,10 @@ package com.naaturel.ANN.implementation.gradientDescent; -import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; +import com.naaturel.ANN.domain.abstraction.AlgorithmStep; import java.util.concurrent.atomic.AtomicInteger; -public class GradientDescentCorrectionStrategy implements AlgorithmStrategy { +public class GradientDescentCorrectionStrategy implements AlgorithmStep { private final GradientDescentTrainingContext context; @@ -13,7 +13,7 @@ public class GradientDescentCorrectionStrategy implements AlgorithmStrategy { } @Override - public void apply() { + public void run() { AtomicInteger i = new AtomicInteger(0); context.model.forEachSynapse(syn -> { float corrector = context.correctorTerms.get(i.get()); diff --git a/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentErrorStrategy.java b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentErrorStrategy.java index 411cb2a..101bbef 100644 --- a/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentErrorStrategy.java +++ b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentErrorStrategy.java @@ -1,10 +1,10 @@ package com.naaturel.ANN.implementation.gradientDescent; -import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; +import com.naaturel.ANN.domain.abstraction.AlgorithmStep; import java.util.concurrent.atomic.AtomicInteger; -public class GradientDescentErrorStrategy implements AlgorithmStrategy { +public class GradientDescentErrorStrategy implements AlgorithmStep { private final GradientDescentTrainingContext context; @@ -14,7 +14,7 @@ public class GradientDescentErrorStrategy implements AlgorithmStrategy { @Override - public void apply() { + public void run() { AtomicInteger neuronIndex = new AtomicInteger(0); AtomicInteger synIndex = new AtomicInteger(0); diff --git a/src/main/java/com/naaturel/ANN/implementation/gradientDescent/Linear.java b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/Linear.java index 18c7acc..2caff73 100644 --- a/src/main/java/com/naaturel/ANN/implementation/gradientDescent/Linear.java +++ b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/Linear.java @@ -5,9 +5,22 @@ import com.naaturel.ANN.domain.model.neuron.Neuron; public class Linear implements ActivationFunction { + private final float slope; + private final float intercept; + + public Linear(float slope, float intercept) { + this.slope = slope; + this.intercept = intercept; + } + @Override public float accept(Neuron n) { - return n.calculateWeightedSum(); + return slope * n.calculateWeightedSum() + intercept; + } + + @Override + public float derivative(float value) { + return this.slope; } } diff --git a/src/main/java/com/naaturel/ANN/implementation/gradientDescent/SquareLossStrategy.java b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/SquareLossStep.java similarity index 61% rename from src/main/java/com/naaturel/ANN/implementation/gradientDescent/SquareLossStrategy.java rename to src/main/java/com/naaturel/ANN/implementation/gradientDescent/SquareLossStep.java index fa15fd7..cbced9e 100644 --- a/src/main/java/com/naaturel/ANN/implementation/gradientDescent/SquareLossStrategy.java +++ b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/SquareLossStep.java @@ -1,21 +1,20 @@ package com.naaturel.ANN.implementation.gradientDescent; -import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; +import com.naaturel.ANN.domain.abstraction.AlgorithmStep; import com.naaturel.ANN.domain.abstraction.TrainingContext; -import com.naaturel.ANN.implementation.simplePerceptron.SimpleTrainingContext; import java.util.stream.Stream; -public class SquareLossStrategy implements AlgorithmStrategy { +public class SquareLossStep implements AlgorithmStep { private final TrainingContext context; - public SquareLossStrategy(TrainingContext context) { + public SquareLossStep(TrainingContext context) { this.context = context; } @Override - public void apply() { + public void run() { Stream deltaStream = this.context.deltas.stream(); this.context.localLoss = deltaStream.reduce(0.0F, (acc, d) -> (float) (acc + Math.pow(d, 2))); this.context.localLoss /= 2; diff --git a/src/main/java/com/naaturel/ANN/implementation/multiLayers/GradientBackpropagationContext.java b/src/main/java/com/naaturel/ANN/implementation/multiLayers/GradientBackpropagationContext.java index b559692..b2c22c7 100644 --- a/src/main/java/com/naaturel/ANN/implementation/multiLayers/GradientBackpropagationContext.java +++ b/src/main/java/com/naaturel/ANN/implementation/multiLayers/GradientBackpropagationContext.java @@ -2,5 +2,14 @@ package com.naaturel.ANN.implementation.multiLayers; import com.naaturel.ANN.domain.abstraction.TrainingContext; +import java.util.ArrayList; +import java.util.List; + public class GradientBackpropagationContext extends TrainingContext { + + public List hiddenDeltas; + + public GradientBackpropagationContext(){ + } + } diff --git a/src/main/java/com/naaturel/ANN/implementation/multiLayers/GradientBackpropagationStep.java b/src/main/java/com/naaturel/ANN/implementation/multiLayers/GradientBackpropagationStep.java new file mode 100644 index 0000000..dd0b552 --- /dev/null +++ b/src/main/java/com/naaturel/ANN/implementation/multiLayers/GradientBackpropagationStep.java @@ -0,0 +1,24 @@ +package com.naaturel.ANN.implementation.multiLayers; + +import com.naaturel.ANN.domain.abstraction.AlgorithmStep; +import com.naaturel.ANN.domain.model.neuron.Neuron; + +public class GradientBackpropagationStep implements AlgorithmStep { + + private GradientBackpropagationContext context; + public GradientBackpropagationStep(GradientBackpropagationContext context) { + this.context = context; + } + + @Override + public void run() { + + + } + + + private float calculateDeltaRecursive(Neuron n){ + + } + +} diff --git a/src/main/java/com/naaturel/ANN/implementation/multiLayers/GradientBackpropagationStrategy.java b/src/main/java/com/naaturel/ANN/implementation/multiLayers/GradientBackpropagationStrategy.java deleted file mode 100644 index 11a4574..0000000 --- a/src/main/java/com/naaturel/ANN/implementation/multiLayers/GradientBackpropagationStrategy.java +++ /dev/null @@ -1,17 +0,0 @@ -package com.naaturel.ANN.implementation.multiLayers; - -import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; - -public class GradientBackpropagationStrategy implements AlgorithmStrategy { - - private GradientBackpropagationContext context; - - public GradientBackpropagationStrategy(GradientBackpropagationContext context) { - this.context = context; - } - - @Override - public void apply() { - - } -} diff --git a/src/main/java/com/naaturel/ANN/implementation/multiLayers/Sigmoid.java b/src/main/java/com/naaturel/ANN/implementation/multiLayers/Sigmoid.java index 01149cd..95ae036 100644 --- a/src/main/java/com/naaturel/ANN/implementation/multiLayers/Sigmoid.java +++ b/src/main/java/com/naaturel/ANN/implementation/multiLayers/Sigmoid.java @@ -15,4 +15,9 @@ public class Sigmoid implements ActivationFunction { public float accept(Neuron n) { return (float) (1.0/(1.0 + Math.exp(-steepness * n.calculateWeightedSum()))); } + + @Override + public float derivative(float value) { + return steepness * value * (1 - value); + } } diff --git a/src/main/java/com/naaturel/ANN/implementation/multiLayers/TanH.java b/src/main/java/com/naaturel/ANN/implementation/multiLayers/TanH.java index a3b8508..59e7b16 100644 --- a/src/main/java/com/naaturel/ANN/implementation/multiLayers/TanH.java +++ b/src/main/java/com/naaturel/ANN/implementation/multiLayers/TanH.java @@ -14,4 +14,8 @@ public class TanH implements ActivationFunction { return (float)(res); } + @Override + public float derivative(float value) { + return 1 - value * value; + } } diff --git a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/Heaviside.java b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/Heaviside.java index af0df44..5f97847 100644 --- a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/Heaviside.java +++ b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/Heaviside.java @@ -3,6 +3,8 @@ package com.naaturel.ANN.implementation.simplePerceptron; import com.naaturel.ANN.domain.abstraction.ActivationFunction; import com.naaturel.ANN.domain.model.neuron.Neuron; +import javax.naming.OperationNotSupportedException; + public class Heaviside implements ActivationFunction { public Heaviside(){ @@ -14,4 +16,9 @@ public class Heaviside implements ActivationFunction { float weightedSum = n.calculateWeightedSum(); return weightedSum < 0 ? 0:1; } + + @Override + public float derivative(float value) { + throw new UnsupportedOperationException("Heaviside is not differentiable"); + } } diff --git a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleCorrectionStrategy.java b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleCorrectionStep.java similarity index 81% rename from src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleCorrectionStrategy.java rename to src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleCorrectionStep.java index 26441a7..81e8f3c 100644 --- a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleCorrectionStrategy.java +++ b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleCorrectionStep.java @@ -1,21 +1,21 @@ package com.naaturel.ANN.implementation.simplePerceptron; -import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; +import com.naaturel.ANN.domain.abstraction.AlgorithmStep; import com.naaturel.ANN.domain.abstraction.TrainingContext; import java.util.concurrent.atomic.AtomicInteger; -public class SimpleCorrectionStrategy implements AlgorithmStrategy { +public class SimpleCorrectionStep implements AlgorithmStep { private final TrainingContext context; - public SimpleCorrectionStrategy(TrainingContext context) { + public SimpleCorrectionStep(TrainingContext context) { this.context = context; } @Override - public void apply() { + public void run() { if(context.expectations.equals(context.predictions)) return; AtomicInteger neuronIndex = new AtomicInteger(0); AtomicInteger synIndex = new AtomicInteger(0); diff --git a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleDeltaStrategy.java b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleDeltaStep.java similarity index 78% rename from src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleDeltaStrategy.java rename to src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleDeltaStep.java index 108798b..10b0300 100644 --- a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleDeltaStrategy.java +++ b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleDeltaStep.java @@ -1,6 +1,6 @@ package com.naaturel.ANN.implementation.simplePerceptron; -import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; +import com.naaturel.ANN.domain.abstraction.AlgorithmStep; import com.naaturel.ANN.domain.abstraction.TrainingContext; import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.infrastructure.dataset.DataSetEntry; @@ -9,16 +9,16 @@ import java.util.List; import java.util.stream.Collectors; import java.util.stream.IntStream; -public class SimpleDeltaStrategy implements AlgorithmStrategy { +public class SimpleDeltaStep implements AlgorithmStep { private final TrainingContext context; - public SimpleDeltaStrategy(TrainingContext context) { + public SimpleDeltaStep(TrainingContext context) { this.context = context; } @Override - public void apply() { + public void run() { DataSet dataSet = context.dataset; DataSetEntry entry = context.currentEntry; List predicted = context.predictions; @@ -28,7 +28,6 @@ public class SimpleDeltaStrategy implements AlgorithmStrategy { context.deltas = IntStream.range(0, predicted.size()) .mapToObj(i -> expected.get(i) - predicted.get(i)) .collect(Collectors.toList()); - System.out.printf(""); } } diff --git a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleErrorRegistrationStrategy.java b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleErrorRegistrationStep.java similarity index 54% rename from src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleErrorRegistrationStrategy.java rename to src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleErrorRegistrationStep.java index c526435..07f5295 100644 --- a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleErrorRegistrationStrategy.java +++ b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleErrorRegistrationStep.java @@ -1,18 +1,18 @@ package com.naaturel.ANN.implementation.simplePerceptron; -import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; +import com.naaturel.ANN.domain.abstraction.AlgorithmStep; import com.naaturel.ANN.domain.abstraction.TrainingContext; -public class SimpleErrorRegistrationStrategy implements AlgorithmStrategy { +public class SimpleErrorRegistrationStep implements AlgorithmStep { private final TrainingContext context; - public SimpleErrorRegistrationStrategy(TrainingContext context) { + public SimpleErrorRegistrationStep(TrainingContext context) { this.context = context; } @Override - public void apply() { + public void run() { context.globalLoss += context.localLoss; } } diff --git a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleLossStrategy.java b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleLossStrategy.java index 4e4247f..562eec6 100644 --- a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleLossStrategy.java +++ b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleLossStrategy.java @@ -1,8 +1,8 @@ package com.naaturel.ANN.implementation.simplePerceptron; -import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; +import com.naaturel.ANN.domain.abstraction.AlgorithmStep; -public class SimpleLossStrategy implements AlgorithmStrategy { +public class SimpleLossStrategy implements AlgorithmStep { private final SimpleTrainingContext context; @@ -11,7 +11,7 @@ public class SimpleLossStrategy implements AlgorithmStrategy { } @Override - public void apply() { + public void run() { this.context.localLoss = this.context.deltas.stream().reduce(0.0F, Float::sum); } } diff --git a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimplePredictionStrategy.java b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimplePredictionStep.java similarity index 56% rename from src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimplePredictionStrategy.java rename to src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimplePredictionStep.java index bf234ca..ab25f23 100644 --- a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimplePredictionStrategy.java +++ b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimplePredictionStep.java @@ -1,20 +1,18 @@ package com.naaturel.ANN.implementation.simplePerceptron; -import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; +import com.naaturel.ANN.domain.abstraction.AlgorithmStep; import com.naaturel.ANN.domain.abstraction.TrainingContext; -import java.util.List; - -public class SimplePredictionStrategy implements AlgorithmStrategy { +public class SimplePredictionStep implements AlgorithmStep { private final TrainingContext context; - public SimplePredictionStrategy(TrainingContext context) { + public SimplePredictionStep(TrainingContext context) { this.context = context; } @Override - public void apply() { + public void run() { context.predictions = context.model.predict(context.currentEntry.getData()); } } diff --git a/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java index 58d434d..b641df8 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java @@ -1,17 +1,16 @@ package com.naaturel.ANN.implementation.training; +import com.naaturel.ANN.domain.abstraction.AlgorithmStep; import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.abstraction.Trainer; -import com.naaturel.ANN.domain.abstraction.TrainingStep; import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.domain.model.training.TrainingPipeline; import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext; -import com.naaturel.ANN.implementation.gradientDescent.SquareLossStrategy; -import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStrategy; -import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStrategy; -import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStrategy; -import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy; -import com.naaturel.ANN.implementation.training.steps.*; +import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep; +import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStep; +import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep; +import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStep; +import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep; import com.naaturel.ANN.infrastructure.graph.GraphVisualizer; import java.util.List; @@ -30,12 +29,12 @@ public class AdalineTraining implements Trainer { context.model = model; context.learningRate = 0.003F; - List steps = List.of( - new PredictionStep(new SimplePredictionStrategy(context)), - new DeltaStep(new SimpleDeltaStrategy(context)), - new LossStep(new SquareLossStrategy(context)), - new ErrorRegistrationStep(new SimpleErrorRegistrationStrategy(context)), - new WeightCorrectionStep(new SimpleCorrectionStrategy(context)) + List steps = List.of( + new SimplePredictionStep(context), + new SimpleDeltaStep(context), + new SquareLossStep(context), + new SimpleErrorRegistrationStep(context), + new SimpleCorrectionStep(context) ); new TrainingPipeline(steps) diff --git a/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java index 3670ecd..20fc04d 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java @@ -1,14 +1,13 @@ package com.naaturel.ANN.implementation.training; +import com.naaturel.ANN.domain.abstraction.AlgorithmStep; import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.abstraction.Trainer; -import com.naaturel.ANN.domain.abstraction.TrainingContext; -import com.naaturel.ANN.domain.abstraction.TrainingStep; import com.naaturel.ANN.domain.model.training.TrainingPipeline; -import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext; -import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy; -import com.naaturel.ANN.implementation.training.steps.DeltaStep; -import com.naaturel.ANN.implementation.training.steps.PredictionStep; +import com.naaturel.ANN.implementation.multiLayers.GradientBackpropagationContext; +import com.naaturel.ANN.implementation.multiLayers.GradientBackpropagationStep; +import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep; +import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep; import com.naaturel.ANN.infrastructure.dataset.DataSet; import java.util.List; @@ -17,14 +16,15 @@ import java.util.List; public class GradientBackpropagationTraining implements Trainer { @Override public void train(Model model, DataSet dataset) { - TrainingContext context = new GradientDescentTrainingContext(); + GradientBackpropagationContext context = new GradientBackpropagationContext(); context.dataset = dataset; context.model = model; - context.learningRate = 0.0008F; + context.learningRate = 0.001F; - List steps = List.of( - new PredictionStep(new SimplePredictionStrategy(context)), - new DeltaStep() + List steps = List.of( + new SimplePredictionStep(context), + new SimpleDeltaStep(context), + new GradientBackpropagationStep(context) ); new TrainingPipeline(steps) diff --git a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java index bbcf03f..9f8abab 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java @@ -1,16 +1,16 @@ package com.naaturel.ANN.implementation.training; +import com.naaturel.ANN.domain.abstraction.AlgorithmStep; import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.abstraction.Trainer; -import com.naaturel.ANN.domain.abstraction.TrainingStep; import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.implementation.gradientDescent.GradientDescentErrorStrategy; import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext; import com.naaturel.ANN.domain.model.training.TrainingPipeline; import com.naaturel.ANN.implementation.gradientDescent.GradientDescentCorrectionStrategy; -import com.naaturel.ANN.implementation.gradientDescent.SquareLossStrategy; -import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStrategy; -import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy; +import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep; +import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep; +import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep; import com.naaturel.ANN.implementation.training.steps.*; import com.naaturel.ANN.infrastructure.graph.GraphVisualizer; @@ -31,11 +31,11 @@ public class GradientDescentTraining implements Trainer { context.learningRate = 0.0008F; context.correctorTerms = new ArrayList<>(); - List steps = List.of( - new PredictionStep(new SimplePredictionStrategy(context)), - new DeltaStep(new SimpleDeltaStrategy(context)), - new LossStep(new SquareLossStrategy(context)), - new ErrorRegistrationStep(new GradientDescentErrorStrategy(context)) + List steps = List.of( + new SimplePredictionStep(context), + new SimpleDeltaStep(context), + new SquareLossStep(context), + new GradientDescentErrorStrategy(context) ); new TrainingPipeline(steps) @@ -48,7 +48,7 @@ public class GradientDescentTraining implements Trainer { }) .afterEpoch(ctx -> { context.globalLoss /= context.dataset.size(); - new GradientDescentCorrectionStrategy(context).apply(); + new GradientDescentCorrectionStrategy(context).run(); }) //.withVerbose(true) .withTimeMeasurement(true) diff --git a/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java index a68ec32..ee59129 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java @@ -1,8 +1,8 @@ package com.naaturel.ANN.implementation.training; +import com.naaturel.ANN.domain.abstraction.AlgorithmStep; import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.abstraction.Trainer; -import com.naaturel.ANN.domain.abstraction.TrainingStep; import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.implementation.simplePerceptron.*; import com.naaturel.ANN.domain.model.training.TrainingPipeline; @@ -23,12 +23,12 @@ public class SimpleTraining implements Trainer { context.model = model; context.learningRate = 0.3F; - List steps = List.of( - new PredictionStep(new SimplePredictionStrategy(context)), - new DeltaStep(new SimpleDeltaStrategy(context)), - new LossStep(new SimpleLossStrategy(context)), - new ErrorRegistrationStep(new SimpleErrorRegistrationStrategy(context)), - new WeightCorrectionStep(new SimpleCorrectionStrategy(context)) + List steps = List.of( + new SimplePredictionStep(context), + new SimpleDeltaStep(context), + new SimpleLossStrategy(context), + new SimpleErrorRegistrationStep(context), + new SimpleCorrectionStep(context) ); TrainingPipeline pipeline = new TrainingPipeline(steps); diff --git a/src/main/java/com/naaturel/ANN/implementation/training/steps/DeltaStep.java b/src/main/java/com/naaturel/ANN/implementation/training/steps/DeltaStep.java index c5c377b..619e6f6 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/steps/DeltaStep.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/steps/DeltaStep.java @@ -1,18 +1,18 @@ package com.naaturel.ANN.implementation.training.steps; -import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; +import com.naaturel.ANN.domain.abstraction.AlgorithmStep; import com.naaturel.ANN.domain.abstraction.TrainingStep; public class DeltaStep implements TrainingStep { - private final AlgorithmStrategy strategy; + private final AlgorithmStep strategy; - public DeltaStep(AlgorithmStrategy strategy) { + public DeltaStep(AlgorithmStep strategy) { this.strategy = strategy; } @Override public void run() { - this.strategy.apply(); + this.strategy.run(); } } diff --git a/src/main/java/com/naaturel/ANN/implementation/training/steps/ErrorRegistrationStep.java b/src/main/java/com/naaturel/ANN/implementation/training/steps/ErrorRegistrationStep.java index cd32511..1db2a49 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/steps/ErrorRegistrationStep.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/steps/ErrorRegistrationStep.java @@ -1,18 +1,18 @@ package com.naaturel.ANN.implementation.training.steps; -import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; +import com.naaturel.ANN.domain.abstraction.AlgorithmStep; import com.naaturel.ANN.domain.abstraction.TrainingStep; public class ErrorRegistrationStep implements TrainingStep { - private final AlgorithmStrategy strategy; + private final AlgorithmStep strategy; - public ErrorRegistrationStep(AlgorithmStrategy strategy) { + public ErrorRegistrationStep(AlgorithmStep strategy) { this.strategy = strategy; } @Override public void run() { - this.strategy.apply(); + this.strategy.run(); } } diff --git a/src/main/java/com/naaturel/ANN/implementation/training/steps/LossStep.java b/src/main/java/com/naaturel/ANN/implementation/training/steps/LossStep.java index b047c34..62ca5d9 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/steps/LossStep.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/steps/LossStep.java @@ -1,20 +1,19 @@ package com.naaturel.ANN.implementation.training.steps; -import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; -import com.naaturel.ANN.domain.abstraction.TrainingContext; +import com.naaturel.ANN.domain.abstraction.AlgorithmStep; import com.naaturel.ANN.domain.abstraction.TrainingStep; public class LossStep implements TrainingStep { - private final AlgorithmStrategy lossStrategy; + private final AlgorithmStep lossStrategy; - public LossStep(AlgorithmStrategy strategy) { + public LossStep(AlgorithmStep strategy) { this.lossStrategy = strategy; } @Override public void run() { - this.lossStrategy.apply(); + this.lossStrategy.run(); } } diff --git a/src/main/java/com/naaturel/ANN/implementation/training/steps/PredictionStep.java b/src/main/java/com/naaturel/ANN/implementation/training/steps/PredictionStep.java index 43a179a..06a1601 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/steps/PredictionStep.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/steps/PredictionStep.java @@ -1,23 +1,18 @@ package com.naaturel.ANN.implementation.training.steps; -import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; -import com.naaturel.ANN.domain.abstraction.TrainingContext; +import com.naaturel.ANN.domain.abstraction.AlgorithmStep; import com.naaturel.ANN.domain.abstraction.TrainingStep; -import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy; -import com.naaturel.ANN.implementation.simplePerceptron.SimpleTrainingContext; - -import java.util.List; public class PredictionStep implements TrainingStep { - private final AlgorithmStrategy strategy; + private final AlgorithmStep strategy; - public PredictionStep(AlgorithmStrategy strategy) { + public PredictionStep(AlgorithmStep strategy) { this.strategy = strategy; } @Override public void run() { - this.strategy.apply(); + this.strategy.run(); } } diff --git a/src/main/java/com/naaturel/ANN/implementation/training/steps/WeightCorrectionStep.java b/src/main/java/com/naaturel/ANN/implementation/training/steps/WeightCorrectionStep.java index 0db68f6..515fc47 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/steps/WeightCorrectionStep.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/steps/WeightCorrectionStep.java @@ -1,18 +1,18 @@ package com.naaturel.ANN.implementation.training.steps; -import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; +import com.naaturel.ANN.domain.abstraction.AlgorithmStep; import com.naaturel.ANN.domain.abstraction.TrainingStep; public class WeightCorrectionStep implements TrainingStep { - private final AlgorithmStrategy correctionStrategy; + private final AlgorithmStep correctionStrategy; - public WeightCorrectionStep(AlgorithmStrategy strategy) { + public WeightCorrectionStep(AlgorithmStep strategy) { this.correctionStrategy = strategy; } @Override public void run() { - this.correctionStrategy.apply(); + this.correctionStrategy.run(); } } diff --git a/src/test/java/adaline/AdalineTest.java b/src/test/java/adaline/AdalineTest.java index 1cbe109..82b2c0c 100644 --- a/src/test/java/adaline/AdalineTest.java +++ b/src/test/java/adaline/AdalineTest.java @@ -9,10 +9,10 @@ import com.naaturel.ANN.domain.model.neuron.*; import com.naaturel.ANN.domain.model.training.TrainingPipeline; import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext; import com.naaturel.ANN.implementation.gradientDescent.*; -import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStrategy; -import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStrategy; -import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStrategy; -import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy; +import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStep; +import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep; +import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStep; +import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep; import com.naaturel.ANN.implementation.training.steps.*; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -29,7 +29,7 @@ public class AdalineTest { private List synapses; private Bias bias; - private Network network; + private FullyConnectedNetwork network; private TrainingPipeline pipeline; @@ -44,20 +44,20 @@ public class AdalineTest { bias = new Bias(new Weight(0)); - Neuron neuron = new Neuron(syns, bias, new Linear()); + Neuron neuron = new Neuron(syns, bias, new Linear(1, 0)); Layer layer = new Layer(List.of(neuron)); - network = new Network(List.of(layer)); + network = new FullyConnectedNetwork(List.of(layer)); context = new AdalineTrainingContext(); context.dataset = dataset; context.model = network; List steps = List.of( - new PredictionStep(new SimplePredictionStrategy(context)), - new DeltaStep(new SimpleDeltaStrategy(context)), - new LossStep(new SquareLossStrategy(context)), - new ErrorRegistrationStep(new SimpleErrorRegistrationStrategy(context)), - new WeightCorrectionStep(new SimpleCorrectionStrategy(context)) + new PredictionStep(new SimplePredictionStep(context)), + new DeltaStep(new SimpleDeltaStep(context)), + new LossStep(new SquareLossStep(context)), + new ErrorRegistrationStep(new SimpleErrorRegistrationStep(context)), + new WeightCorrectionStep(new SimpleCorrectionStep(context)) ); pipeline = new TrainingPipeline(steps) diff --git a/src/test/java/gradientDescent/GradientDescentTest.java b/src/test/java/gradientDescent/GradientDescentTest.java index 9952431..0312862 100644 --- a/src/test/java/gradientDescent/GradientDescentTest.java +++ b/src/test/java/gradientDescent/GradientDescentTest.java @@ -25,7 +25,7 @@ public class GradientDescentTest { private List synapses; private Bias bias; - private Network network; + private FullyConnectedNetwork network; private TrainingPipeline pipeline; @@ -40,9 +40,9 @@ public class GradientDescentTest { bias = new Bias(new Weight(0)); - Neuron neuron = new Neuron(syns, bias, new Linear()); + Neuron neuron = new Neuron(syns, bias, new Linear(1, 0)); Layer layer = new Layer(List.of(neuron)); - network = new Network(List.of(layer)); + network = new FullyConnectedNetwork(List.of(layer)); context = new GradientDescentTrainingContext(); context.dataset = dataset; @@ -50,9 +50,9 @@ public class GradientDescentTest { context.correctorTerms = new ArrayList<>(); List steps = List.of( - new PredictionStep(new SimplePredictionStrategy(context)), - new DeltaStep(new SimpleDeltaStrategy(context)), - new LossStep(new SquareLossStrategy(context)), + new PredictionStep(new SimplePredictionStep(context)), + new DeltaStep(new SimpleDeltaStep(context)), + new LossStep(new SquareLossStep(context)), new ErrorRegistrationStep(new GradientDescentErrorStrategy(context)) ); @@ -82,7 +82,7 @@ public class GradientDescentTest { context.learningRate = 0.2F; pipeline.afterEpoch(ctx -> { context.globalLoss /= context.dataset.size(); - new GradientDescentCorrectionStrategy(context).apply(); + new GradientDescentCorrectionStrategy(context).run(); int index = ctx.epoch-1; if(index >= expectedGlobalLosses.size()) return; diff --git a/src/test/java/perceptron/SimplePerceptronTest.java b/src/test/java/perceptron/SimplePerceptronTest.java index 2251988..cabd4e1 100644 --- a/src/test/java/perceptron/SimplePerceptronTest.java +++ b/src/test/java/perceptron/SimplePerceptronTest.java @@ -24,7 +24,7 @@ public class SimplePerceptronTest { private List synapses; private Bias bias; - private Network network; + private FullyConnectedNetwork network; private TrainingPipeline pipeline; @@ -41,18 +41,18 @@ public class SimplePerceptronTest { Neuron neuron = new Neuron(syns, bias, new Heaviside()); Layer layer = new Layer(List.of(neuron)); - network = new Network(List.of(layer)); + network = new FullyConnectedNetwork(List.of(layer)); context = new SimpleTrainingContext(); context.dataset = dataset; context.model = network; List steps = List.of( - new PredictionStep(new SimplePredictionStrategy(context)), - new DeltaStep(new SimpleDeltaStrategy(context)), + new PredictionStep(new SimplePredictionStep(context)), + new DeltaStep(new SimpleDeltaStep(context)), new LossStep(new SimpleLossStrategy(context)), - new ErrorRegistrationStep(new SimpleErrorRegistrationStrategy(context)), - new WeightCorrectionStep(new SimpleCorrectionStrategy(context)) + new ErrorRegistrationStep(new SimpleErrorRegistrationStep(context)), + new WeightCorrectionStep(new SimpleCorrectionStep(context)) ); pipeline = new TrainingPipeline(steps);