Reimplement Adaline

This commit is contained in:
2026-03-26 11:27:10 +01:00
parent c389646794
commit 0d3ab0de8d
11 changed files with 187 additions and 25 deletions

View File

@@ -8,6 +8,7 @@ import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.implementation.gradientDescent.Linear; import com.naaturel.ANN.implementation.gradientDescent.Linear;
import com.naaturel.ANN.implementation.simplePerceptron.Heaviside; import com.naaturel.ANN.implementation.simplePerceptron.Heaviside;
import com.naaturel.ANN.implementation.neuron.SimplePerceptron; import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
import com.naaturel.ANN.implementation.training.AdalineTraining;
import com.naaturel.ANN.implementation.training.GradientDescentTraining; import com.naaturel.ANN.implementation.training.GradientDescentTraining;
import com.naaturel.ANN.implementation.training.SimpleTraining; import com.naaturel.ANN.implementation.training.SimpleTraining;
@@ -21,7 +22,7 @@ public class Main {
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_9.csv"); .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_9.csv");
DataSet andDataset = new DatasetExtractor() DataSet andDataset = new DatasetExtractor()
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and.csv"); .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and-gradient.csv");
List<Synapse> syns = new ArrayList<>(); List<Synapse> syns = new ArrayList<>();
syns.add(new Synapse(new Input(0), new Weight(0))); syns.add(new Synapse(new Input(0), new Weight(0)));
@@ -29,12 +30,12 @@ public class Main {
Bias bias = new Bias(new Weight(0)); Bias bias = new Bias(new Weight(0));
Neuron neuron = new SimplePerceptron(syns, bias, new Heaviside()); Neuron neuron = new SimplePerceptron(syns, bias, new Linear());
Layer layer = new Layer(List.of(neuron)); Layer layer = new Layer(List.of(neuron));
Network network = new Network(List.of(layer)); Network network = new Network(List.of(layer));
Trainer trainer = new SimpleTraining(); Trainer trainer = new AdalineTraining();
trainer.train(network, andDataset); trainer.train(network, dataset);
} }
} }

View File

@@ -4,6 +4,7 @@ import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.domain.abstraction.TrainingStep; import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry; import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
import java.sql.Time;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.function.Consumer; import java.util.function.Consumer;
@@ -52,11 +53,23 @@ public class TrainingPipeline {
} }
public void run(TrainingContext ctx) { public void run(TrainingContext ctx) {
long start = this.timeMeasurement ? System.currentTimeMillis() : 0;
do { do {
this.beforeEpoch.accept(ctx); this.beforeEpoch.accept(ctx);
this.executeSteps(ctx); this.executeSteps(ctx);
this.afterEpoch.accept(ctx); this.afterEpoch.accept(ctx);
if(this.verbose) {
System.out.printf("[Global error] : %f\n", ctx.globalLoss);
}
} while (!this.stopCondition.test(ctx)); } while (!this.stopCondition.test(ctx));
if(this.timeMeasurement) {
long end = System.currentTimeMillis();
System.out.printf("[Training finished in %.3fs]", (end-start)/1000.0);
}
} }
private void executeSteps(TrainingContext ctx){ private void executeSteps(TrainingContext ctx){
@@ -74,9 +87,6 @@ public class TrainingPipeline {
System.out.printf("loss : %.5f\n", ctx.localLoss); System.out.printf("loss : %.5f\n", ctx.localLoss);
} }
} }
if(this.verbose) {
System.out.printf("[Global error] : %.2f\n", ctx.globalLoss);
}
ctx.epoch += 1; ctx.epoch += 1;
} }
} }

View File

@@ -0,0 +1,6 @@
package com.naaturel.ANN.implementation.adaline;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
public class AdalineTrainingContext extends TrainingContext {
}

View File

@@ -22,5 +22,6 @@ public class GradientDescentErrorStrategy implements AlgorithmStrategy {
context.correctorTerms.set(i.get(), corrector); context.correctorTerms.set(i.get(), corrector);
i.incrementAndGet(); i.incrementAndGet();
}); });
context.globalLoss += context.localLoss;
} }
} }

View File

@@ -1,19 +1,19 @@
package com.naaturel.ANN.implementation.gradientDescent; package com.naaturel.ANN.implementation.gradientDescent;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleTrainingContext; import com.naaturel.ANN.implementation.simplePerceptron.SimpleTrainingContext;
public class SquareLossStrategy implements AlgorithmStrategy { public class SquareLossStrategy implements AlgorithmStrategy {
private final GradientDescentTrainingContext context; private final TrainingContext context;
public SquareLossStrategy(GradientDescentTrainingContext context) { public SquareLossStrategy(TrainingContext context) {
this.context = context; this.context = context;
} }
@Override @Override
public void apply() { public void apply() {
this.context.localLoss = (float)Math.pow(this.context.delta, 2)/2; this.context.localLoss = (float)Math.pow(this.context.delta, 2)/2;
this.context.globalLoss += context.localLoss;
} }
} }

View File

@@ -1,13 +1,14 @@
package com.naaturel.ANN.implementation.simplePerceptron; package com.naaturel.ANN.implementation.simplePerceptron;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
public class SimpleCorrectionStrategy implements AlgorithmStrategy { public class SimpleCorrectionStrategy implements AlgorithmStrategy {
private final SimpleTrainingContext context; private final TrainingContext context;
public SimpleCorrectionStrategy(SimpleTrainingContext context) { public SimpleCorrectionStrategy(TrainingContext context) {
this.context = context; this.context = context;
} }

View File

@@ -1,12 +1,13 @@
package com.naaturel.ANN.implementation.simplePerceptron; package com.naaturel.ANN.implementation.simplePerceptron;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
public class SimpleErrorRegistrationStrategy implements AlgorithmStrategy { public class SimpleErrorRegistrationStrategy implements AlgorithmStrategy {
private final SimpleTrainingContext context; private final TrainingContext context;
public SimpleErrorRegistrationStrategy(SimpleTrainingContext context) { public SimpleErrorRegistrationStrategy(TrainingContext context) {
this.context = context; this.context = context;
} }

View File

@@ -1,21 +1,65 @@
package com.naaturel.ANN.implementation.training; package com.naaturel.ANN.implementation.training;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.Neuron; import com.naaturel.ANN.domain.abstraction.Neuron;
import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.dataset.DataSet; import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry; import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
import com.naaturel.ANN.domain.model.neuron.Input; import com.naaturel.ANN.domain.model.neuron.Input;
import com.naaturel.ANN.domain.model.neuron.Synapse; import com.naaturel.ANN.domain.model.neuron.Synapse;
import com.naaturel.ANN.domain.model.neuron.Weight; import com.naaturel.ANN.domain.model.neuron.Weight;
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext;
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentCorrectionStrategy;
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentErrorStrategy;
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext;
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStrategy;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStrategy;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStrategy;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStrategy;
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy;
import com.naaturel.ANN.implementation.training.steps.*;
import java.util.ArrayList;
import java.util.List;
/*public class AdalineTraining implements Trainer { public class AdalineTraining implements Trainer {
public AdalineTraining(){ public AdalineTraining(){
} }
public void train(Neuron n, float learningRate, DataSet dataSet) { @Override
public void train(Model model, DataSet dataset) {
AdalineTrainingContext context = new AdalineTrainingContext();
context.dataset = dataset;
context.model = model;
context.learningRate = 0.003F;
List<TrainingStep> steps = List.of(
new PredictionStep(new SimplePredictionStrategy(context)),
new DeltaStep(new SimpleDeltaStrategy(context)),
new LossStep(new SquareLossStrategy(context)),
new ErrorRegistrationStep(new SimpleErrorRegistrationStrategy(context)),
new WeightCorrectionStep(new SimpleCorrectionStrategy(context))
);
new TrainingPipeline(steps)
.stopCondition(ctx -> ctx.globalLoss <= 0.125F || ctx.epoch > 10000)
.beforeEpoch(ctx -> {
ctx.globalLoss = 0.0F;
})
.afterEpoch(ctx -> {
ctx.globalLoss /= context.dataset.size();
})
.withVerbose(true)
.withTimeMeasurement(true)
.run(context);
}
/*public void train(Neuron n, float learningRate, DataSet dataSet) {
int epoch = 1; int epoch = 1;
int maxEpoch = 202; int maxEpoch = 202;
float errorThreshold = 0.0F; float errorThreshold = 0.0F;
@@ -76,6 +120,6 @@ import com.naaturel.ANN.domain.model.neuron.Weight;
private float calculateLoss(float delta){ private float calculateLoss(float delta){
return (float) Math.pow(delta, 2)/2; return (float) Math.pow(delta, 2)/2;
}
}*/ }*/
}

View File

@@ -27,6 +27,7 @@ public class GradientDescentTraining implements Trainer {
GradientDescentTrainingContext context = new GradientDescentTrainingContext(); GradientDescentTrainingContext context = new GradientDescentTrainingContext();
context.dataset = dataset; context.dataset = dataset;
context.model = model; context.model = model;
context.learningRate = 0.0011F;
context.correctorTerms = new ArrayList<>(); context.correctorTerms = new ArrayList<>();
List<TrainingStep> steps = List.of( List<TrainingStep> steps = List.of(
@@ -37,17 +38,19 @@ public class GradientDescentTraining implements Trainer {
); );
new TrainingPipeline(steps) new TrainingPipeline(steps)
.stopCondition(ctx -> ctx.globalLoss <= 0.125F || ctx.epoch > 100) .stopCondition(ctx -> ctx.globalLoss <= 0.08F || ctx.epoch > 5000)
.beforeEpoch(ctx -> { .beforeEpoch(ctx -> {
GradientDescentTrainingContext gdCtx = (GradientDescentTrainingContext) ctx; GradientDescentTrainingContext gdCtx = (GradientDescentTrainingContext) ctx;
gdCtx.globalLoss = 0.0F; gdCtx.globalLoss = 0.0F;
gdCtx.correctorTerms.clear(); gdCtx.correctorTerms.clear();
for (int i = 0; i < ctx.model.synCount(); i++){ gdCtx.model.forEachSynapse(syn -> gdCtx.correctorTerms.add(0F));
gdCtx.correctorTerms.add(0F); })
} .afterEpoch(ctx -> {
context.globalLoss /= context.dataset.size();
new GradientDescentCorrectionStrategy(context).apply();
}) })
.afterEpoch(ctx -> new GradientDescentCorrectionStrategy(context).apply())
.withVerbose(true) .withVerbose(true)
.withTimeMeasurement(true)
.run(context); .run(context);
} }

View File

@@ -0,0 +1,93 @@
package adaline;
import com.naaturel.ANN.domain.abstraction.Neuron;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DatasetExtractor;
import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext;
import com.naaturel.ANN.implementation.gradientDescent.*;
import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStrategy;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStrategy;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStrategy;
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy;
import com.naaturel.ANN.implementation.training.steps.*;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class AdalineTest {
private DataSet dataset;
private AdalineTrainingContext context;
private List<Synapse> synapses;
private Bias bias;
private Network network;
private TrainingPipeline pipeline;
@BeforeEach
public void init(){
dataset = new DatasetExtractor()
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and-gradient.csv");
List<Synapse> syns = new ArrayList<>();
syns.add(new Synapse(new Input(0), new Weight(0)));
syns.add(new Synapse(new Input(0), new Weight(0)));
bias = new Bias(new Weight(0));
Neuron neuron = new SimplePerceptron(syns, bias, new Linear());
Layer layer = new Layer(List.of(neuron));
network = new Network(List.of(layer));
context = new AdalineTrainingContext();
context.dataset = dataset;
context.model = network;
List<TrainingStep> steps = List.of(
new PredictionStep(new SimplePredictionStrategy(context)),
new DeltaStep(new SimpleDeltaStrategy(context)),
new LossStep(new SquareLossStrategy(context)),
new ErrorRegistrationStep(new SimpleErrorRegistrationStrategy(context)),
new WeightCorrectionStep(new SimpleCorrectionStrategy(context))
);
pipeline = new TrainingPipeline(steps)
.stopCondition(ctx -> ctx.globalLoss <= 0.1329F || ctx.epoch > 10000)
.beforeEpoch(ctx -> {
ctx.globalLoss = 0.0F;
});
}
@Test
public void test_the_whole_algorithm(){
List<Float> expectedGlobalLosses = List.of(
0.501522F,
0.498601F
);
context.learningRate = 0.03F;
pipeline.afterEpoch(ctx -> {
ctx.globalLoss /= context.dataset.size();
int index = ctx.epoch-1;
if(index >= expectedGlobalLosses.size()) return;
//assertEquals(expectedGlobalLosses.get(index), context.globalLoss, 0.00001f);
});
pipeline.run(context);
assertEquals(214, context.epoch);
}
}

View File

@@ -92,7 +92,9 @@ public class GradientDescentTest {
assertEquals(expectedGlobalLosses.get(index), context.globalLoss, 0.00001f); assertEquals(expectedGlobalLosses.get(index), context.globalLoss, 0.00001f);
}); });
pipeline.run(context); pipeline
.withVerbose(true)
.run(context);
assertEquals(67, context.epoch); assertEquals(67, context.epoch);
} }
} }