Fix implementation
This commit is contained in:
@@ -2,28 +2,26 @@ package com.naaturel.ANN;
|
|||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
|
||||||
import com.naaturel.ANN.domain.model.dataset.DatasetExtractor;
|
import com.naaturel.ANN.domain.model.dataset.DatasetExtractor;
|
||||||
import com.naaturel.ANN.domain.model.dataset.Label;
|
|
||||||
import com.naaturel.ANN.domain.model.neuron.*;
|
import com.naaturel.ANN.domain.model.neuron.*;
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
import com.naaturel.ANN.implementation.gradientDescent.Linear;
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
import com.naaturel.ANN.implementation.simplePerceptron.Heaviside;
|
||||||
import com.naaturel.ANN.implementation.activation.Heaviside;
|
|
||||||
import com.naaturel.ANN.implementation.correction.SimpleCorrectionStrategy;
|
|
||||||
import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
|
import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
|
||||||
|
import com.naaturel.ANN.implementation.training.GradientDescentTraining;
|
||||||
import com.naaturel.ANN.implementation.training.SimpleTraining;
|
import com.naaturel.ANN.implementation.training.SimpleTraining;
|
||||||
import com.naaturel.ANN.implementation.training.steps.*;
|
|
||||||
|
|
||||||
import javax.xml.crypto.Data;
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
public class Main {
|
public class Main {
|
||||||
|
|
||||||
public static void main(String[] args){
|
public static void main(String[] args){
|
||||||
|
|
||||||
DataSet dataset = new DatasetExtractor().extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/or.csv");
|
DataSet dataset = new DatasetExtractor()
|
||||||
|
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_9.csv");
|
||||||
|
|
||||||
|
DataSet orDataset = new DatasetExtractor()
|
||||||
|
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/or.csv");
|
||||||
|
|
||||||
List<Synapse> syns = new ArrayList<>();
|
List<Synapse> syns = new ArrayList<>();
|
||||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||||
@@ -31,11 +29,11 @@ public class Main {
|
|||||||
|
|
||||||
Bias bias = new Bias(new Weight(0));
|
Bias bias = new Bias(new Weight(0));
|
||||||
|
|
||||||
Neuron neuron = new SimplePerceptron(syns, bias, new Heaviside());
|
Neuron neuron = new SimplePerceptron(syns, bias, new Linear());
|
||||||
Layer layer = new Layer(List.of(neuron));
|
Layer layer = new Layer(List.of(neuron));
|
||||||
Network network = new Network(List.of(layer));
|
Network network = new Network(List.of(layer));
|
||||||
|
|
||||||
Trainer trainer = new SimpleTraining();
|
Trainer trainer = new GradientDescentTraining();
|
||||||
trainer.train(network, dataset);
|
trainer.train(network, dataset);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,7 @@
|
|||||||
package com.naaturel.ANN.domain.abstraction;
|
package com.naaturel.ANN.domain.abstraction;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
|
||||||
|
|
||||||
public interface AlgorithmStrategy {
|
public interface AlgorithmStrategy {
|
||||||
|
|
||||||
void apply(TrainingContext ctx);
|
void apply();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,9 +6,9 @@ import com.naaturel.ANN.domain.model.neuron.Synapse;
|
|||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.function.Consumer;
|
import java.util.function.Consumer;
|
||||||
|
|
||||||
public interface Trainable {
|
public interface Model {
|
||||||
|
int synCount();
|
||||||
|
void applyOnSynapses(Consumer<Synapse> consumer);
|
||||||
List<Float> predict(List<Input> inputs);
|
List<Float> predict(List<Input> inputs);
|
||||||
|
|
||||||
void applyOnSynapses(Consumer<Synapse> consumer);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -4,10 +4,9 @@ import com.naaturel.ANN.domain.model.neuron.Input;
|
|||||||
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
||||||
import com.naaturel.ANN.domain.model.neuron.Weight;
|
import com.naaturel.ANN.domain.model.neuron.Weight;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public abstract class Neuron implements Trainable {
|
public abstract class Neuron implements Model {
|
||||||
|
|
||||||
protected List<Synapse> synapses;
|
protected List<Synapse> synapses;
|
||||||
protected Bias bias;
|
protected Bias bias;
|
||||||
@@ -35,4 +34,9 @@ public abstract class Neuron implements Trainable {
|
|||||||
syn.setInput(inputs.get(i));
|
syn.setInput(inputs.get(i));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int synCount() {
|
||||||
|
return this.synapses.size()+1; //take the bias in account
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,5 +3,5 @@ package com.naaturel.ANN.domain.abstraction;
|
|||||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||||
|
|
||||||
public interface Trainer {
|
public interface Trainer {
|
||||||
void train(Trainable model, DataSet dataset);
|
void train(Model model, DataSet dataset);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,21 +1,21 @@
|
|||||||
package com.naaturel.ANN.domain.model.training;
|
package com.naaturel.ANN.domain.abstraction;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.Trainable;
|
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||||
import com.naaturel.ANN.domain.model.dataset.Label;
|
import com.naaturel.ANN.domain.model.dataset.Label;
|
||||||
|
|
||||||
public class TrainingContext {
|
public abstract class TrainingContext {
|
||||||
public Trainable model;
|
public Model model;
|
||||||
public DataSet dataset;
|
public DataSet dataset;
|
||||||
public DataSetEntry currentEntry;
|
public DataSetEntry currentEntry;
|
||||||
public Label currentLabel;
|
|
||||||
|
|
||||||
|
public Label currentLabel;
|
||||||
public float prediction;
|
public float prediction;
|
||||||
public float delta;
|
public float delta;
|
||||||
public float localLoss;
|
|
||||||
public float globalLoss;
|
|
||||||
public float learningRate;
|
|
||||||
|
|
||||||
|
public float globalLoss;
|
||||||
|
public float localLoss;
|
||||||
|
|
||||||
|
public float learningRate;
|
||||||
public int epoch;
|
public int epoch;
|
||||||
}
|
}
|
||||||
@@ -1,9 +1,7 @@
|
|||||||
package com.naaturel.ANN.domain.abstraction;
|
package com.naaturel.ANN.domain.abstraction;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
|
||||||
|
|
||||||
public interface TrainingStep {
|
public interface TrainingStep {
|
||||||
|
|
||||||
void run(TrainingContext ctx);
|
void run();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
package com.naaturel.ANN.domain.model.neuron;
|
package com.naaturel.ANN.domain.model.neuron;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||||
import com.naaturel.ANN.domain.abstraction.Trainable;
|
import com.naaturel.ANN.domain.abstraction.Model;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.function.Consumer;
|
import java.util.function.Consumer;
|
||||||
|
|
||||||
public class Layer implements Trainable {
|
public class Layer implements Model {
|
||||||
|
|
||||||
private final List<Neuron> neurons;
|
private final List<Neuron> neurons;
|
||||||
|
|
||||||
@@ -25,6 +25,15 @@ public class Layer implements Trainable {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int synCount() {
|
||||||
|
int res = 0;
|
||||||
|
for (Neuron neuron : this.neurons) {
|
||||||
|
res += neuron.synCount();
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void applyOnSynapses(Consumer<Synapse> consumer) {
|
public void applyOnSynapses(Consumer<Synapse> consumer) {
|
||||||
this.neurons.forEach(neuron -> neuron.applyOnSynapses(consumer));
|
this.neurons.forEach(neuron -> neuron.applyOnSynapses(consumer));
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
package com.naaturel.ANN.domain.model.neuron;
|
package com.naaturel.ANN.domain.model.neuron;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.Trainable;
|
import com.naaturel.ANN.domain.abstraction.Model;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.function.Consumer;
|
import java.util.function.Consumer;
|
||||||
|
|
||||||
public class Network implements Trainable {
|
public class Network implements Model {
|
||||||
|
|
||||||
private final List<Layer> layers;
|
private final List<Layer> layers;
|
||||||
|
|
||||||
@@ -24,6 +24,15 @@ public class Network implements Trainable {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int synCount() {
|
||||||
|
int res = 0;
|
||||||
|
for(Layer layer : this.layers){
|
||||||
|
res += layer.synCount();
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void applyOnSynapses(Consumer<Synapse> consumer) {
|
public void applyOnSynapses(Consumer<Synapse> consumer) {
|
||||||
this.layers.forEach(layer -> layer.applyOnSynapses(consumer));
|
this.layers.forEach(layer -> layer.applyOnSynapses(consumer));
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
package com.naaturel.ANN.domain.model.training;
|
package com.naaturel.ANN.domain.model.training;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||||
|
|
||||||
@@ -55,6 +56,9 @@ public class TrainingPipeline {
|
|||||||
this.beforeEpoch.accept(ctx);
|
this.beforeEpoch.accept(ctx);
|
||||||
this.executeSteps(ctx);
|
this.executeSteps(ctx);
|
||||||
this.afterEpoch.accept(ctx);
|
this.afterEpoch.accept(ctx);
|
||||||
|
if(this.verbose) {
|
||||||
|
System.out.printf("[Global error] : %.2f\n", ctx.globalLoss);
|
||||||
|
}
|
||||||
} while (!this.stopCondition.test(ctx));
|
} while (!this.stopCondition.test(ctx));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -63,18 +67,16 @@ public class TrainingPipeline {
|
|||||||
ctx.currentEntry = entry;
|
ctx.currentEntry = entry;
|
||||||
ctx.currentLabel = ctx.dataset.getLabel(entry);
|
ctx.currentLabel = ctx.dataset.getLabel(entry);
|
||||||
for (TrainingStep step : steps) {
|
for (TrainingStep step : steps) {
|
||||||
step.run(ctx);
|
step.run();
|
||||||
}
|
}
|
||||||
if(this.verbose) {
|
if(this.verbose) {
|
||||||
System.out.printf("Epoch : %d, ", ctx.epoch);
|
System.out.printf("Epoch : %d, ", ctx.epoch);
|
||||||
System.out.printf("predicted : %.2f, ", ctx.prediction);
|
System.out.printf("predicted : %.2f, ", ctx.prediction);
|
||||||
System.out.printf("expected : %.2f, ", ctx.currentLabel.getValue());
|
System.out.printf("expected : %.2f, ", ctx.currentLabel.getValue());
|
||||||
System.out.printf("delta : %.2f\n", ctx.delta);
|
System.out.printf("delta : %.2f, ", ctx.delta);
|
||||||
|
System.out.printf("loss : %.5f\n", ctx.localLoss);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if(this.verbose) {
|
|
||||||
System.out.printf("[Global error] : %.2f\n", ctx.globalLoss);
|
|
||||||
}
|
|
||||||
ctx.epoch += 1;
|
ctx.epoch += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,25 +0,0 @@
|
|||||||
package com.naaturel.ANN.implementation.correction;
|
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class GradientDescentCorrectionStrategy implements AlgorithmStrategy {
|
|
||||||
|
|
||||||
List<Float> correctorTerms;
|
|
||||||
|
|
||||||
public GradientDescentCorrectionStrategy(int nbrCorrectors){
|
|
||||||
this.correctorTerms = new ArrayList<>();
|
|
||||||
for (int i = 0; i < nbrCorrectors; i++){
|
|
||||||
correctorTerms.add(0F);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void apply(TrainingContext context) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
package com.naaturel.ANN.implementation.gradientDescent;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||||
|
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
|
||||||
|
public class GradientDescentCorrectionStrategy implements AlgorithmStrategy {
|
||||||
|
|
||||||
|
private final GradientDescentTrainingContext context;
|
||||||
|
|
||||||
|
public GradientDescentCorrectionStrategy(GradientDescentTrainingContext context) {
|
||||||
|
this.context = context;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void apply() {
|
||||||
|
AtomicInteger i = new AtomicInteger(0);
|
||||||
|
context.model.applyOnSynapses(syn -> {
|
||||||
|
float corrector = context.correctorTerms.get(i.get());
|
||||||
|
float c = syn.getWeight() + corrector;
|
||||||
|
syn.setWeight(c);
|
||||||
|
i.incrementAndGet();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,26 @@
|
|||||||
|
package com.naaturel.ANN.implementation.gradientDescent;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||||
|
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
|
||||||
|
public class GradientDescentErrorStrategy implements AlgorithmStrategy {
|
||||||
|
|
||||||
|
private final GradientDescentTrainingContext context;
|
||||||
|
|
||||||
|
public GradientDescentErrorStrategy(GradientDescentTrainingContext context) {
|
||||||
|
this.context = context;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void apply() {
|
||||||
|
AtomicInteger i = new AtomicInteger(0);
|
||||||
|
context.model.applyOnSynapses(syn -> {
|
||||||
|
float corrector = context.correctorTerms.get(i.get());
|
||||||
|
corrector += context.learningRate * context.delta * syn.getInput();
|
||||||
|
context.correctorTerms.set(i.get(), corrector);
|
||||||
|
i.incrementAndGet();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
package com.naaturel.ANN.implementation.gradientDescent;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class GradientDescentTrainingContext extends TrainingContext {
|
||||||
|
|
||||||
|
public List<Float> correctorTerms;
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.naaturel.ANN.implementation.activation;
|
package com.naaturel.ANN.implementation.gradientDescent;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
||||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||||
@@ -0,0 +1,19 @@
|
|||||||
|
package com.naaturel.ANN.implementation.gradientDescent;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||||
|
import com.naaturel.ANN.implementation.simplePerceptron.SimpleTrainingContext;
|
||||||
|
|
||||||
|
public class SquareLossStrategy implements AlgorithmStrategy {
|
||||||
|
|
||||||
|
private final GradientDescentTrainingContext context;
|
||||||
|
|
||||||
|
public SquareLossStrategy(GradientDescentTrainingContext context) {
|
||||||
|
this.context = context;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void apply() {
|
||||||
|
this.context.localLoss = (float)Math.pow(this.context.delta, 2)/2;
|
||||||
|
this.context.globalLoss += context.localLoss;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
package com.naaturel.ANN.implementation.loss;
|
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
|
||||||
|
|
||||||
public class SimpleLossStrategy implements AlgorithmStrategy {
|
|
||||||
@Override
|
|
||||||
public void apply(TrainingContext ctx) {
|
|
||||||
ctx.localLoss = Math.abs(ctx.delta);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
package com.naaturel.ANN.implementation.loss;
|
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
|
||||||
|
|
||||||
public class SquareLossStrategy implements AlgorithmStrategy {
|
|
||||||
@Override
|
|
||||||
public void apply(TrainingContext ctx) {
|
|
||||||
ctx.localLoss = (float)Math.pow(ctx.delta, 2) / 2;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.naaturel.ANN.implementation.activation;
|
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
||||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||||
@@ -1,14 +1,18 @@
|
|||||||
package com.naaturel.ANN.implementation.correction;
|
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
|
||||||
|
|
||||||
public class SimpleCorrectionStrategy implements AlgorithmStrategy {
|
public class SimpleCorrectionStrategy implements AlgorithmStrategy {
|
||||||
|
|
||||||
@Override
|
private final SimpleTrainingContext context;
|
||||||
public void apply(TrainingContext context) {
|
|
||||||
if(context.currentLabel.getValue() == context.prediction) return ;
|
|
||||||
|
|
||||||
|
public SimpleCorrectionStrategy(SimpleTrainingContext context) {
|
||||||
|
this.context = context;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void apply() {
|
||||||
|
if(context.currentLabel.getValue() == context.prediction) return ;
|
||||||
context.model.applyOnSynapses(syn -> {
|
context.model.applyOnSynapses(syn -> {
|
||||||
float currentW = syn.getWeight();
|
float currentW = syn.getWeight();
|
||||||
float currentInput = syn.getInput();
|
float currentInput = syn.getInput();
|
||||||
@@ -0,0 +1,26 @@
|
|||||||
|
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||||
|
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||||
|
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||||
|
import com.naaturel.ANN.domain.model.dataset.Label;
|
||||||
|
|
||||||
|
public class SimpleDeltaStrategy implements AlgorithmStrategy {
|
||||||
|
|
||||||
|
private final TrainingContext context;
|
||||||
|
|
||||||
|
public SimpleDeltaStrategy(TrainingContext context) {
|
||||||
|
this.context = context;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void apply() {
|
||||||
|
DataSet dataSet = context.dataset;
|
||||||
|
DataSetEntry entry = context.currentEntry;
|
||||||
|
Label label = dataSet.getLabel(entry);
|
||||||
|
|
||||||
|
context.delta = label.getValue() - context.prediction;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||||
|
|
||||||
|
public class SimpleErrorRegistrationStrategy implements AlgorithmStrategy {
|
||||||
|
|
||||||
|
private final SimpleTrainingContext context;
|
||||||
|
|
||||||
|
public SimpleErrorRegistrationStrategy(SimpleTrainingContext context) {
|
||||||
|
this.context = context;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void apply() {
|
||||||
|
context.globalLoss += context.localLoss;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||||
|
|
||||||
|
public class SimpleLossStrategy implements AlgorithmStrategy {
|
||||||
|
|
||||||
|
private final SimpleTrainingContext context;
|
||||||
|
|
||||||
|
public SimpleLossStrategy(SimpleTrainingContext context) {
|
||||||
|
this.context = context;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void apply() {
|
||||||
|
this.context.localLoss = Math.abs(this.context.delta);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class SimplePredictionStrategy implements AlgorithmStrategy {
|
||||||
|
|
||||||
|
private final TrainingContext context;
|
||||||
|
|
||||||
|
public SimplePredictionStrategy(TrainingContext context) {
|
||||||
|
this.context = context;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void apply() {
|
||||||
|
List<Float> predictions = context.model.predict(context.currentEntry.getData());
|
||||||
|
context.prediction = predictions.getFirst();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||||
|
|
||||||
|
public class SimpleTrainingContext extends TrainingContext {
|
||||||
|
}
|
||||||
@@ -1,16 +1,19 @@
|
|||||||
package com.naaturel.ANN.implementation.training;
|
package com.naaturel.ANN.implementation.training;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.Trainable;
|
import com.naaturel.ANN.domain.abstraction.Model;
|
||||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentErrorStrategy;
|
||||||
|
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext;
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||||
import com.naaturel.ANN.implementation.correction.GradientDescentCorrectionStrategy;
|
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentCorrectionStrategy;
|
||||||
import com.naaturel.ANN.implementation.correction.SimpleCorrectionStrategy;
|
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStrategy;
|
||||||
import com.naaturel.ANN.implementation.loss.SquareLossStrategy;
|
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStrategy;
|
||||||
|
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy;
|
||||||
import com.naaturel.ANN.implementation.training.steps.*;
|
import com.naaturel.ANN.implementation.training.steps.*;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public class GradientDescentTraining implements Trainer {
|
public class GradientDescentTraining implements Trainer {
|
||||||
@@ -20,25 +23,31 @@ public class GradientDescentTraining implements Trainer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void train(Trainable model, DataSet dataset) {
|
public void train(Model model, DataSet dataset) {
|
||||||
TrainingContext context = new TrainingContext();
|
GradientDescentTrainingContext context = new GradientDescentTrainingContext();
|
||||||
context.dataset = dataset;
|
context.dataset = dataset;
|
||||||
context.model = model;
|
context.model = model;
|
||||||
context.learningRate = 0.3F;
|
context.learningRate = 0.00011F;
|
||||||
|
context.correctorTerms = new ArrayList<>();
|
||||||
|
|
||||||
List<TrainingStep> steps = List.of(
|
List<TrainingStep> steps = List.of(
|
||||||
new PredictionStep(),
|
new PredictionStep(new SimplePredictionStrategy(context)),
|
||||||
new DeltaStep(),
|
new DeltaStep(new SimpleDeltaStrategy(context)),
|
||||||
new LossStep(new SquareLossStrategy()),
|
new LossStep(new SquareLossStrategy(context)),
|
||||||
new SimpleErrorDetectionStep(),
|
new ErrorRegistrationStep(new GradientDescentErrorStrategy(context)),
|
||||||
new WeightCorrectionStep(new GradientDescentCorrectionStrategy(2))
|
new WeightCorrectionStep(new GradientDescentCorrectionStrategy(context))
|
||||||
);
|
);
|
||||||
|
|
||||||
TrainingPipeline pipeline = new TrainingPipeline(steps);
|
TrainingPipeline pipeline = new TrainingPipeline(steps);
|
||||||
pipeline
|
pipeline
|
||||||
.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 100)
|
.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 1000)
|
||||||
.beforeEpoch(ctx -> ctx.globalLoss = 0)
|
.beforeEpoch(ctx -> {
|
||||||
.afterEpoch(ctx -> ())
|
ctx.globalLoss = 0.0F;
|
||||||
|
for (int i = 0; i < model.synCount(); i++){
|
||||||
|
context.correctorTerms.add(0F);
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.afterEpoch(ctx -> ctx.globalLoss /= ctx.dataset.size())
|
||||||
.withVerbose(true)
|
.withVerbose(true)
|
||||||
.withTimeMeasurement(true)
|
.withTimeMeasurement(true)
|
||||||
.run(context);
|
.run(context);
|
||||||
|
|||||||
@@ -1,13 +1,11 @@
|
|||||||
package com.naaturel.ANN.implementation.training;
|
package com.naaturel.ANN.implementation.training;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.Trainable;
|
import com.naaturel.ANN.domain.abstraction.Model;
|
||||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
import com.naaturel.ANN.implementation.simplePerceptron.*;
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||||
import com.naaturel.ANN.implementation.correction.SimpleCorrectionStrategy;
|
|
||||||
import com.naaturel.ANN.implementation.loss.SimpleLossStrategy;
|
|
||||||
import com.naaturel.ANN.implementation.training.steps.*;
|
import com.naaturel.ANN.implementation.training.steps.*;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -19,18 +17,18 @@ public class SimpleTraining implements Trainer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void train(Trainable model, DataSet dataset) {
|
public void train(Model model, DataSet dataset) {
|
||||||
TrainingContext context = new TrainingContext();
|
SimpleTrainingContext context = new SimpleTrainingContext();
|
||||||
context.dataset = dataset;
|
context.dataset = dataset;
|
||||||
context.model = model;
|
context.model = model;
|
||||||
context.learningRate = 0.3F;
|
context.learningRate = 0.3F;
|
||||||
|
|
||||||
List<TrainingStep> steps = List.of(
|
List<TrainingStep> steps = List.of(
|
||||||
new PredictionStep(),
|
new PredictionStep(new SimplePredictionStrategy(context)),
|
||||||
new DeltaStep(),
|
new DeltaStep(new SimpleDeltaStrategy(context)),
|
||||||
new LossStep(new SimpleLossStrategy()),
|
new LossStep(new SimpleLossStrategy(context)),
|
||||||
new SimpleErrorDetectionStep(),
|
new ErrorRegistrationStep(new SimpleErrorRegistrationStrategy(context)),
|
||||||
new WeightCorrectionStep(new SimpleCorrectionStrategy())
|
new WeightCorrectionStep(new SimpleCorrectionStrategy(context))
|
||||||
);
|
);
|
||||||
|
|
||||||
TrainingPipeline pipeline = new TrainingPipeline(steps);
|
TrainingPipeline pipeline = new TrainingPipeline(steps);
|
||||||
|
|||||||
@@ -1,19 +1,22 @@
|
|||||||
package com.naaturel.ANN.implementation.training.steps;
|
package com.naaturel.ANN.implementation.training.steps;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||||
import com.naaturel.ANN.domain.model.dataset.Label;
|
import com.naaturel.ANN.domain.model.dataset.Label;
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
|
||||||
|
|
||||||
public class DeltaStep implements TrainingStep {
|
public class DeltaStep implements TrainingStep {
|
||||||
|
|
||||||
@Override
|
private final AlgorithmStrategy strategy;
|
||||||
public void run(TrainingContext ctx) {
|
|
||||||
DataSet dataSet = ctx.dataset;
|
|
||||||
DataSetEntry entry = ctx.currentEntry;
|
|
||||||
Label label = dataSet.getLabel(entry);
|
|
||||||
|
|
||||||
ctx.delta = label.getValue() - ctx.prediction;
|
public DeltaStep(AlgorithmStrategy strategy) {
|
||||||
|
this.strategy = strategy;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void run() {
|
||||||
|
this.strategy.apply();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,18 @@
|
|||||||
|
package com.naaturel.ANN.implementation.training.steps;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
|
|
||||||
|
public class ErrorRegistrationStep implements TrainingStep {
|
||||||
|
|
||||||
|
private final AlgorithmStrategy strategy;
|
||||||
|
|
||||||
|
public ErrorRegistrationStep(AlgorithmStrategy strategy) {
|
||||||
|
this.strategy = strategy;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void run() {
|
||||||
|
this.strategy.apply();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,11 +1,12 @@
|
|||||||
package com.naaturel.ANN.implementation.training.steps;
|
package com.naaturel.ANN.implementation.training.steps;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
|
||||||
|
|
||||||
public class LossStep implements TrainingStep {
|
public class LossStep implements TrainingStep {
|
||||||
|
|
||||||
|
|
||||||
private final AlgorithmStrategy lossStrategy;
|
private final AlgorithmStrategy lossStrategy;
|
||||||
|
|
||||||
public LossStep(AlgorithmStrategy strategy) {
|
public LossStep(AlgorithmStrategy strategy) {
|
||||||
@@ -13,7 +14,7 @@ public class LossStep implements TrainingStep {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void run(TrainingContext ctx) {
|
public void run() {
|
||||||
this.lossStrategy.apply(ctx);
|
this.lossStrategy.apply();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,17 +1,23 @@
|
|||||||
package com.naaturel.ANN.implementation.training.steps;
|
package com.naaturel.ANN.implementation.training.steps;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy;
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
import com.naaturel.ANN.implementation.simplePerceptron.SimpleTrainingContext;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public class PredictionStep implements TrainingStep {
|
public class PredictionStep implements TrainingStep {
|
||||||
|
|
||||||
|
private final SimplePredictionStrategy strategy;
|
||||||
|
|
||||||
|
public PredictionStep(SimplePredictionStrategy strategy) {
|
||||||
|
this.strategy = strategy;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void run(TrainingContext ctx) {
|
public void run() {
|
||||||
List<Float> predictions = ctx.model.predict(ctx.currentEntry.getData());
|
this.strategy.apply();
|
||||||
ctx.prediction = predictions.getFirst();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +0,0 @@
|
|||||||
package com.naaturel.ANN.implementation.training.steps;
|
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
|
||||||
|
|
||||||
public class SimpleErrorDetectionStep implements TrainingStep {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void run(TrainingContext ctx) {
|
|
||||||
ctx.globalLoss += ctx.localLoss;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -2,7 +2,6 @@ package com.naaturel.ANN.implementation.training.steps;
|
|||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
|
||||||
|
|
||||||
public class WeightCorrectionStep implements TrainingStep {
|
public class WeightCorrectionStep implements TrainingStep {
|
||||||
|
|
||||||
@@ -13,7 +12,7 @@ public class WeightCorrectionStep implements TrainingStep {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void run(TrainingContext ctx) {
|
public void run() {
|
||||||
this.correctionStrategy.apply(ctx);
|
this.correctionStrategy.apply();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user