Just a regular commit
This commit is contained in:
@@ -6,6 +6,9 @@ import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
/**
|
||||
* Represents a fully connected neural network
|
||||
*/
|
||||
public class Network implements Model {
|
||||
|
||||
private final List<Layer> layers;
|
||||
@@ -16,11 +19,12 @@ public class Network implements Model {
|
||||
|
||||
@Override
|
||||
public List<Float> predict(List<Input> inputs) {
|
||||
List<Input> currentLayerOutput = new ArrayList<>(inputs);
|
||||
List<Input> previousLayerOutput = new ArrayList<>(inputs);
|
||||
for(Layer layer : this.layers){
|
||||
currentLayerOutput = layer.predict(currentLayerOutput).stream().map(Input::new).toList();
|
||||
List<Float> currentLayerOutput = layer.predict(previousLayerOutput);
|
||||
previousLayerOutput = currentLayerOutput.stream().map(Input::new).toList();
|
||||
}
|
||||
return currentLayerOutput.stream().map(Input::getValue).toList();
|
||||
return previousLayerOutput.stream().map(Input::getValue).toList();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
package com.naaturel.ANN.implementation.multiLayers;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
|
||||
public class GradientBackpropagationContext extends TrainingContext {
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package com.naaturel.ANN.implementation.multiLayers;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
|
||||
public class GradientBackpropagationStrategy implements AlgorithmStrategy {
|
||||
|
||||
private GradientBackpropagationContext context;
|
||||
|
||||
public GradientBackpropagationStrategy(GradientBackpropagationContext context) {
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void apply() {
|
||||
|
||||
}
|
||||
}
|
||||
@@ -2,11 +2,35 @@ package com.naaturel.ANN.implementation.training;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy;
|
||||
import com.naaturel.ANN.implementation.training.steps.DeltaStep;
|
||||
import com.naaturel.ANN.implementation.training.steps.PredictionStep;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
|
||||
public class GradientBackpropagationTraining implements Trainer {
|
||||
@Override
|
||||
public void train(Model model, DataSet dataset) {
|
||||
TrainingContext context = new GradientDescentTrainingContext();
|
||||
context.dataset = dataset;
|
||||
context.model = model;
|
||||
context.learningRate = 0.0008F;
|
||||
|
||||
List<TrainingStep> steps = List.of(
|
||||
new PredictionStep(new SimplePredictionStrategy(context)),
|
||||
new DeltaStep()
|
||||
);
|
||||
|
||||
new TrainingPipeline(steps)
|
||||
.stopCondition(ctx -> false)
|
||||
.withVerbose(true)
|
||||
.run(context);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,9 +10,9 @@ import java.util.List;
|
||||
|
||||
public class PredictionStep implements TrainingStep {
|
||||
|
||||
private final SimplePredictionStrategy strategy;
|
||||
private final AlgorithmStrategy strategy;
|
||||
|
||||
public PredictionStep(SimplePredictionStrategy strategy) {
|
||||
public PredictionStep(AlgorithmStrategy strategy) {
|
||||
this.strategy = strategy;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user