Start to add test coverage

This commit is contained in:
2026-03-25 22:36:26 +01:00
parent 65d3a0e3e4
commit 76465ab6ee
16 changed files with 112 additions and 28 deletions

View File

@@ -20,8 +20,8 @@ public class Main {
DataSet dataset = new DatasetExtractor() DataSet dataset = new DatasetExtractor()
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_9.csv"); .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_9.csv");
DataSet orDataset = new DatasetExtractor() DataSet andDataset = new DatasetExtractor()
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/or.csv"); .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and.csv");
List<Synapse> syns = new ArrayList<>(); List<Synapse> syns = new ArrayList<>();
syns.add(new Synapse(new Input(0), new Weight(0))); syns.add(new Synapse(new Input(0), new Weight(0)));
@@ -29,12 +29,12 @@ public class Main {
Bias bias = new Bias(new Weight(0)); Bias bias = new Bias(new Weight(0));
Neuron neuron = new SimplePerceptron(syns, bias, new Linear()); Neuron neuron = new SimplePerceptron(syns, bias, new Heaviside());
Layer layer = new Layer(List.of(neuron)); Layer layer = new Layer(List.of(neuron));
Network network = new Network(List.of(layer)); Network network = new Network(List.of(layer));
Trainer trainer = new GradientDescentTraining(); Trainer trainer = new SimpleTraining();
trainer.train(network, dataset); trainer.train(network, andDataset);
} }
} }

View File

@@ -1,5 +1,6 @@
package com.naaturel.ANN.domain.abstraction; package com.naaturel.ANN.domain.abstraction;
@FunctionalInterface
public interface AlgorithmStrategy { public interface AlgorithmStrategy {
void apply(); void apply();

View File

@@ -8,7 +8,7 @@ import java.util.function.Consumer;
public interface Model { public interface Model {
int synCount(); int synCount();
void applyOnSynapses(Consumer<Synapse> consumer); void forEachSynapse(Consumer<Synapse> consumer);
List<Float> predict(List<Input> inputs); List<Float> predict(List<Input> inputs);
} }

View File

@@ -9,7 +9,7 @@ public class DataSet implements Iterable<DataSetEntry>{
private Map<DataSetEntry, Label> data; private Map<DataSetEntry, Label> data;
public DataSet() { public DataSet() {
this(new HashMap<>()); this(new LinkedHashMap<>());
} }
public DataSet(Map<DataSetEntry, Label> data){ public DataSet(Map<DataSetEntry, Label> data){

View File

@@ -5,15 +5,12 @@ import com.naaturel.ANN.domain.model.neuron.Input;
import java.io.BufferedReader; import java.io.BufferedReader;
import java.io.FileReader; import java.io.FileReader;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.*;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class DatasetExtractor { public class DatasetExtractor {
public DataSet extract(String path) { public DataSet extract(String path) {
Map<DataSetEntry, Label> data = new HashMap<>(); Map<DataSetEntry, Label> data = new LinkedHashMap<>();
try (BufferedReader reader = new BufferedReader(new FileReader(path))) { try (BufferedReader reader = new BufferedReader(new FileReader(path))) {
String line; String line;

View File

@@ -35,7 +35,7 @@ public class Layer implements Model {
} }
@Override @Override
public void applyOnSynapses(Consumer<Synapse> consumer) { public void forEachSynapse(Consumer<Synapse> consumer) {
this.neurons.forEach(neuron -> neuron.applyOnSynapses(consumer)); this.neurons.forEach(neuron -> neuron.forEachSynapse(consumer));
} }
} }

View File

@@ -34,7 +34,7 @@ public class Network implements Model {
} }
@Override @Override
public void applyOnSynapses(Consumer<Synapse> consumer) { public void forEachSynapse(Consumer<Synapse> consumer) {
this.layers.forEach(layer -> layer.applyOnSynapses(consumer)); this.layers.forEach(layer -> layer.forEachSynapse(consumer));
} }
} }

View File

@@ -56,9 +56,6 @@ public class TrainingPipeline {
this.beforeEpoch.accept(ctx); this.beforeEpoch.accept(ctx);
this.executeSteps(ctx); this.executeSteps(ctx);
this.afterEpoch.accept(ctx); this.afterEpoch.accept(ctx);
if(this.verbose) {
System.out.printf("[Global error] : %.2f\n", ctx.globalLoss);
}
} while (!this.stopCondition.test(ctx)); } while (!this.stopCondition.test(ctx));
} }
@@ -77,6 +74,9 @@ public class TrainingPipeline {
System.out.printf("loss : %.5f\n", ctx.localLoss); System.out.printf("loss : %.5f\n", ctx.localLoss);
} }
} }
if(this.verbose) {
System.out.printf("[Global error] : %.2f\n", ctx.globalLoss);
}
ctx.epoch += 1; ctx.epoch += 1;
} }
} }

View File

@@ -15,7 +15,7 @@ public class GradientDescentCorrectionStrategy implements AlgorithmStrategy {
@Override @Override
public void apply() { public void apply() {
AtomicInteger i = new AtomicInteger(0); AtomicInteger i = new AtomicInteger(0);
context.model.applyOnSynapses(syn -> { context.model.forEachSynapse(syn -> {
float corrector = context.correctorTerms.get(i.get()); float corrector = context.correctorTerms.get(i.get());
float c = syn.getWeight() + corrector; float c = syn.getWeight() + corrector;
syn.setWeight(c); syn.setWeight(c);

View File

@@ -16,7 +16,7 @@ public class GradientDescentErrorStrategy implements AlgorithmStrategy {
@Override @Override
public void apply() { public void apply() {
AtomicInteger i = new AtomicInteger(0); AtomicInteger i = new AtomicInteger(0);
context.model.applyOnSynapses(syn -> { context.model.forEachSynapse(syn -> {
float corrector = context.correctorTerms.get(i.get()); float corrector = context.correctorTerms.get(i.get());
corrector += context.learningRate * context.delta * syn.getInput(); corrector += context.learningRate * context.delta * syn.getInput();
context.correctorTerms.set(i.get(), corrector); context.correctorTerms.set(i.get(), corrector);

View File

@@ -22,7 +22,7 @@ public class SimplePerceptron extends Neuron {
} }
@Override @Override
public void applyOnSynapses(Consumer<Synapse> consumer) { public void forEachSynapse(Consumer<Synapse> consumer) {
consumer.accept(this.bias); consumer.accept(this.bias);
this.synapses.forEach(consumer); this.synapses.forEach(consumer);
} }
@@ -30,10 +30,10 @@ public class SimplePerceptron extends Neuron {
@Override @Override
public float calculateWeightedSum() { public float calculateWeightedSum() {
float res = 0; float res = 0;
res += this.bias.getWeight() * this.bias.getInput();
for(Synapse syn : super.synapses){ for(Synapse syn : super.synapses){
res += syn.getWeight() * syn.getInput(); res += syn.getWeight() * syn.getInput();
} }
res += this.bias.getWeight() * this.bias.getInput();
return res; return res;
} }

View File

@@ -12,6 +12,6 @@ public class Heaviside implements ActivationFunction {
@Override @Override
public float accept(Neuron n) { public float accept(Neuron n) {
float weightedSum = n.calculateWeightedSum(); float weightedSum = n.calculateWeightedSum();
return weightedSum <= 0 ? 0:1; return weightedSum < 0 ? 0:1;
} }
} }

View File

@@ -2,6 +2,7 @@ package com.naaturel.ANN.implementation.simplePerceptron;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
public class SimpleCorrectionStrategy implements AlgorithmStrategy { public class SimpleCorrectionStrategy implements AlgorithmStrategy {
private final SimpleTrainingContext context; private final SimpleTrainingContext context;
@@ -13,7 +14,7 @@ public class SimpleCorrectionStrategy implements AlgorithmStrategy {
@Override @Override
public void apply() { public void apply() {
if(context.currentLabel.getValue() == context.prediction) return ; if(context.currentLabel.getValue() == context.prediction) return ;
context.model.applyOnSynapses(syn -> { context.model.forEachSynapse(syn -> {
float currentW = syn.getWeight(); float currentW = syn.getWeight();
float currentInput = syn.getInput(); float currentInput = syn.getInput();
float newValue = currentW + (context.learningRate * context.delta * currentInput); float newValue = currentW + (context.learningRate * context.delta * currentInput);

View File

@@ -27,7 +27,7 @@ public class GradientDescentTraining implements Trainer {
GradientDescentTrainingContext context = new GradientDescentTrainingContext(); GradientDescentTrainingContext context = new GradientDescentTrainingContext();
context.dataset = dataset; context.dataset = dataset;
context.model = model; context.model = model;
context.learningRate = 0.00011F; context.learningRate = 0.2F;
context.correctorTerms = new ArrayList<>(); context.correctorTerms = new ArrayList<>();
List<TrainingStep> steps = List.of( List<TrainingStep> steps = List.of(
@@ -40,7 +40,7 @@ public class GradientDescentTraining implements Trainer {
TrainingPipeline pipeline = new TrainingPipeline(steps); TrainingPipeline pipeline = new TrainingPipeline(steps);
pipeline pipeline
.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 1000) .stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 50)
.beforeEpoch(ctx -> { .beforeEpoch(ctx -> {
ctx.globalLoss = 0.0F; ctx.globalLoss = 0.0F;
for (int i = 0; i < model.synCount(); i++){ for (int i = 0; i < model.synCount(); i++){

View File

@@ -33,7 +33,7 @@ public class SimpleTraining implements Trainer {
TrainingPipeline pipeline = new TrainingPipeline(steps); TrainingPipeline pipeline = new TrainingPipeline(steps);
pipeline pipeline
.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 100) .stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 10)
.beforeEpoch(ctx -> ctx.globalLoss = 0) .beforeEpoch(ctx -> ctx.globalLoss = 0)
.withVerbose(true) .withVerbose(true)
.run(context); .run(context);

View File

@@ -0,0 +1,85 @@
package perceptron;
import com.naaturel.ANN.domain.abstraction.Neuron;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DatasetExtractor;
import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
import com.naaturel.ANN.implementation.simplePerceptron.*;
import com.naaturel.ANN.implementation.training.steps.*;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
public class simplePerceptronTest {
private DataSet dataset;
private SimpleTrainingContext context;
private List<Synapse> synapses;
private Bias bias;
private Network network;
private TrainingPipeline pipeline;
@BeforeEach
public void init(){
dataset = new DatasetExtractor()
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and.csv");
List<Synapse> syns = new ArrayList<>();
syns.add(new Synapse(new Input(0), new Weight(0)));
syns.add(new Synapse(new Input(0), new Weight(0)));
bias = new Bias(new Weight(0));
Neuron neuron = new SimplePerceptron(syns, bias, new Heaviside());
Layer layer = new Layer(List.of(neuron));
network = new Network(List.of(layer));
context = new SimpleTrainingContext();
context.dataset = dataset;
context.model = network;
List<TrainingStep> steps = List.of(
new PredictionStep(new SimplePredictionStrategy(context)),
new DeltaStep(new SimpleDeltaStrategy(context)),
new LossStep(new SimpleLossStrategy(context)),
new ErrorRegistrationStep(new SimpleErrorRegistrationStrategy(context)),
new WeightCorrectionStep(new SimpleCorrectionStrategy(context))
);
pipeline = new TrainingPipeline(steps);
pipeline.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 100);
pipeline.beforeEpoch(ctx -> ctx.globalLoss = 0);
}
@Test
public void test_the_whole_algorithm(){
List<Float> expectedGlobalLosses = List.of(
2.0F,
3.0F,
3.0F,
2.0F,
1.0F,
0.0F
);
context.learningRate = 1F;
pipeline.afterEpoch(ctx -> {
int index = ctx.epoch-1;
assertEquals(expectedGlobalLosses.get(index), context.globalLoss);
});
pipeline.run(context);
assertEquals(6, context.epoch);
}
}