Fix implementation
This commit is contained in:
@@ -2,28 +2,26 @@ package com.naaturel.ANN;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||
import com.naaturel.ANN.domain.model.dataset.DatasetExtractor;
|
||||
import com.naaturel.ANN.domain.model.dataset.Label;
|
||||
import com.naaturel.ANN.domain.model.neuron.*;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||
import com.naaturel.ANN.implementation.activation.Heaviside;
|
||||
import com.naaturel.ANN.implementation.correction.SimpleCorrectionStrategy;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.Linear;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.Heaviside;
|
||||
import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
|
||||
import com.naaturel.ANN.implementation.training.GradientDescentTraining;
|
||||
import com.naaturel.ANN.implementation.training.SimpleTraining;
|
||||
import com.naaturel.ANN.implementation.training.steps.*;
|
||||
|
||||
import javax.xml.crypto.Data;
|
||||
import java.util.*;
|
||||
|
||||
public class Main {
|
||||
|
||||
public static void main(String[] args){
|
||||
|
||||
DataSet dataset = new DatasetExtractor().extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/or.csv");
|
||||
DataSet dataset = new DatasetExtractor()
|
||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_9.csv");
|
||||
|
||||
DataSet orDataset = new DatasetExtractor()
|
||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/or.csv");
|
||||
|
||||
List<Synapse> syns = new ArrayList<>();
|
||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||
@@ -31,11 +29,11 @@ public class Main {
|
||||
|
||||
Bias bias = new Bias(new Weight(0));
|
||||
|
||||
Neuron neuron = new SimplePerceptron(syns, bias, new Heaviside());
|
||||
Neuron neuron = new SimplePerceptron(syns, bias, new Linear());
|
||||
Layer layer = new Layer(List.of(neuron));
|
||||
Network network = new Network(List.of(layer));
|
||||
|
||||
Trainer trainer = new SimpleTraining();
|
||||
Trainer trainer = new GradientDescentTraining();
|
||||
trainer.train(network, dataset);
|
||||
|
||||
}
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
package com.naaturel.ANN.domain.abstraction;
|
||||
|
||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
||||
|
||||
public interface AlgorithmStrategy {
|
||||
|
||||
void apply(TrainingContext ctx);
|
||||
void apply();
|
||||
|
||||
}
|
||||
|
||||
@@ -6,9 +6,9 @@ import com.naaturel.ANN.domain.model.neuron.Synapse;
|
||||
import java.util.List;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
public interface Trainable {
|
||||
public interface Model {
|
||||
int synCount();
|
||||
void applyOnSynapses(Consumer<Synapse> consumer);
|
||||
List<Float> predict(List<Input> inputs);
|
||||
|
||||
void applyOnSynapses(Consumer<Synapse> consumer);
|
||||
|
||||
}
|
||||
@@ -4,10 +4,9 @@ import com.naaturel.ANN.domain.model.neuron.Input;
|
||||
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
||||
import com.naaturel.ANN.domain.model.neuron.Weight;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public abstract class Neuron implements Trainable {
|
||||
public abstract class Neuron implements Model {
|
||||
|
||||
protected List<Synapse> synapses;
|
||||
protected Bias bias;
|
||||
@@ -35,4 +34,9 @@ public abstract class Neuron implements Trainable {
|
||||
syn.setInput(inputs.get(i));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public int synCount() {
|
||||
return this.synapses.size()+1; //take the bias in account
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,5 +3,5 @@ package com.naaturel.ANN.domain.abstraction;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||
|
||||
public interface Trainer {
|
||||
void train(Trainable model, DataSet dataset);
|
||||
void train(Model model, DataSet dataset);
|
||||
}
|
||||
|
||||
@@ -1,21 +1,21 @@
|
||||
package com.naaturel.ANN.domain.model.training;
|
||||
package com.naaturel.ANN.domain.abstraction;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.Trainable;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||
import com.naaturel.ANN.domain.model.dataset.Label;
|
||||
|
||||
public class TrainingContext {
|
||||
public Trainable model;
|
||||
public abstract class TrainingContext {
|
||||
public Model model;
|
||||
public DataSet dataset;
|
||||
public DataSetEntry currentEntry;
|
||||
public Label currentLabel;
|
||||
|
||||
public Label currentLabel;
|
||||
public float prediction;
|
||||
public float delta;
|
||||
public float localLoss;
|
||||
public float globalLoss;
|
||||
public float learningRate;
|
||||
|
||||
public float globalLoss;
|
||||
public float localLoss;
|
||||
|
||||
public float learningRate;
|
||||
public int epoch;
|
||||
}
|
||||
@@ -1,9 +1,7 @@
|
||||
package com.naaturel.ANN.domain.abstraction;
|
||||
|
||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
||||
|
||||
public interface TrainingStep {
|
||||
|
||||
void run(TrainingContext ctx);
|
||||
void run();
|
||||
|
||||
}
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
package com.naaturel.ANN.domain.model.neuron;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||
import com.naaturel.ANN.domain.abstraction.Trainable;
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
public class Layer implements Trainable {
|
||||
public class Layer implements Model {
|
||||
|
||||
private final List<Neuron> neurons;
|
||||
|
||||
@@ -25,6 +25,15 @@ public class Layer implements Trainable {
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int synCount() {
|
||||
int res = 0;
|
||||
for (Neuron neuron : this.neurons) {
|
||||
res += neuron.synCount();
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void applyOnSynapses(Consumer<Synapse> consumer) {
|
||||
this.neurons.forEach(neuron -> neuron.applyOnSynapses(consumer));
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
package com.naaturel.ANN.domain.model.neuron;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.Trainable;
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
public class Network implements Trainable {
|
||||
public class Network implements Model {
|
||||
|
||||
private final List<Layer> layers;
|
||||
|
||||
@@ -24,6 +24,15 @@ public class Network implements Trainable {
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int synCount() {
|
||||
int res = 0;
|
||||
for(Layer layer : this.layers){
|
||||
res += layer.synCount();
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void applyOnSynapses(Consumer<Synapse> consumer) {
|
||||
this.layers.forEach(layer -> layer.applyOnSynapses(consumer));
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.naaturel.ANN.domain.model.training;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||
|
||||
@@ -55,6 +56,9 @@ public class TrainingPipeline {
|
||||
this.beforeEpoch.accept(ctx);
|
||||
this.executeSteps(ctx);
|
||||
this.afterEpoch.accept(ctx);
|
||||
if(this.verbose) {
|
||||
System.out.printf("[Global error] : %.2f\n", ctx.globalLoss);
|
||||
}
|
||||
} while (!this.stopCondition.test(ctx));
|
||||
}
|
||||
|
||||
@@ -63,18 +67,16 @@ public class TrainingPipeline {
|
||||
ctx.currentEntry = entry;
|
||||
ctx.currentLabel = ctx.dataset.getLabel(entry);
|
||||
for (TrainingStep step : steps) {
|
||||
step.run(ctx);
|
||||
step.run();
|
||||
}
|
||||
if(this.verbose) {
|
||||
System.out.printf("Epoch : %d, ", ctx.epoch);
|
||||
System.out.printf("predicted : %.2f, ", ctx.prediction);
|
||||
System.out.printf("expected : %.2f, ", ctx.currentLabel.getValue());
|
||||
System.out.printf("delta : %.2f\n", ctx.delta);
|
||||
System.out.printf("delta : %.2f, ", ctx.delta);
|
||||
System.out.printf("loss : %.5f\n", ctx.localLoss);
|
||||
}
|
||||
}
|
||||
if(this.verbose) {
|
||||
System.out.printf("[Global error] : %.2f\n", ctx.globalLoss);
|
||||
}
|
||||
ctx.epoch += 1;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
package com.naaturel.ANN.implementation.correction;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class GradientDescentCorrectionStrategy implements AlgorithmStrategy {
|
||||
|
||||
List<Float> correctorTerms;
|
||||
|
||||
public GradientDescentCorrectionStrategy(int nbrCorrectors){
|
||||
this.correctorTerms = new ArrayList<>();
|
||||
for (int i = 0; i < nbrCorrectors; i++){
|
||||
correctorTerms.add(0F);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void apply(TrainingContext context) {
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package com.naaturel.ANN.implementation.gradientDescent;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
public class GradientDescentCorrectionStrategy implements AlgorithmStrategy {
|
||||
|
||||
private final GradientDescentTrainingContext context;
|
||||
|
||||
public GradientDescentCorrectionStrategy(GradientDescentTrainingContext context) {
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void apply() {
|
||||
AtomicInteger i = new AtomicInteger(0);
|
||||
context.model.applyOnSynapses(syn -> {
|
||||
float corrector = context.correctorTerms.get(i.get());
|
||||
float c = syn.getWeight() + corrector;
|
||||
syn.setWeight(c);
|
||||
i.incrementAndGet();
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package com.naaturel.ANN.implementation.gradientDescent;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
public class GradientDescentErrorStrategy implements AlgorithmStrategy {
|
||||
|
||||
private final GradientDescentTrainingContext context;
|
||||
|
||||
public GradientDescentErrorStrategy(GradientDescentTrainingContext context) {
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void apply() {
|
||||
AtomicInteger i = new AtomicInteger(0);
|
||||
context.model.applyOnSynapses(syn -> {
|
||||
float corrector = context.correctorTerms.get(i.get());
|
||||
corrector += context.learningRate * context.delta * syn.getInput();
|
||||
context.correctorTerms.set(i.get(), corrector);
|
||||
i.incrementAndGet();
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
package com.naaturel.ANN.implementation.gradientDescent;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class GradientDescentTrainingContext extends TrainingContext {
|
||||
|
||||
public List<Float> correctorTerms;
|
||||
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.naaturel.ANN.implementation.activation;
|
||||
package com.naaturel.ANN.implementation.gradientDescent;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||
@@ -0,0 +1,19 @@
|
||||
package com.naaturel.ANN.implementation.gradientDescent;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleTrainingContext;
|
||||
|
||||
public class SquareLossStrategy implements AlgorithmStrategy {
|
||||
|
||||
private final GradientDescentTrainingContext context;
|
||||
|
||||
public SquareLossStrategy(GradientDescentTrainingContext context) {
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void apply() {
|
||||
this.context.localLoss = (float)Math.pow(this.context.delta, 2)/2;
|
||||
this.context.globalLoss += context.localLoss;
|
||||
}
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
package com.naaturel.ANN.implementation.loss;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
||||
|
||||
public class SimpleLossStrategy implements AlgorithmStrategy {
|
||||
@Override
|
||||
public void apply(TrainingContext ctx) {
|
||||
ctx.localLoss = Math.abs(ctx.delta);
|
||||
}
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
package com.naaturel.ANN.implementation.loss;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
||||
|
||||
public class SquareLossStrategy implements AlgorithmStrategy {
|
||||
@Override
|
||||
public void apply(TrainingContext ctx) {
|
||||
ctx.localLoss = (float)Math.pow(ctx.delta, 2) / 2;
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.naaturel.ANN.implementation.activation;
|
||||
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||
@@ -1,14 +1,18 @@
|
||||
package com.naaturel.ANN.implementation.correction;
|
||||
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
||||
|
||||
public class SimpleCorrectionStrategy implements AlgorithmStrategy {
|
||||
|
||||
@Override
|
||||
public void apply(TrainingContext context) {
|
||||
if(context.currentLabel.getValue() == context.prediction) return ;
|
||||
private final SimpleTrainingContext context;
|
||||
|
||||
public SimpleCorrectionStrategy(SimpleTrainingContext context) {
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void apply() {
|
||||
if(context.currentLabel.getValue() == context.prediction) return ;
|
||||
context.model.applyOnSynapses(syn -> {
|
||||
float currentW = syn.getWeight();
|
||||
float currentInput = syn.getInput();
|
||||
@@ -0,0 +1,26 @@
|
||||
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||
import com.naaturel.ANN.domain.model.dataset.Label;
|
||||
|
||||
public class SimpleDeltaStrategy implements AlgorithmStrategy {
|
||||
|
||||
private final TrainingContext context;
|
||||
|
||||
public SimpleDeltaStrategy(TrainingContext context) {
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void apply() {
|
||||
DataSet dataSet = context.dataset;
|
||||
DataSetEntry entry = context.currentEntry;
|
||||
Label label = dataSet.getLabel(entry);
|
||||
|
||||
context.delta = label.getValue() - context.prediction;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
|
||||
public class SimpleErrorRegistrationStrategy implements AlgorithmStrategy {
|
||||
|
||||
private final SimpleTrainingContext context;
|
||||
|
||||
public SimpleErrorRegistrationStrategy(SimpleTrainingContext context) {
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void apply() {
|
||||
context.globalLoss += context.localLoss;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
|
||||
public class SimpleLossStrategy implements AlgorithmStrategy {
|
||||
|
||||
private final SimpleTrainingContext context;
|
||||
|
||||
public SimpleLossStrategy(SimpleTrainingContext context) {
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void apply() {
|
||||
this.context.localLoss = Math.abs(this.context.delta);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class SimplePredictionStrategy implements AlgorithmStrategy {
|
||||
|
||||
private final TrainingContext context;
|
||||
|
||||
public SimplePredictionStrategy(TrainingContext context) {
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void apply() {
|
||||
List<Float> predictions = context.model.predict(context.currentEntry.getData());
|
||||
context.prediction = predictions.getFirst();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
|
||||
public class SimpleTrainingContext extends TrainingContext {
|
||||
}
|
||||
@@ -1,16 +1,19 @@
|
||||
package com.naaturel.ANN.implementation.training;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.Trainable;
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentErrorStrategy;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||
import com.naaturel.ANN.implementation.correction.GradientDescentCorrectionStrategy;
|
||||
import com.naaturel.ANN.implementation.correction.SimpleCorrectionStrategy;
|
||||
import com.naaturel.ANN.implementation.loss.SquareLossStrategy;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentCorrectionStrategy;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStrategy;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStrategy;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy;
|
||||
import com.naaturel.ANN.implementation.training.steps.*;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class GradientDescentTraining implements Trainer {
|
||||
@@ -20,25 +23,31 @@ public class GradientDescentTraining implements Trainer {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void train(Trainable model, DataSet dataset) {
|
||||
TrainingContext context = new TrainingContext();
|
||||
public void train(Model model, DataSet dataset) {
|
||||
GradientDescentTrainingContext context = new GradientDescentTrainingContext();
|
||||
context.dataset = dataset;
|
||||
context.model = model;
|
||||
context.learningRate = 0.3F;
|
||||
context.learningRate = 0.00011F;
|
||||
context.correctorTerms = new ArrayList<>();
|
||||
|
||||
List<TrainingStep> steps = List.of(
|
||||
new PredictionStep(),
|
||||
new DeltaStep(),
|
||||
new LossStep(new SquareLossStrategy()),
|
||||
new SimpleErrorDetectionStep(),
|
||||
new WeightCorrectionStep(new GradientDescentCorrectionStrategy(2))
|
||||
new PredictionStep(new SimplePredictionStrategy(context)),
|
||||
new DeltaStep(new SimpleDeltaStrategy(context)),
|
||||
new LossStep(new SquareLossStrategy(context)),
|
||||
new ErrorRegistrationStep(new GradientDescentErrorStrategy(context)),
|
||||
new WeightCorrectionStep(new GradientDescentCorrectionStrategy(context))
|
||||
);
|
||||
|
||||
TrainingPipeline pipeline = new TrainingPipeline(steps);
|
||||
pipeline
|
||||
.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 100)
|
||||
.beforeEpoch(ctx -> ctx.globalLoss = 0)
|
||||
.afterEpoch(ctx -> ())
|
||||
.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 1000)
|
||||
.beforeEpoch(ctx -> {
|
||||
ctx.globalLoss = 0.0F;
|
||||
for (int i = 0; i < model.synCount(); i++){
|
||||
context.correctorTerms.add(0F);
|
||||
}
|
||||
})
|
||||
.afterEpoch(ctx -> ctx.globalLoss /= ctx.dataset.size())
|
||||
.withVerbose(true)
|
||||
.withTimeMeasurement(true)
|
||||
.run(context);
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
package com.naaturel.ANN.implementation.training;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.Trainable;
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.*;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||
import com.naaturel.ANN.implementation.correction.SimpleCorrectionStrategy;
|
||||
import com.naaturel.ANN.implementation.loss.SimpleLossStrategy;
|
||||
import com.naaturel.ANN.implementation.training.steps.*;
|
||||
|
||||
import java.util.List;
|
||||
@@ -19,18 +17,18 @@ public class SimpleTraining implements Trainer {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void train(Trainable model, DataSet dataset) {
|
||||
TrainingContext context = new TrainingContext();
|
||||
public void train(Model model, DataSet dataset) {
|
||||
SimpleTrainingContext context = new SimpleTrainingContext();
|
||||
context.dataset = dataset;
|
||||
context.model = model;
|
||||
context.learningRate = 0.3F;
|
||||
|
||||
List<TrainingStep> steps = List.of(
|
||||
new PredictionStep(),
|
||||
new DeltaStep(),
|
||||
new LossStep(new SimpleLossStrategy()),
|
||||
new SimpleErrorDetectionStep(),
|
||||
new WeightCorrectionStep(new SimpleCorrectionStrategy())
|
||||
new PredictionStep(new SimplePredictionStrategy(context)),
|
||||
new DeltaStep(new SimpleDeltaStrategy(context)),
|
||||
new LossStep(new SimpleLossStrategy(context)),
|
||||
new ErrorRegistrationStep(new SimpleErrorRegistrationStrategy(context)),
|
||||
new WeightCorrectionStep(new SimpleCorrectionStrategy(context))
|
||||
);
|
||||
|
||||
TrainingPipeline pipeline = new TrainingPipeline(steps);
|
||||
|
||||
@@ -1,19 +1,22 @@
|
||||
package com.naaturel.ANN.implementation.training.steps;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||
import com.naaturel.ANN.domain.model.dataset.Label;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
||||
|
||||
public class DeltaStep implements TrainingStep {
|
||||
|
||||
@Override
|
||||
public void run(TrainingContext ctx) {
|
||||
DataSet dataSet = ctx.dataset;
|
||||
DataSetEntry entry = ctx.currentEntry;
|
||||
Label label = dataSet.getLabel(entry);
|
||||
private final AlgorithmStrategy strategy;
|
||||
|
||||
ctx.delta = label.getValue() - ctx.prediction;
|
||||
public DeltaStep(AlgorithmStrategy strategy) {
|
||||
this.strategy = strategy;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
this.strategy.apply();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
package com.naaturel.ANN.implementation.training.steps;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
|
||||
public class ErrorRegistrationStep implements TrainingStep {
|
||||
|
||||
private final AlgorithmStrategy strategy;
|
||||
|
||||
public ErrorRegistrationStep(AlgorithmStrategy strategy) {
|
||||
this.strategy = strategy;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
this.strategy.apply();
|
||||
}
|
||||
}
|
||||
@@ -1,11 +1,12 @@
|
||||
package com.naaturel.ANN.implementation.training.steps;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
||||
|
||||
public class LossStep implements TrainingStep {
|
||||
|
||||
|
||||
private final AlgorithmStrategy lossStrategy;
|
||||
|
||||
public LossStep(AlgorithmStrategy strategy) {
|
||||
@@ -13,7 +14,7 @@ public class LossStep implements TrainingStep {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run(TrainingContext ctx) {
|
||||
this.lossStrategy.apply(ctx);
|
||||
public void run() {
|
||||
this.lossStrategy.apply();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,17 +1,23 @@
|
||||
package com.naaturel.ANN.implementation.training.steps;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleTrainingContext;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class PredictionStep implements TrainingStep {
|
||||
|
||||
private final SimplePredictionStrategy strategy;
|
||||
|
||||
public PredictionStep(SimplePredictionStrategy strategy) {
|
||||
this.strategy = strategy;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run(TrainingContext ctx) {
|
||||
List<Float> predictions = ctx.model.predict(ctx.currentEntry.getData());
|
||||
ctx.prediction = predictions.getFirst();
|
||||
public void run() {
|
||||
this.strategy.apply();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
package com.naaturel.ANN.implementation.training.steps;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
||||
|
||||
public class SimpleErrorDetectionStep implements TrainingStep {
|
||||
|
||||
@Override
|
||||
public void run(TrainingContext ctx) {
|
||||
ctx.globalLoss += ctx.localLoss;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -2,7 +2,6 @@ package com.naaturel.ANN.implementation.training.steps;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
||||
|
||||
public class WeightCorrectionStep implements TrainingStep {
|
||||
|
||||
@@ -13,7 +12,7 @@ public class WeightCorrectionStep implements TrainingStep {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run(TrainingContext ctx) {
|
||||
this.correctionStrategy.apply(ctx);
|
||||
public void run() {
|
||||
this.correctionStrategy.apply();
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user