Start to add test coverage
This commit is contained in:
@@ -20,8 +20,8 @@ public class Main {
|
|||||||
DataSet dataset = new DatasetExtractor()
|
DataSet dataset = new DatasetExtractor()
|
||||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_9.csv");
|
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_9.csv");
|
||||||
|
|
||||||
DataSet orDataset = new DatasetExtractor()
|
DataSet andDataset = new DatasetExtractor()
|
||||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/or.csv");
|
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and.csv");
|
||||||
|
|
||||||
List<Synapse> syns = new ArrayList<>();
|
List<Synapse> syns = new ArrayList<>();
|
||||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||||
@@ -29,12 +29,12 @@ public class Main {
|
|||||||
|
|
||||||
Bias bias = new Bias(new Weight(0));
|
Bias bias = new Bias(new Weight(0));
|
||||||
|
|
||||||
Neuron neuron = new SimplePerceptron(syns, bias, new Linear());
|
Neuron neuron = new SimplePerceptron(syns, bias, new Heaviside());
|
||||||
Layer layer = new Layer(List.of(neuron));
|
Layer layer = new Layer(List.of(neuron));
|
||||||
Network network = new Network(List.of(layer));
|
Network network = new Network(List.of(layer));
|
||||||
|
|
||||||
Trainer trainer = new GradientDescentTraining();
|
Trainer trainer = new SimpleTraining();
|
||||||
trainer.train(network, dataset);
|
trainer.train(network, andDataset);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
package com.naaturel.ANN.domain.abstraction;
|
package com.naaturel.ANN.domain.abstraction;
|
||||||
|
|
||||||
|
@FunctionalInterface
|
||||||
public interface AlgorithmStrategy {
|
public interface AlgorithmStrategy {
|
||||||
|
|
||||||
void apply();
|
void apply();
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import java.util.function.Consumer;
|
|||||||
|
|
||||||
public interface Model {
|
public interface Model {
|
||||||
int synCount();
|
int synCount();
|
||||||
void applyOnSynapses(Consumer<Synapse> consumer);
|
void forEachSynapse(Consumer<Synapse> consumer);
|
||||||
List<Float> predict(List<Input> inputs);
|
List<Float> predict(List<Input> inputs);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ public class DataSet implements Iterable<DataSetEntry>{
|
|||||||
private Map<DataSetEntry, Label> data;
|
private Map<DataSetEntry, Label> data;
|
||||||
|
|
||||||
public DataSet() {
|
public DataSet() {
|
||||||
this(new HashMap<>());
|
this(new LinkedHashMap<>());
|
||||||
}
|
}
|
||||||
|
|
||||||
public DataSet(Map<DataSetEntry, Label> data){
|
public DataSet(Map<DataSetEntry, Label> data){
|
||||||
|
|||||||
@@ -5,15 +5,12 @@ import com.naaturel.ANN.domain.model.neuron.Input;
|
|||||||
import java.io.BufferedReader;
|
import java.io.BufferedReader;
|
||||||
import java.io.FileReader;
|
import java.io.FileReader;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.ArrayList;
|
import java.util.*;
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
public class DatasetExtractor {
|
public class DatasetExtractor {
|
||||||
|
|
||||||
public DataSet extract(String path) {
|
public DataSet extract(String path) {
|
||||||
Map<DataSetEntry, Label> data = new HashMap<>();
|
Map<DataSetEntry, Label> data = new LinkedHashMap<>();
|
||||||
|
|
||||||
try (BufferedReader reader = new BufferedReader(new FileReader(path))) {
|
try (BufferedReader reader = new BufferedReader(new FileReader(path))) {
|
||||||
String line;
|
String line;
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ public class Layer implements Model {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void applyOnSynapses(Consumer<Synapse> consumer) {
|
public void forEachSynapse(Consumer<Synapse> consumer) {
|
||||||
this.neurons.forEach(neuron -> neuron.applyOnSynapses(consumer));
|
this.neurons.forEach(neuron -> neuron.forEachSynapse(consumer));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ public class Network implements Model {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void applyOnSynapses(Consumer<Synapse> consumer) {
|
public void forEachSynapse(Consumer<Synapse> consumer) {
|
||||||
this.layers.forEach(layer -> layer.applyOnSynapses(consumer));
|
this.layers.forEach(layer -> layer.forEachSynapse(consumer));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -56,9 +56,6 @@ public class TrainingPipeline {
|
|||||||
this.beforeEpoch.accept(ctx);
|
this.beforeEpoch.accept(ctx);
|
||||||
this.executeSteps(ctx);
|
this.executeSteps(ctx);
|
||||||
this.afterEpoch.accept(ctx);
|
this.afterEpoch.accept(ctx);
|
||||||
if(this.verbose) {
|
|
||||||
System.out.printf("[Global error] : %.2f\n", ctx.globalLoss);
|
|
||||||
}
|
|
||||||
} while (!this.stopCondition.test(ctx));
|
} while (!this.stopCondition.test(ctx));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -77,6 +74,9 @@ public class TrainingPipeline {
|
|||||||
System.out.printf("loss : %.5f\n", ctx.localLoss);
|
System.out.printf("loss : %.5f\n", ctx.localLoss);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if(this.verbose) {
|
||||||
|
System.out.printf("[Global error] : %.2f\n", ctx.globalLoss);
|
||||||
|
}
|
||||||
ctx.epoch += 1;
|
ctx.epoch += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ public class GradientDescentCorrectionStrategy implements AlgorithmStrategy {
|
|||||||
@Override
|
@Override
|
||||||
public void apply() {
|
public void apply() {
|
||||||
AtomicInteger i = new AtomicInteger(0);
|
AtomicInteger i = new AtomicInteger(0);
|
||||||
context.model.applyOnSynapses(syn -> {
|
context.model.forEachSynapse(syn -> {
|
||||||
float corrector = context.correctorTerms.get(i.get());
|
float corrector = context.correctorTerms.get(i.get());
|
||||||
float c = syn.getWeight() + corrector;
|
float c = syn.getWeight() + corrector;
|
||||||
syn.setWeight(c);
|
syn.setWeight(c);
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ public class GradientDescentErrorStrategy implements AlgorithmStrategy {
|
|||||||
@Override
|
@Override
|
||||||
public void apply() {
|
public void apply() {
|
||||||
AtomicInteger i = new AtomicInteger(0);
|
AtomicInteger i = new AtomicInteger(0);
|
||||||
context.model.applyOnSynapses(syn -> {
|
context.model.forEachSynapse(syn -> {
|
||||||
float corrector = context.correctorTerms.get(i.get());
|
float corrector = context.correctorTerms.get(i.get());
|
||||||
corrector += context.learningRate * context.delta * syn.getInput();
|
corrector += context.learningRate * context.delta * syn.getInput();
|
||||||
context.correctorTerms.set(i.get(), corrector);
|
context.correctorTerms.set(i.get(), corrector);
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ public class SimplePerceptron extends Neuron {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void applyOnSynapses(Consumer<Synapse> consumer) {
|
public void forEachSynapse(Consumer<Synapse> consumer) {
|
||||||
consumer.accept(this.bias);
|
consumer.accept(this.bias);
|
||||||
this.synapses.forEach(consumer);
|
this.synapses.forEach(consumer);
|
||||||
}
|
}
|
||||||
@@ -30,10 +30,10 @@ public class SimplePerceptron extends Neuron {
|
|||||||
@Override
|
@Override
|
||||||
public float calculateWeightedSum() {
|
public float calculateWeightedSum() {
|
||||||
float res = 0;
|
float res = 0;
|
||||||
|
res += this.bias.getWeight() * this.bias.getInput();
|
||||||
for(Synapse syn : super.synapses){
|
for(Synapse syn : super.synapses){
|
||||||
res += syn.getWeight() * syn.getInput();
|
res += syn.getWeight() * syn.getInput();
|
||||||
}
|
}
|
||||||
res += this.bias.getWeight() * this.bias.getInput();
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,6 @@ public class Heaviside implements ActivationFunction {
|
|||||||
@Override
|
@Override
|
||||||
public float accept(Neuron n) {
|
public float accept(Neuron n) {
|
||||||
float weightedSum = n.calculateWeightedSum();
|
float weightedSum = n.calculateWeightedSum();
|
||||||
return weightedSum <= 0 ? 0:1;
|
return weightedSum < 0 ? 0:1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package com.naaturel.ANN.implementation.simplePerceptron;
|
|||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||||
|
|
||||||
|
|
||||||
public class SimpleCorrectionStrategy implements AlgorithmStrategy {
|
public class SimpleCorrectionStrategy implements AlgorithmStrategy {
|
||||||
|
|
||||||
private final SimpleTrainingContext context;
|
private final SimpleTrainingContext context;
|
||||||
@@ -13,7 +14,7 @@ public class SimpleCorrectionStrategy implements AlgorithmStrategy {
|
|||||||
@Override
|
@Override
|
||||||
public void apply() {
|
public void apply() {
|
||||||
if(context.currentLabel.getValue() == context.prediction) return ;
|
if(context.currentLabel.getValue() == context.prediction) return ;
|
||||||
context.model.applyOnSynapses(syn -> {
|
context.model.forEachSynapse(syn -> {
|
||||||
float currentW = syn.getWeight();
|
float currentW = syn.getWeight();
|
||||||
float currentInput = syn.getInput();
|
float currentInput = syn.getInput();
|
||||||
float newValue = currentW + (context.learningRate * context.delta * currentInput);
|
float newValue = currentW + (context.learningRate * context.delta * currentInput);
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ public class GradientDescentTraining implements Trainer {
|
|||||||
GradientDescentTrainingContext context = new GradientDescentTrainingContext();
|
GradientDescentTrainingContext context = new GradientDescentTrainingContext();
|
||||||
context.dataset = dataset;
|
context.dataset = dataset;
|
||||||
context.model = model;
|
context.model = model;
|
||||||
context.learningRate = 0.00011F;
|
context.learningRate = 0.2F;
|
||||||
context.correctorTerms = new ArrayList<>();
|
context.correctorTerms = new ArrayList<>();
|
||||||
|
|
||||||
List<TrainingStep> steps = List.of(
|
List<TrainingStep> steps = List.of(
|
||||||
@@ -40,7 +40,7 @@ public class GradientDescentTraining implements Trainer {
|
|||||||
|
|
||||||
TrainingPipeline pipeline = new TrainingPipeline(steps);
|
TrainingPipeline pipeline = new TrainingPipeline(steps);
|
||||||
pipeline
|
pipeline
|
||||||
.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 1000)
|
.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 50)
|
||||||
.beforeEpoch(ctx -> {
|
.beforeEpoch(ctx -> {
|
||||||
ctx.globalLoss = 0.0F;
|
ctx.globalLoss = 0.0F;
|
||||||
for (int i = 0; i < model.synCount(); i++){
|
for (int i = 0; i < model.synCount(); i++){
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ public class SimpleTraining implements Trainer {
|
|||||||
|
|
||||||
TrainingPipeline pipeline = new TrainingPipeline(steps);
|
TrainingPipeline pipeline = new TrainingPipeline(steps);
|
||||||
pipeline
|
pipeline
|
||||||
.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 100)
|
.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 10)
|
||||||
.beforeEpoch(ctx -> ctx.globalLoss = 0)
|
.beforeEpoch(ctx -> ctx.globalLoss = 0)
|
||||||
.withVerbose(true)
|
.withVerbose(true)
|
||||||
.run(context);
|
.run(context);
|
||||||
|
|||||||
85
src/test/java/perceptron/simplePerceptronTest.java
Normal file
85
src/test/java/perceptron/simplePerceptronTest.java
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
package perceptron;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
|
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||||
|
import com.naaturel.ANN.domain.model.dataset.DatasetExtractor;
|
||||||
|
import com.naaturel.ANN.domain.model.neuron.*;
|
||||||
|
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||||
|
import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
|
||||||
|
import com.naaturel.ANN.implementation.simplePerceptron.*;
|
||||||
|
import com.naaturel.ANN.implementation.training.steps.*;
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
|
|
||||||
|
public class simplePerceptronTest {
|
||||||
|
|
||||||
|
private DataSet dataset;
|
||||||
|
private SimpleTrainingContext context;
|
||||||
|
|
||||||
|
private List<Synapse> synapses;
|
||||||
|
private Bias bias;
|
||||||
|
private Network network;
|
||||||
|
|
||||||
|
private TrainingPipeline pipeline;
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
public void init(){
|
||||||
|
dataset = new DatasetExtractor()
|
||||||
|
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and.csv");
|
||||||
|
|
||||||
|
List<Synapse> syns = new ArrayList<>();
|
||||||
|
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||||
|
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||||
|
|
||||||
|
bias = new Bias(new Weight(0));
|
||||||
|
|
||||||
|
Neuron neuron = new SimplePerceptron(syns, bias, new Heaviside());
|
||||||
|
Layer layer = new Layer(List.of(neuron));
|
||||||
|
network = new Network(List.of(layer));
|
||||||
|
|
||||||
|
context = new SimpleTrainingContext();
|
||||||
|
context.dataset = dataset;
|
||||||
|
context.model = network;
|
||||||
|
|
||||||
|
List<TrainingStep> steps = List.of(
|
||||||
|
new PredictionStep(new SimplePredictionStrategy(context)),
|
||||||
|
new DeltaStep(new SimpleDeltaStrategy(context)),
|
||||||
|
new LossStep(new SimpleLossStrategy(context)),
|
||||||
|
new ErrorRegistrationStep(new SimpleErrorRegistrationStrategy(context)),
|
||||||
|
new WeightCorrectionStep(new SimpleCorrectionStrategy(context))
|
||||||
|
);
|
||||||
|
|
||||||
|
pipeline = new TrainingPipeline(steps);
|
||||||
|
pipeline.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 100);
|
||||||
|
pipeline.beforeEpoch(ctx -> ctx.globalLoss = 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void test_the_whole_algorithm(){
|
||||||
|
|
||||||
|
List<Float> expectedGlobalLosses = List.of(
|
||||||
|
2.0F,
|
||||||
|
3.0F,
|
||||||
|
3.0F,
|
||||||
|
2.0F,
|
||||||
|
1.0F,
|
||||||
|
0.0F
|
||||||
|
);
|
||||||
|
|
||||||
|
context.learningRate = 1F;
|
||||||
|
pipeline.afterEpoch(ctx -> {
|
||||||
|
int index = ctx.epoch-1;
|
||||||
|
assertEquals(expectedGlobalLosses.get(index), context.globalLoss);
|
||||||
|
});
|
||||||
|
|
||||||
|
pipeline.run(context);
|
||||||
|
assertEquals(6, context.epoch);
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user