Refactor handling of model creation
This commit is contained in:
@@ -40,22 +40,15 @@ public class Main {
|
|||||||
DataSet dataset = new DatasetExtractor().extract(datasetPath, nbrClass);
|
DataSet dataset = new DatasetExtractor().extract(datasetPath, nbrClass);
|
||||||
int nbrInput = dataset.getNbrInputs();
|
int nbrInput = dataset.getNbrInputs();
|
||||||
|
|
||||||
ModelSnapshot snapshot;
|
ModelSnapshot snapshot = new ModelSnapshot();
|
||||||
|
|
||||||
Model network;
|
if(newModel) {
|
||||||
if(newModel){
|
Trainer trainer = new GradientBackpropagationTraining(modelParameters, nbrInput);
|
||||||
network = createNetwork(modelParameters, nbrInput);
|
trainer.train(learningRate, maxEpoch, dataset);
|
||||||
snapshot = new ModelSnapshot(network);
|
trainer.saveModel(snapshot, modelPath);
|
||||||
System.out.println("Parameters: " + network.synCount());
|
|
||||||
Trainer trainer = new GradientBackpropagationTraining();
|
|
||||||
trainer.train(learningRate, maxEpoch, network, dataset);
|
|
||||||
snapshot.saveToFile(modelPath);
|
|
||||||
} else {
|
|
||||||
snapshot = new ModelSnapshot();
|
|
||||||
snapshot.loadFromFile(modelPath);
|
|
||||||
network = snapshot.getModel();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Model network = snapshot.loadFromFile(modelPath);
|
||||||
plotGraph(dataset, network);
|
plotGraph(dataset, network);
|
||||||
|
|
||||||
new ModelVisualizer(network)
|
new ModelVisualizer(network)
|
||||||
@@ -63,33 +56,7 @@ public class Main {
|
|||||||
.display();
|
.display();
|
||||||
}
|
}
|
||||||
|
|
||||||
private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){
|
|
||||||
int neuronId = 0;
|
|
||||||
List<Layer> layers = new ArrayList<>();
|
|
||||||
for (int i = 0; i < neuronPerLayer.length; i++){
|
|
||||||
|
|
||||||
List<Neuron> neurons = new ArrayList<>();
|
|
||||||
for (int j = 0; j < neuronPerLayer[i]; j++){
|
|
||||||
|
|
||||||
int nbrSyn = i == 0 ? nbrInput: neuronPerLayer[i-1];
|
|
||||||
|
|
||||||
List<Synapse> syns = new ArrayList<>();
|
|
||||||
for (int k=0; k < nbrSyn; k++){
|
|
||||||
syns.add(new Synapse(new Input(0), new Weight()));
|
|
||||||
}
|
|
||||||
|
|
||||||
Bias bias = new Bias(new Weight());
|
|
||||||
|
|
||||||
Neuron n = new Neuron(neuronId, syns.toArray(new Synapse[0]), bias, new TanH());
|
|
||||||
neurons.add(n);
|
|
||||||
neuronId++;
|
|
||||||
}
|
|
||||||
Layer layer = new Layer(neurons.toArray(new Neuron[0]));
|
|
||||||
layers.add(layer);
|
|
||||||
}
|
|
||||||
|
|
||||||
return new FullyConnectedNetwork(layers.toArray(new Layer[0]));
|
|
||||||
}
|
|
||||||
|
|
||||||
private static void plotGraph(DataSet dataset, Model network){
|
private static void plotGraph(DataSet dataset, Model network){
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
package com.naaturel.ANN.domain.abstraction;
|
package com.naaturel.ANN.domain.abstraction;
|
||||||
|
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||||
|
import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot;
|
||||||
|
|
||||||
public interface Trainer {
|
public interface Trainer {
|
||||||
void train(float learningRate, int epoch, Model model, DataSet dataset);
|
void train(float learningRate, int epoch, DataSet dataset);
|
||||||
|
void saveModel(ModelSnapshot snapshot, String path);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,39 @@
|
|||||||
|
package com.naaturel.ANN.domain.model.helpers;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.Model;
|
||||||
|
import com.naaturel.ANN.domain.model.neuron.*;
|
||||||
|
import com.naaturel.ANN.implementation.multiLayers.TanH;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class ModelCreator {
|
||||||
|
|
||||||
|
public static Model createModel(int[] neuronPerLayer, int nbrInput){
|
||||||
|
int neuronId = 0;
|
||||||
|
List<Layer> layers = new ArrayList<>();
|
||||||
|
for (int i = 0; i < neuronPerLayer.length; i++){
|
||||||
|
|
||||||
|
List<Neuron> neurons = new ArrayList<>();
|
||||||
|
for (int j = 0; j < neuronPerLayer[i]; j++){
|
||||||
|
|
||||||
|
int nbrSyn = i == 0 ? nbrInput: neuronPerLayer[i-1];
|
||||||
|
|
||||||
|
List<Synapse> syns = new ArrayList<>();
|
||||||
|
for (int k=0; k < nbrSyn; k++){
|
||||||
|
syns.add(new Synapse(new Input(0), new Weight()));
|
||||||
|
}
|
||||||
|
|
||||||
|
Bias bias = new Bias(new Weight());
|
||||||
|
|
||||||
|
Neuron n = new Neuron(neuronId, syns.toArray(new Synapse[0]), bias, new TanH());
|
||||||
|
neurons.add(n);
|
||||||
|
neuronId++;
|
||||||
|
}
|
||||||
|
Layer layer = new Layer(neurons.toArray(new Neuron[0]));
|
||||||
|
layers.add(layer);
|
||||||
|
}
|
||||||
|
|
||||||
|
return new FullyConnectedNetwork(layers.toArray(new Layer[0]));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,6 +3,7 @@ package com.naaturel.ANN.implementation.training;
|
|||||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||||
import com.naaturel.ANN.domain.abstraction.Model;
|
import com.naaturel.ANN.domain.abstraction.Model;
|
||||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||||
|
import com.naaturel.ANN.domain.model.helpers.ModelCreator;
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||||
import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext;
|
import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext;
|
||||||
@@ -11,6 +12,7 @@ import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStep;
|
|||||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
|
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
|
||||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStep;
|
import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStep;
|
||||||
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
|
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
|
||||||
|
import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot;
|
||||||
import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer;
|
import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -18,12 +20,14 @@ import java.util.List;
|
|||||||
|
|
||||||
public class AdalineTraining implements Trainer {
|
public class AdalineTraining implements Trainer {
|
||||||
|
|
||||||
public AdalineTraining(){
|
private Model model;
|
||||||
|
|
||||||
|
public AdalineTraining(int[] neurons, int nbrInputs){
|
||||||
|
model = ModelCreator.createModel(neurons, nbrInputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
|
public void train(float learningRate, int epoch, DataSet dataset) {
|
||||||
AdalineTrainingContext context = new AdalineTrainingContext(model, dataset);
|
AdalineTrainingContext context = new AdalineTrainingContext(model, dataset);
|
||||||
context.learningRate = learningRate;
|
context.learningRate = learningRate;
|
||||||
|
|
||||||
@@ -45,6 +49,15 @@ public class AdalineTraining implements Trainer {
|
|||||||
.run(context);
|
.run(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void saveModel(ModelSnapshot snapshot, String path) {
|
||||||
|
try {
|
||||||
|
snapshot.saveToFile(model, path);
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/*public void train(Neuron n, float learningRate, DataSet dataSet) {
|
/*public void train(Neuron n, float learningRate, DataSet dataSet) {
|
||||||
int epoch = 1;
|
int epoch = 1;
|
||||||
int maxEpoch = 202;
|
int maxEpoch = 202;
|
||||||
|
|||||||
@@ -3,19 +3,30 @@ package com.naaturel.ANN.implementation.training;
|
|||||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||||
import com.naaturel.ANN.domain.abstraction.Model;
|
import com.naaturel.ANN.domain.abstraction.Model;
|
||||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||||
|
import com.naaturel.ANN.domain.model.helpers.ModelCreator;
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||||
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep;
|
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep;
|
||||||
import com.naaturel.ANN.implementation.multiLayers.BackpropagationCorrectionStep;
|
import com.naaturel.ANN.implementation.multiLayers.*;
|
||||||
import com.naaturel.ANN.implementation.multiLayers.GradientBackpropagationContext;
|
|
||||||
import com.naaturel.ANN.implementation.multiLayers.ErrorSignalStep;
|
|
||||||
import com.naaturel.ANN.implementation.multiLayers.OutputLayerErrorStep;
|
|
||||||
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
|
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||||
|
import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public class GradientBackpropagationTraining implements Trainer {
|
public class GradientBackpropagationTraining implements Trainer {
|
||||||
|
|
||||||
|
private Model model;
|
||||||
|
|
||||||
|
public GradientBackpropagationTraining(int[] neurons, int nbrInputs){
|
||||||
|
model = ModelCreator.createModel(neurons, nbrInputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Model getModel() {
|
||||||
|
return model;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
|
public void train(float learningRate, int epoch, DataSet dataset) {
|
||||||
GradientBackpropagationContext context =
|
GradientBackpropagationContext context =
|
||||||
new GradientBackpropagationContext(model, dataset, learningRate, 10);
|
new GradientBackpropagationContext(model, dataset, learningRate, 10);
|
||||||
|
|
||||||
@@ -39,4 +50,13 @@ public class GradientBackpropagationTraining implements Trainer {
|
|||||||
.withTimeMeasurement(true)
|
.withTimeMeasurement(true)
|
||||||
.run(context);
|
.run(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void saveModel(ModelSnapshot snapshot, String path) {
|
||||||
|
try {
|
||||||
|
snapshot.saveToFile(model, path);
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package com.naaturel.ANN.implementation.training;
|
|||||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||||
import com.naaturel.ANN.domain.abstraction.Model;
|
import com.naaturel.ANN.domain.abstraction.Model;
|
||||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||||
|
import com.naaturel.ANN.domain.model.helpers.ModelCreator;
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||||
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentErrorStrategy;
|
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentErrorStrategy;
|
||||||
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext;
|
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext;
|
||||||
@@ -11,6 +12,7 @@ import com.naaturel.ANN.implementation.gradientDescent.GradientDescentCorrection
|
|||||||
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep;
|
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep;
|
||||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
|
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
|
||||||
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
|
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
|
||||||
|
import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot;
|
||||||
import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer;
|
import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
@@ -18,12 +20,14 @@ import java.util.List;
|
|||||||
|
|
||||||
public class GradientDescentTraining implements Trainer {
|
public class GradientDescentTraining implements Trainer {
|
||||||
|
|
||||||
public GradientDescentTraining(){
|
private Model model;
|
||||||
|
|
||||||
|
public GradientDescentTraining(int[] neurons, int nbrInputs){
|
||||||
|
model = ModelCreator.createModel(neurons, nbrInputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
|
public void train(float learningRate, int epoch, DataSet dataset) {
|
||||||
GradientDescentTrainingContext context = new GradientDescentTrainingContext(model, dataset);
|
GradientDescentTrainingContext context = new GradientDescentTrainingContext(model, dataset);
|
||||||
context.learningRate = learningRate;
|
context.learningRate = learningRate;
|
||||||
context.correctorTerms = new ArrayList<>();
|
context.correctorTerms = new ArrayList<>();
|
||||||
@@ -55,6 +59,15 @@ public class GradientDescentTraining implements Trainer {
|
|||||||
.run(context);
|
.run(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void saveModel(ModelSnapshot snapshot, String path) {
|
||||||
|
try {
|
||||||
|
snapshot.saveToFile(model, path);
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/*public void train(Neuron n, float learningRate, DataSet dataSet) {
|
/*public void train(Neuron n, float learningRate, DataSet dataSet) {
|
||||||
int epoch = 1;
|
int epoch = 1;
|
||||||
int maxEpoch = 402;
|
int maxEpoch = 402;
|
||||||
|
|||||||
@@ -3,20 +3,24 @@ package com.naaturel.ANN.implementation.training;
|
|||||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||||
import com.naaturel.ANN.domain.abstraction.Model;
|
import com.naaturel.ANN.domain.abstraction.Model;
|
||||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||||
|
import com.naaturel.ANN.domain.model.helpers.ModelCreator;
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||||
import com.naaturel.ANN.implementation.simplePerceptron.*;
|
import com.naaturel.ANN.implementation.simplePerceptron.*;
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||||
|
import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public class SimpleTraining implements Trainer {
|
public class SimpleTraining implements Trainer {
|
||||||
|
|
||||||
public SimpleTraining() {
|
private Model model;
|
||||||
|
|
||||||
|
public SimpleTraining(int[] neurons, int nbrInputs){
|
||||||
|
model = ModelCreator.createModel(neurons, nbrInputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
|
public void train(float learningRate, int epoch, DataSet dataset) {
|
||||||
SimpleTrainingContext context = new SimpleTrainingContext(model, dataset);
|
SimpleTrainingContext context = new SimpleTrainingContext(model, dataset);
|
||||||
context.dataset = dataset;
|
context.dataset = dataset;
|
||||||
context.model = model;
|
context.model = model;
|
||||||
@@ -38,6 +42,15 @@ public class SimpleTraining implements Trainer {
|
|||||||
.run(context);
|
.run(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void saveModel(ModelSnapshot snapshot, String path) {
|
||||||
|
try {
|
||||||
|
snapshot.saveToFile(model, path);
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/*public void train(Neuron n, float learningRate, DataSet dataSet) {
|
/*public void train(Neuron n, float learningRate, DataSet dataSet) {
|
||||||
int epoch = 1;
|
int epoch = 1;
|
||||||
int errorCount;
|
int errorCount;
|
||||||
|
|||||||
@@ -15,23 +15,13 @@ import java.util.Map;
|
|||||||
|
|
||||||
public class ModelSnapshot {
|
public class ModelSnapshot {
|
||||||
|
|
||||||
private Model model;
|
|
||||||
private final ObjectMapper mapper;
|
private final ObjectMapper mapper;
|
||||||
|
|
||||||
public ModelSnapshot(){
|
public ModelSnapshot(){
|
||||||
this(null);
|
|
||||||
}
|
|
||||||
|
|
||||||
public ModelSnapshot(Model model){
|
|
||||||
this.model = model;
|
|
||||||
mapper = new ObjectMapper();
|
mapper = new ObjectMapper();
|
||||||
}
|
}
|
||||||
|
|
||||||
public Model getModel() {
|
public void saveToFile(Model model, String path) throws Exception {
|
||||||
return model;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void saveToFile(String path) throws Exception {
|
|
||||||
|
|
||||||
ArrayNode root = mapper.createArrayNode();
|
ArrayNode root = mapper.createArrayNode();
|
||||||
model.forEachNeuron(n -> {
|
model.forEachNeuron(n -> {
|
||||||
@@ -52,7 +42,7 @@ public class ModelSnapshot {
|
|||||||
mapper.writerWithDefaultPrettyPrinter().writeValue(new File(path), root);
|
mapper.writerWithDefaultPrettyPrinter().writeValue(new File(path), root);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void loadFromFile(String path) throws Exception {
|
public Model loadFromFile(String path) throws Exception {
|
||||||
ArrayNode root = (ArrayNode) mapper.readTree(new File(path));
|
ArrayNode root = (ArrayNode) mapper.readTree(new File(path));
|
||||||
|
|
||||||
Map<Integer, List<Neuron>> neuronsByLayer = new LinkedHashMap<>();
|
Map<Integer, List<Neuron>> neuronsByLayer = new LinkedHashMap<>();
|
||||||
@@ -76,6 +66,6 @@ public class ModelSnapshot {
|
|||||||
.map(neurons -> new Layer(neurons.toArray(new Neuron[0])))
|
.map(neurons -> new Layer(neurons.toArray(new Neuron[0])))
|
||||||
.toArray(Layer[]::new);
|
.toArray(Layer[]::new);
|
||||||
|
|
||||||
this.model = new FullyConnectedNetwork(layers);
|
return new FullyConnectedNetwork(layers);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user