Just a regular commit

This commit is contained in:
2026-03-28 17:53:21 +01:00
parent 17cff89b44
commit 83526b72d4
5 changed files with 56 additions and 5 deletions

View File

@@ -6,6 +6,9 @@ import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.function.Consumer; import java.util.function.Consumer;
/**
* Represents a fully connected neural network
*/
public class Network implements Model { public class Network implements Model {
private final List<Layer> layers; private final List<Layer> layers;
@@ -16,11 +19,12 @@ public class Network implements Model {
@Override @Override
public List<Float> predict(List<Input> inputs) { public List<Float> predict(List<Input> inputs) {
List<Input> currentLayerOutput = new ArrayList<>(inputs); List<Input> previousLayerOutput = new ArrayList<>(inputs);
for(Layer layer : this.layers){ for(Layer layer : this.layers){
currentLayerOutput = layer.predict(currentLayerOutput).stream().map(Input::new).toList(); List<Float> currentLayerOutput = layer.predict(previousLayerOutput);
previousLayerOutput = currentLayerOutput.stream().map(Input::new).toList();
} }
return currentLayerOutput.stream().map(Input::getValue).toList(); return previousLayerOutput.stream().map(Input::getValue).toList();
} }
@Override @Override

View File

@@ -0,0 +1,6 @@
package com.naaturel.ANN.implementation.multiLayers;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
public class GradientBackpropagationContext extends TrainingContext {
}

View File

@@ -0,0 +1,17 @@
package com.naaturel.ANN.implementation.multiLayers;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
public class GradientBackpropagationStrategy implements AlgorithmStrategy {
private GradientBackpropagationContext context;
public GradientBackpropagationStrategy(GradientBackpropagationContext context) {
this.context = context;
}
@Override
public void apply() {
}
}

View File

@@ -2,11 +2,35 @@ package com.naaturel.ANN.implementation.training;
import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext;
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy;
import com.naaturel.ANN.implementation.training.steps.DeltaStep;
import com.naaturel.ANN.implementation.training.steps.PredictionStep;
import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.infrastructure.dataset.DataSet;
import java.util.List;
public class GradientBackpropagationTraining implements Trainer { public class GradientBackpropagationTraining implements Trainer {
@Override @Override
public void train(Model model, DataSet dataset) { public void train(Model model, DataSet dataset) {
TrainingContext context = new GradientDescentTrainingContext();
context.dataset = dataset;
context.model = model;
context.learningRate = 0.0008F;
List<TrainingStep> steps = List.of(
new PredictionStep(new SimplePredictionStrategy(context)),
new DeltaStep()
);
new TrainingPipeline(steps)
.stopCondition(ctx -> false)
.withVerbose(true)
.run(context);
} }
} }

View File

@@ -10,9 +10,9 @@ import java.util.List;
public class PredictionStep implements TrainingStep { public class PredictionStep implements TrainingStep {
private final SimplePredictionStrategy strategy; private final AlgorithmStrategy strategy;
public PredictionStep(SimplePredictionStrategy strategy) { public PredictionStep(AlgorithmStrategy strategy) {
this.strategy = strategy; this.strategy = strategy;
} }