Just a regular commit
This commit is contained in:
@@ -6,6 +6,9 @@ import java.util.ArrayList;
|
|||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.function.Consumer;
|
import java.util.function.Consumer;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents a fully connected neural network
|
||||||
|
*/
|
||||||
public class Network implements Model {
|
public class Network implements Model {
|
||||||
|
|
||||||
private final List<Layer> layers;
|
private final List<Layer> layers;
|
||||||
@@ -16,11 +19,12 @@ public class Network implements Model {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<Float> predict(List<Input> inputs) {
|
public List<Float> predict(List<Input> inputs) {
|
||||||
List<Input> currentLayerOutput = new ArrayList<>(inputs);
|
List<Input> previousLayerOutput = new ArrayList<>(inputs);
|
||||||
for(Layer layer : this.layers){
|
for(Layer layer : this.layers){
|
||||||
currentLayerOutput = layer.predict(currentLayerOutput).stream().map(Input::new).toList();
|
List<Float> currentLayerOutput = layer.predict(previousLayerOutput);
|
||||||
|
previousLayerOutput = currentLayerOutput.stream().map(Input::new).toList();
|
||||||
}
|
}
|
||||||
return currentLayerOutput.stream().map(Input::getValue).toList();
|
return previousLayerOutput.stream().map(Input::getValue).toList();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -0,0 +1,6 @@
|
|||||||
|
package com.naaturel.ANN.implementation.multiLayers;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||||
|
|
||||||
|
public class GradientBackpropagationContext extends TrainingContext {
|
||||||
|
}
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
package com.naaturel.ANN.implementation.multiLayers;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||||
|
|
||||||
|
public class GradientBackpropagationStrategy implements AlgorithmStrategy {
|
||||||
|
|
||||||
|
private GradientBackpropagationContext context;
|
||||||
|
|
||||||
|
public GradientBackpropagationStrategy(GradientBackpropagationContext context) {
|
||||||
|
this.context = context;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void apply() {
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -2,11 +2,35 @@ package com.naaturel.ANN.implementation.training;
|
|||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.Model;
|
import com.naaturel.ANN.domain.abstraction.Model;
|
||||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
|
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||||
|
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext;
|
||||||
|
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy;
|
||||||
|
import com.naaturel.ANN.implementation.training.steps.DeltaStep;
|
||||||
|
import com.naaturel.ANN.implementation.training.steps.PredictionStep;
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
|
||||||
public class GradientBackpropagationTraining implements Trainer {
|
public class GradientBackpropagationTraining implements Trainer {
|
||||||
@Override
|
@Override
|
||||||
public void train(Model model, DataSet dataset) {
|
public void train(Model model, DataSet dataset) {
|
||||||
|
TrainingContext context = new GradientDescentTrainingContext();
|
||||||
|
context.dataset = dataset;
|
||||||
|
context.model = model;
|
||||||
|
context.learningRate = 0.0008F;
|
||||||
|
|
||||||
|
List<TrainingStep> steps = List.of(
|
||||||
|
new PredictionStep(new SimplePredictionStrategy(context)),
|
||||||
|
new DeltaStep()
|
||||||
|
);
|
||||||
|
|
||||||
|
new TrainingPipeline(steps)
|
||||||
|
.stopCondition(ctx -> false)
|
||||||
|
.withVerbose(true)
|
||||||
|
.run(context);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,9 +10,9 @@ import java.util.List;
|
|||||||
|
|
||||||
public class PredictionStep implements TrainingStep {
|
public class PredictionStep implements TrainingStep {
|
||||||
|
|
||||||
private final SimplePredictionStrategy strategy;
|
private final AlgorithmStrategy strategy;
|
||||||
|
|
||||||
public PredictionStep(SimplePredictionStrategy strategy) {
|
public PredictionStep(AlgorithmStrategy strategy) {
|
||||||
this.strategy = strategy;
|
this.strategy = strategy;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user