Move dataset components
This commit is contained in:
@@ -2,11 +2,10 @@ package com.naaturel.ANN;
|
|||||||
|
|
||||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DatasetExtractor;
|
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
||||||
import com.naaturel.ANN.domain.model.neuron.*;
|
import com.naaturel.ANN.domain.model.neuron.*;
|
||||||
import com.naaturel.ANN.implementation.gradientDescent.Linear;
|
import com.naaturel.ANN.implementation.gradientDescent.Linear;
|
||||||
import com.naaturel.ANN.implementation.training.AdalineTraining;
|
|
||||||
import com.naaturel.ANN.implementation.training.GradientDescentTraining;
|
import com.naaturel.ANN.implementation.training.GradientDescentTraining;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
@@ -18,25 +17,31 @@ public class Main {
|
|||||||
int nbrInput = 2;
|
int nbrInput = 2;
|
||||||
int nbrClass = 3;
|
int nbrClass = 3;
|
||||||
|
|
||||||
|
int nbrLayers = 1;
|
||||||
|
|
||||||
DataSet dataset = new DatasetExtractor()
|
DataSet dataset = new DatasetExtractor()
|
||||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_3_1.csv", nbrClass);
|
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_3_1.csv", nbrClass);
|
||||||
|
|
||||||
List<Neuron> neurons = new ArrayList<>();
|
List<Layer> layers = new ArrayList<>();
|
||||||
|
for(int i = 0; i < nbrLayers; i++){
|
||||||
|
|
||||||
for (int i=0; i < nbrClass; i++){
|
List<Neuron> neurons = new ArrayList<>();
|
||||||
List<Synapse> syns = new ArrayList<>();
|
for (int j=0; j < nbrClass; j++){
|
||||||
for (int j=0; j < nbrInput; j++){
|
|
||||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
List<Synapse> syns = new ArrayList<>();
|
||||||
|
for (int k=0; k < nbrInput; k++){
|
||||||
|
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||||
|
}
|
||||||
|
|
||||||
|
Bias bias = new Bias(new Weight(0));
|
||||||
|
|
||||||
|
Neuron n = new Neuron(syns, bias, new Linear());
|
||||||
|
neurons.add(n);
|
||||||
}
|
}
|
||||||
|
Layer layer = new Layer(neurons);
|
||||||
Bias bias = new Bias(new Weight(0));
|
layers.add(layer);
|
||||||
|
|
||||||
Neuron n = new Neuron(syns, bias, new Linear());
|
|
||||||
neurons.add(n);
|
|
||||||
}
|
}
|
||||||
|
Network network = new Network(layers);
|
||||||
Layer layer = new Layer(neurons);
|
|
||||||
Network network = new Network(List.of(layer));
|
|
||||||
|
|
||||||
Trainer trainer = new GradientDescentTraining();
|
Trainer trainer = new GradientDescentTraining();
|
||||||
trainer.train(network, dataset);
|
trainer.train(network, dataset);
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
package com.naaturel.ANN.domain.abstraction;
|
package com.naaturel.ANN.domain.abstraction;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||||
|
|
||||||
public interface Trainer {
|
public interface Trainer {
|
||||||
void train(Model model, DataSet dataset);
|
void train(Model model, DataSet dataset);
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package com.naaturel.ANN.domain.abstraction;
|
package com.naaturel.ANN.domain.abstraction;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
|||||||
@@ -16,12 +16,11 @@ public class Network implements Model {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<Float> predict(List<Input> inputs) {
|
public List<Float> predict(List<Input> inputs) {
|
||||||
List<Float> result = new ArrayList<>();
|
List<Input> currentLayerOutput = new ArrayList<>(inputs);
|
||||||
for(Layer layer : this.layers){
|
for(Layer layer : this.layers){
|
||||||
List<Float> res = layer.predict(inputs);
|
currentLayerOutput = layer.predict(currentLayerOutput).stream().map(Input::new).toList();
|
||||||
result.addAll(res);
|
|
||||||
}
|
}
|
||||||
return result;
|
return currentLayerOutput.stream().map(Input::getValue).toList();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package com.naaturel.ANN.domain.model.training;
|
|||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
|
||||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||||
import com.naaturel.ANN.infrastructure.graph.GraphVisualizer;
|
import com.naaturel.ANN.infrastructure.graph.GraphVisualizer;
|
||||||
|
|
||||||
|
|||||||
@@ -2,10 +2,9 @@ package com.naaturel.ANN.implementation.simplePerceptron;
|
|||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import java.util.stream.IntStream;
|
import java.util.stream.IntStream;
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ package com.naaturel.ANN.implementation.training;
|
|||||||
import com.naaturel.ANN.domain.abstraction.Model;
|
import com.naaturel.ANN.domain.abstraction.Model;
|
||||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||||
import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext;
|
import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext;
|
||||||
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStrategy;
|
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStrategy;
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ package com.naaturel.ANN.implementation.training;
|
|||||||
import com.naaturel.ANN.domain.abstraction.Model;
|
import com.naaturel.ANN.domain.abstraction.Model;
|
||||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||||
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentErrorStrategy;
|
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentErrorStrategy;
|
||||||
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext;
|
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext;
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||||
@@ -28,7 +28,7 @@ public class GradientDescentTraining implements Trainer {
|
|||||||
GradientDescentTrainingContext context = new GradientDescentTrainingContext();
|
GradientDescentTrainingContext context = new GradientDescentTrainingContext();
|
||||||
context.dataset = dataset;
|
context.dataset = dataset;
|
||||||
context.model = model;
|
context.model = model;
|
||||||
context.learningRate = 0.0005F;
|
context.learningRate = 0.0008F;
|
||||||
context.correctorTerms = new ArrayList<>();
|
context.correctorTerms = new ArrayList<>();
|
||||||
|
|
||||||
List<TrainingStep> steps = List.of(
|
List<TrainingStep> steps = List.of(
|
||||||
@@ -39,7 +39,7 @@ public class GradientDescentTraining implements Trainer {
|
|||||||
);
|
);
|
||||||
|
|
||||||
new TrainingPipeline(steps)
|
new TrainingPipeline(steps)
|
||||||
.stopCondition(ctx -> ctx.globalLoss <= 0.00F || ctx.epoch > 50000)
|
.stopCondition(ctx -> ctx.globalLoss <= 0.08F || ctx.epoch > 150)
|
||||||
.beforeEpoch(ctx -> {
|
.beforeEpoch(ctx -> {
|
||||||
GradientDescentTrainingContext gdCtx = (GradientDescentTrainingContext) ctx;
|
GradientDescentTrainingContext gdCtx = (GradientDescentTrainingContext) ctx;
|
||||||
gdCtx.globalLoss = 0.0F;
|
gdCtx.globalLoss = 0.0F;
|
||||||
@@ -50,9 +50,9 @@ public class GradientDescentTraining implements Trainer {
|
|||||||
context.globalLoss /= context.dataset.size();
|
context.globalLoss /= context.dataset.size();
|
||||||
new GradientDescentCorrectionStrategy(context).apply();
|
new GradientDescentCorrectionStrategy(context).apply();
|
||||||
})
|
})
|
||||||
//.withVerbose(true)
|
.withVerbose(true)
|
||||||
.withTimeMeasurement(true)
|
.withTimeMeasurement(true)
|
||||||
.withVisualization(false, new GraphVisualizer())
|
.withVisualization(true, new GraphVisualizer())
|
||||||
.run(context);
|
.run(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ package com.naaturel.ANN.implementation.training;
|
|||||||
import com.naaturel.ANN.domain.abstraction.Model;
|
import com.naaturel.ANN.domain.abstraction.Model;
|
||||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||||
import com.naaturel.ANN.implementation.simplePerceptron.*;
|
import com.naaturel.ANN.implementation.simplePerceptron.*;
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||||
import com.naaturel.ANN.implementation.training.steps.*;
|
import com.naaturel.ANN.implementation.training.steps.*;
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.naaturel.ANN.domain.model.dataset;
|
package com.naaturel.ANN.infrastructure.dataset;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||||
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.naaturel.ANN.domain.model.dataset;
|
package com.naaturel.ANN.infrastructure.dataset;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||||
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.naaturel.ANN.domain.model.dataset;
|
package com.naaturel.ANN.infrastructure.dataset;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||||
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.naaturel.ANN.domain.model.dataset;
|
package com.naaturel.ANN.infrastructure.dataset;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@@ -3,8 +3,8 @@ package adaline;
|
|||||||
|
|
||||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DatasetExtractor;
|
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
||||||
import com.naaturel.ANN.domain.model.neuron.*;
|
import com.naaturel.ANN.domain.model.neuron.*;
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||||
import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext;
|
import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext;
|
||||||
|
|||||||
@@ -2,8 +2,8 @@ package gradientDescent;
|
|||||||
|
|
||||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DatasetExtractor;
|
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
||||||
import com.naaturel.ANN.domain.model.neuron.*;
|
import com.naaturel.ANN.domain.model.neuron.*;
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||||
import com.naaturel.ANN.implementation.gradientDescent.*;
|
import com.naaturel.ANN.implementation.gradientDescent.*;
|
||||||
|
|||||||
@@ -2,8 +2,8 @@ package perceptron;
|
|||||||
|
|
||||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DatasetExtractor;
|
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
||||||
import com.naaturel.ANN.domain.model.neuron.*;
|
import com.naaturel.ANN.domain.model.neuron.*;
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||||
import com.naaturel.ANN.implementation.simplePerceptron.*;
|
import com.naaturel.ANN.implementation.simplePerceptron.*;
|
||||||
|
|||||||
Reference in New Issue
Block a user