Reimplement Adaline
This commit is contained in:
@@ -8,6 +8,7 @@ import com.naaturel.ANN.domain.model.neuron.*;
|
|||||||
import com.naaturel.ANN.implementation.gradientDescent.Linear;
|
import com.naaturel.ANN.implementation.gradientDescent.Linear;
|
||||||
import com.naaturel.ANN.implementation.simplePerceptron.Heaviside;
|
import com.naaturel.ANN.implementation.simplePerceptron.Heaviside;
|
||||||
import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
|
import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
|
||||||
|
import com.naaturel.ANN.implementation.training.AdalineTraining;
|
||||||
import com.naaturel.ANN.implementation.training.GradientDescentTraining;
|
import com.naaturel.ANN.implementation.training.GradientDescentTraining;
|
||||||
import com.naaturel.ANN.implementation.training.SimpleTraining;
|
import com.naaturel.ANN.implementation.training.SimpleTraining;
|
||||||
|
|
||||||
@@ -21,7 +22,7 @@ public class Main {
|
|||||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_9.csv");
|
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_9.csv");
|
||||||
|
|
||||||
DataSet andDataset = new DatasetExtractor()
|
DataSet andDataset = new DatasetExtractor()
|
||||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and.csv");
|
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and-gradient.csv");
|
||||||
|
|
||||||
List<Synapse> syns = new ArrayList<>();
|
List<Synapse> syns = new ArrayList<>();
|
||||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||||
@@ -29,12 +30,12 @@ public class Main {
|
|||||||
|
|
||||||
Bias bias = new Bias(new Weight(0));
|
Bias bias = new Bias(new Weight(0));
|
||||||
|
|
||||||
Neuron neuron = new SimplePerceptron(syns, bias, new Heaviside());
|
Neuron neuron = new SimplePerceptron(syns, bias, new Linear());
|
||||||
Layer layer = new Layer(List.of(neuron));
|
Layer layer = new Layer(List.of(neuron));
|
||||||
Network network = new Network(List.of(layer));
|
Network network = new Network(List.of(layer));
|
||||||
|
|
||||||
Trainer trainer = new SimpleTraining();
|
Trainer trainer = new AdalineTraining();
|
||||||
trainer.train(network, andDataset);
|
trainer.train(network, dataset);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
|||||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||||
|
|
||||||
|
import java.sql.Time;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.function.Consumer;
|
import java.util.function.Consumer;
|
||||||
@@ -52,11 +53,23 @@ public class TrainingPipeline {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public void run(TrainingContext ctx) {
|
public void run(TrainingContext ctx) {
|
||||||
|
|
||||||
|
long start = this.timeMeasurement ? System.currentTimeMillis() : 0;
|
||||||
|
|
||||||
do {
|
do {
|
||||||
this.beforeEpoch.accept(ctx);
|
this.beforeEpoch.accept(ctx);
|
||||||
this.executeSteps(ctx);
|
this.executeSteps(ctx);
|
||||||
this.afterEpoch.accept(ctx);
|
this.afterEpoch.accept(ctx);
|
||||||
|
if(this.verbose) {
|
||||||
|
System.out.printf("[Global error] : %f\n", ctx.globalLoss);
|
||||||
|
}
|
||||||
} while (!this.stopCondition.test(ctx));
|
} while (!this.stopCondition.test(ctx));
|
||||||
|
|
||||||
|
if(this.timeMeasurement) {
|
||||||
|
long end = System.currentTimeMillis();
|
||||||
|
System.out.printf("[Training finished in %.3fs]", (end-start)/1000.0);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void executeSteps(TrainingContext ctx){
|
private void executeSteps(TrainingContext ctx){
|
||||||
@@ -74,9 +87,6 @@ public class TrainingPipeline {
|
|||||||
System.out.printf("loss : %.5f\n", ctx.localLoss);
|
System.out.printf("loss : %.5f\n", ctx.localLoss);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if(this.verbose) {
|
|
||||||
System.out.printf("[Global error] : %.2f\n", ctx.globalLoss);
|
|
||||||
}
|
|
||||||
ctx.epoch += 1;
|
ctx.epoch += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,6 @@
|
|||||||
|
package com.naaturel.ANN.implementation.adaline;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||||
|
|
||||||
|
public class AdalineTrainingContext extends TrainingContext {
|
||||||
|
}
|
||||||
@@ -22,5 +22,6 @@ public class GradientDescentErrorStrategy implements AlgorithmStrategy {
|
|||||||
context.correctorTerms.set(i.get(), corrector);
|
context.correctorTerms.set(i.get(), corrector);
|
||||||
i.incrementAndGet();
|
i.incrementAndGet();
|
||||||
});
|
});
|
||||||
|
context.globalLoss += context.localLoss;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,19 +1,19 @@
|
|||||||
package com.naaturel.ANN.implementation.gradientDescent;
|
package com.naaturel.ANN.implementation.gradientDescent;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleTrainingContext;
|
import com.naaturel.ANN.implementation.simplePerceptron.SimpleTrainingContext;
|
||||||
|
|
||||||
public class SquareLossStrategy implements AlgorithmStrategy {
|
public class SquareLossStrategy implements AlgorithmStrategy {
|
||||||
|
|
||||||
private final GradientDescentTrainingContext context;
|
private final TrainingContext context;
|
||||||
|
|
||||||
public SquareLossStrategy(GradientDescentTrainingContext context) {
|
public SquareLossStrategy(TrainingContext context) {
|
||||||
this.context = context;
|
this.context = context;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void apply() {
|
public void apply() {
|
||||||
this.context.localLoss = (float)Math.pow(this.context.delta, 2)/2;
|
this.context.localLoss = (float)Math.pow(this.context.delta, 2)/2;
|
||||||
this.context.globalLoss += context.localLoss;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +1,14 @@
|
|||||||
package com.naaturel.ANN.implementation.simplePerceptron;
|
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||||
|
|
||||||
|
|
||||||
public class SimpleCorrectionStrategy implements AlgorithmStrategy {
|
public class SimpleCorrectionStrategy implements AlgorithmStrategy {
|
||||||
|
|
||||||
private final SimpleTrainingContext context;
|
private final TrainingContext context;
|
||||||
|
|
||||||
public SimpleCorrectionStrategy(SimpleTrainingContext context) {
|
public SimpleCorrectionStrategy(TrainingContext context) {
|
||||||
this.context = context;
|
this.context = context;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
package com.naaturel.ANN.implementation.simplePerceptron;
|
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||||
|
|
||||||
public class SimpleErrorRegistrationStrategy implements AlgorithmStrategy {
|
public class SimpleErrorRegistrationStrategy implements AlgorithmStrategy {
|
||||||
|
|
||||||
private final SimpleTrainingContext context;
|
private final TrainingContext context;
|
||||||
|
|
||||||
public SimpleErrorRegistrationStrategy(SimpleTrainingContext context) {
|
public SimpleErrorRegistrationStrategy(TrainingContext context) {
|
||||||
this.context = context;
|
this.context = context;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,21 +1,65 @@
|
|||||||
package com.naaturel.ANN.implementation.training;
|
package com.naaturel.ANN.implementation.training;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.Model;
|
||||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||||
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
||||||
import com.naaturel.ANN.domain.model.neuron.Weight;
|
import com.naaturel.ANN.domain.model.neuron.Weight;
|
||||||
|
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||||
|
import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext;
|
||||||
|
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentCorrectionStrategy;
|
||||||
|
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentErrorStrategy;
|
||||||
|
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext;
|
||||||
|
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStrategy;
|
||||||
|
import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStrategy;
|
||||||
|
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStrategy;
|
||||||
|
import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStrategy;
|
||||||
|
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy;
|
||||||
|
import com.naaturel.ANN.implementation.training.steps.*;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
|
||||||
/*public class AdalineTraining implements Trainer {
|
public class AdalineTraining implements Trainer {
|
||||||
|
|
||||||
public AdalineTraining(){
|
public AdalineTraining(){
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void train(Neuron n, float learningRate, DataSet dataSet) {
|
@Override
|
||||||
|
public void train(Model model, DataSet dataset) {
|
||||||
|
AdalineTrainingContext context = new AdalineTrainingContext();
|
||||||
|
context.dataset = dataset;
|
||||||
|
context.model = model;
|
||||||
|
context.learningRate = 0.003F;
|
||||||
|
|
||||||
|
List<TrainingStep> steps = List.of(
|
||||||
|
new PredictionStep(new SimplePredictionStrategy(context)),
|
||||||
|
new DeltaStep(new SimpleDeltaStrategy(context)),
|
||||||
|
new LossStep(new SquareLossStrategy(context)),
|
||||||
|
new ErrorRegistrationStep(new SimpleErrorRegistrationStrategy(context)),
|
||||||
|
new WeightCorrectionStep(new SimpleCorrectionStrategy(context))
|
||||||
|
);
|
||||||
|
|
||||||
|
new TrainingPipeline(steps)
|
||||||
|
.stopCondition(ctx -> ctx.globalLoss <= 0.125F || ctx.epoch > 10000)
|
||||||
|
.beforeEpoch(ctx -> {
|
||||||
|
ctx.globalLoss = 0.0F;
|
||||||
|
})
|
||||||
|
.afterEpoch(ctx -> {
|
||||||
|
ctx.globalLoss /= context.dataset.size();
|
||||||
|
})
|
||||||
|
.withVerbose(true)
|
||||||
|
.withTimeMeasurement(true)
|
||||||
|
.run(context);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*public void train(Neuron n, float learningRate, DataSet dataSet) {
|
||||||
int epoch = 1;
|
int epoch = 1;
|
||||||
int maxEpoch = 202;
|
int maxEpoch = 202;
|
||||||
float errorThreshold = 0.0F;
|
float errorThreshold = 0.0F;
|
||||||
@@ -76,6 +120,6 @@ import com.naaturel.ANN.domain.model.neuron.Weight;
|
|||||||
|
|
||||||
private float calculateLoss(float delta){
|
private float calculateLoss(float delta){
|
||||||
return (float) Math.pow(delta, 2)/2;
|
return (float) Math.pow(delta, 2)/2;
|
||||||
}
|
}*/
|
||||||
|
|
||||||
}*/
|
}
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ public class GradientDescentTraining implements Trainer {
|
|||||||
GradientDescentTrainingContext context = new GradientDescentTrainingContext();
|
GradientDescentTrainingContext context = new GradientDescentTrainingContext();
|
||||||
context.dataset = dataset;
|
context.dataset = dataset;
|
||||||
context.model = model;
|
context.model = model;
|
||||||
|
context.learningRate = 0.0011F;
|
||||||
context.correctorTerms = new ArrayList<>();
|
context.correctorTerms = new ArrayList<>();
|
||||||
|
|
||||||
List<TrainingStep> steps = List.of(
|
List<TrainingStep> steps = List.of(
|
||||||
@@ -37,17 +38,19 @@ public class GradientDescentTraining implements Trainer {
|
|||||||
);
|
);
|
||||||
|
|
||||||
new TrainingPipeline(steps)
|
new TrainingPipeline(steps)
|
||||||
.stopCondition(ctx -> ctx.globalLoss <= 0.125F || ctx.epoch > 100)
|
.stopCondition(ctx -> ctx.globalLoss <= 0.08F || ctx.epoch > 5000)
|
||||||
.beforeEpoch(ctx -> {
|
.beforeEpoch(ctx -> {
|
||||||
GradientDescentTrainingContext gdCtx = (GradientDescentTrainingContext) ctx;
|
GradientDescentTrainingContext gdCtx = (GradientDescentTrainingContext) ctx;
|
||||||
gdCtx.globalLoss = 0.0F;
|
gdCtx.globalLoss = 0.0F;
|
||||||
gdCtx.correctorTerms.clear();
|
gdCtx.correctorTerms.clear();
|
||||||
for (int i = 0; i < ctx.model.synCount(); i++){
|
gdCtx.model.forEachSynapse(syn -> gdCtx.correctorTerms.add(0F));
|
||||||
gdCtx.correctorTerms.add(0F);
|
})
|
||||||
}
|
.afterEpoch(ctx -> {
|
||||||
|
context.globalLoss /= context.dataset.size();
|
||||||
|
new GradientDescentCorrectionStrategy(context).apply();
|
||||||
})
|
})
|
||||||
.afterEpoch(ctx -> new GradientDescentCorrectionStrategy(context).apply())
|
|
||||||
.withVerbose(true)
|
.withVerbose(true)
|
||||||
|
.withTimeMeasurement(true)
|
||||||
.run(context);
|
.run(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
93
src/test/java/adaline/AdalineTest.java
Normal file
93
src/test/java/adaline/AdalineTest.java
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
package adaline;
|
||||||
|
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
|
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||||
|
import com.naaturel.ANN.domain.model.dataset.DatasetExtractor;
|
||||||
|
import com.naaturel.ANN.domain.model.neuron.*;
|
||||||
|
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||||
|
import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext;
|
||||||
|
import com.naaturel.ANN.implementation.gradientDescent.*;
|
||||||
|
import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
|
||||||
|
import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStrategy;
|
||||||
|
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStrategy;
|
||||||
|
import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStrategy;
|
||||||
|
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy;
|
||||||
|
import com.naaturel.ANN.implementation.training.steps.*;
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
|
public class AdalineTest {
|
||||||
|
|
||||||
|
private DataSet dataset;
|
||||||
|
private AdalineTrainingContext context;
|
||||||
|
|
||||||
|
private List<Synapse> synapses;
|
||||||
|
private Bias bias;
|
||||||
|
private Network network;
|
||||||
|
|
||||||
|
private TrainingPipeline pipeline;
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
public void init(){
|
||||||
|
dataset = new DatasetExtractor()
|
||||||
|
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and-gradient.csv");
|
||||||
|
|
||||||
|
List<Synapse> syns = new ArrayList<>();
|
||||||
|
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||||
|
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||||
|
|
||||||
|
bias = new Bias(new Weight(0));
|
||||||
|
|
||||||
|
Neuron neuron = new SimplePerceptron(syns, bias, new Linear());
|
||||||
|
Layer layer = new Layer(List.of(neuron));
|
||||||
|
network = new Network(List.of(layer));
|
||||||
|
|
||||||
|
context = new AdalineTrainingContext();
|
||||||
|
context.dataset = dataset;
|
||||||
|
context.model = network;
|
||||||
|
|
||||||
|
List<TrainingStep> steps = List.of(
|
||||||
|
new PredictionStep(new SimplePredictionStrategy(context)),
|
||||||
|
new DeltaStep(new SimpleDeltaStrategy(context)),
|
||||||
|
new LossStep(new SquareLossStrategy(context)),
|
||||||
|
new ErrorRegistrationStep(new SimpleErrorRegistrationStrategy(context)),
|
||||||
|
new WeightCorrectionStep(new SimpleCorrectionStrategy(context))
|
||||||
|
);
|
||||||
|
|
||||||
|
pipeline = new TrainingPipeline(steps)
|
||||||
|
.stopCondition(ctx -> ctx.globalLoss <= 0.1329F || ctx.epoch > 10000)
|
||||||
|
.beforeEpoch(ctx -> {
|
||||||
|
ctx.globalLoss = 0.0F;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void test_the_whole_algorithm(){
|
||||||
|
|
||||||
|
List<Float> expectedGlobalLosses = List.of(
|
||||||
|
0.501522F,
|
||||||
|
0.498601F
|
||||||
|
);
|
||||||
|
|
||||||
|
context.learningRate = 0.03F;
|
||||||
|
pipeline.afterEpoch(ctx -> {
|
||||||
|
ctx.globalLoss /= context.dataset.size();
|
||||||
|
|
||||||
|
int index = ctx.epoch-1;
|
||||||
|
if(index >= expectedGlobalLosses.size()) return;
|
||||||
|
|
||||||
|
//assertEquals(expectedGlobalLosses.get(index), context.globalLoss, 0.00001f);
|
||||||
|
});
|
||||||
|
|
||||||
|
pipeline.run(context);
|
||||||
|
assertEquals(214, context.epoch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@@ -49,7 +49,7 @@ public class GradientDescentTest {
|
|||||||
context = new GradientDescentTrainingContext();
|
context = new GradientDescentTrainingContext();
|
||||||
context.dataset = dataset;
|
context.dataset = dataset;
|
||||||
context.model = network;
|
context.model = network;
|
||||||
context.correctorTerms = new ArrayList<>();
|
context.correctorTerms = new ArrayList<>();
|
||||||
|
|
||||||
List<TrainingStep> steps = List.of(
|
List<TrainingStep> steps = List.of(
|
||||||
new PredictionStep(new SimplePredictionStrategy(context)),
|
new PredictionStep(new SimplePredictionStrategy(context)),
|
||||||
@@ -92,7 +92,9 @@ public class GradientDescentTest {
|
|||||||
assertEquals(expectedGlobalLosses.get(index), context.globalLoss, 0.00001f);
|
assertEquals(expectedGlobalLosses.get(index), context.globalLoss, 0.00001f);
|
||||||
});
|
});
|
||||||
|
|
||||||
pipeline.run(context);
|
pipeline
|
||||||
|
.withVerbose(true)
|
||||||
|
.run(context);
|
||||||
assertEquals(67, context.epoch);
|
assertEquals(67, context.epoch);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user