diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/Network.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/Network.java index 7210e5d..7423776 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/neuron/Network.java +++ b/src/main/java/com/naaturel/ANN/domain/model/neuron/Network.java @@ -6,6 +6,9 @@ import java.util.ArrayList; import java.util.List; import java.util.function.Consumer; +/** + * Represents a fully connected neural network + */ public class Network implements Model { private final List layers; @@ -16,11 +19,12 @@ public class Network implements Model { @Override public List predict(List inputs) { - List currentLayerOutput = new ArrayList<>(inputs); + List previousLayerOutput = new ArrayList<>(inputs); for(Layer layer : this.layers){ - currentLayerOutput = layer.predict(currentLayerOutput).stream().map(Input::new).toList(); + List currentLayerOutput = layer.predict(previousLayerOutput); + previousLayerOutput = currentLayerOutput.stream().map(Input::new).toList(); } - return currentLayerOutput.stream().map(Input::getValue).toList(); + return previousLayerOutput.stream().map(Input::getValue).toList(); } @Override diff --git a/src/main/java/com/naaturel/ANN/implementation/multiLayers/GradientBackpropagationContext.java b/src/main/java/com/naaturel/ANN/implementation/multiLayers/GradientBackpropagationContext.java new file mode 100644 index 0000000..b559692 --- /dev/null +++ b/src/main/java/com/naaturel/ANN/implementation/multiLayers/GradientBackpropagationContext.java @@ -0,0 +1,6 @@ +package com.naaturel.ANN.implementation.multiLayers; + +import com.naaturel.ANN.domain.abstraction.TrainingContext; + +public class GradientBackpropagationContext extends TrainingContext { +} diff --git a/src/main/java/com/naaturel/ANN/implementation/multiLayers/GradientBackpropagationStrategy.java b/src/main/java/com/naaturel/ANN/implementation/multiLayers/GradientBackpropagationStrategy.java new file mode 100644 index 0000000..11a4574 --- /dev/null +++ b/src/main/java/com/naaturel/ANN/implementation/multiLayers/GradientBackpropagationStrategy.java @@ -0,0 +1,17 @@ +package com.naaturel.ANN.implementation.multiLayers; + +import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; + +public class GradientBackpropagationStrategy implements AlgorithmStrategy { + + private GradientBackpropagationContext context; + + public GradientBackpropagationStrategy(GradientBackpropagationContext context) { + this.context = context; + } + + @Override + public void apply() { + + } +} diff --git a/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java index 45452df..3670ecd 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java @@ -2,11 +2,35 @@ package com.naaturel.ANN.implementation.training; import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.abstraction.Trainer; +import com.naaturel.ANN.domain.abstraction.TrainingContext; +import com.naaturel.ANN.domain.abstraction.TrainingStep; +import com.naaturel.ANN.domain.model.training.TrainingPipeline; +import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext; +import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy; +import com.naaturel.ANN.implementation.training.steps.DeltaStep; +import com.naaturel.ANN.implementation.training.steps.PredictionStep; import com.naaturel.ANN.infrastructure.dataset.DataSet; +import java.util.List; + + public class GradientBackpropagationTraining implements Trainer { @Override public void train(Model model, DataSet dataset) { + TrainingContext context = new GradientDescentTrainingContext(); + context.dataset = dataset; + context.model = model; + context.learningRate = 0.0008F; + + List steps = List.of( + new PredictionStep(new SimplePredictionStrategy(context)), + new DeltaStep() + ); + + new TrainingPipeline(steps) + .stopCondition(ctx -> false) + .withVerbose(true) + .run(context); } } diff --git a/src/main/java/com/naaturel/ANN/implementation/training/steps/PredictionStep.java b/src/main/java/com/naaturel/ANN/implementation/training/steps/PredictionStep.java index b598a15..43a179a 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/steps/PredictionStep.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/steps/PredictionStep.java @@ -10,9 +10,9 @@ import java.util.List; public class PredictionStep implements TrainingStep { - private final SimplePredictionStrategy strategy; + private final AlgorithmStrategy strategy; - public PredictionStep(SimplePredictionStrategy strategy) { + public PredictionStep(AlgorithmStrategy strategy) { this.strategy = strategy; }