Refactor handling of model creation
This commit is contained in:
@@ -40,22 +40,15 @@ public class Main {
|
||||
DataSet dataset = new DatasetExtractor().extract(datasetPath, nbrClass);
|
||||
int nbrInput = dataset.getNbrInputs();
|
||||
|
||||
ModelSnapshot snapshot;
|
||||
ModelSnapshot snapshot = new ModelSnapshot();
|
||||
|
||||
Model network;
|
||||
if(newModel) {
|
||||
network = createNetwork(modelParameters, nbrInput);
|
||||
snapshot = new ModelSnapshot(network);
|
||||
System.out.println("Parameters: " + network.synCount());
|
||||
Trainer trainer = new GradientBackpropagationTraining();
|
||||
trainer.train(learningRate, maxEpoch, network, dataset);
|
||||
snapshot.saveToFile(modelPath);
|
||||
} else {
|
||||
snapshot = new ModelSnapshot();
|
||||
snapshot.loadFromFile(modelPath);
|
||||
network = snapshot.getModel();
|
||||
Trainer trainer = new GradientBackpropagationTraining(modelParameters, nbrInput);
|
||||
trainer.train(learningRate, maxEpoch, dataset);
|
||||
trainer.saveModel(snapshot, modelPath);
|
||||
}
|
||||
|
||||
Model network = snapshot.loadFromFile(modelPath);
|
||||
plotGraph(dataset, network);
|
||||
|
||||
new ModelVisualizer(network)
|
||||
@@ -63,33 +56,7 @@ public class Main {
|
||||
.display();
|
||||
}
|
||||
|
||||
private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){
|
||||
int neuronId = 0;
|
||||
List<Layer> layers = new ArrayList<>();
|
||||
for (int i = 0; i < neuronPerLayer.length; i++){
|
||||
|
||||
List<Neuron> neurons = new ArrayList<>();
|
||||
for (int j = 0; j < neuronPerLayer[i]; j++){
|
||||
|
||||
int nbrSyn = i == 0 ? nbrInput: neuronPerLayer[i-1];
|
||||
|
||||
List<Synapse> syns = new ArrayList<>();
|
||||
for (int k=0; k < nbrSyn; k++){
|
||||
syns.add(new Synapse(new Input(0), new Weight()));
|
||||
}
|
||||
|
||||
Bias bias = new Bias(new Weight());
|
||||
|
||||
Neuron n = new Neuron(neuronId, syns.toArray(new Synapse[0]), bias, new TanH());
|
||||
neurons.add(n);
|
||||
neuronId++;
|
||||
}
|
||||
Layer layer = new Layer(neurons.toArray(new Neuron[0]));
|
||||
layers.add(layer);
|
||||
}
|
||||
|
||||
return new FullyConnectedNetwork(layers.toArray(new Layer[0]));
|
||||
}
|
||||
|
||||
private static void plotGraph(DataSet dataset, Model network){
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package com.naaturel.ANN.domain.abstraction;
|
||||
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot;
|
||||
|
||||
public interface Trainer {
|
||||
void train(float learningRate, int epoch, Model model, DataSet dataset);
|
||||
void train(float learningRate, int epoch, DataSet dataset);
|
||||
void saveModel(ModelSnapshot snapshot, String path);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
package com.naaturel.ANN.domain.model.helpers;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
import com.naaturel.ANN.domain.model.neuron.*;
|
||||
import com.naaturel.ANN.implementation.multiLayers.TanH;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class ModelCreator {
|
||||
|
||||
public static Model createModel(int[] neuronPerLayer, int nbrInput){
|
||||
int neuronId = 0;
|
||||
List<Layer> layers = new ArrayList<>();
|
||||
for (int i = 0; i < neuronPerLayer.length; i++){
|
||||
|
||||
List<Neuron> neurons = new ArrayList<>();
|
||||
for (int j = 0; j < neuronPerLayer[i]; j++){
|
||||
|
||||
int nbrSyn = i == 0 ? nbrInput: neuronPerLayer[i-1];
|
||||
|
||||
List<Synapse> syns = new ArrayList<>();
|
||||
for (int k=0; k < nbrSyn; k++){
|
||||
syns.add(new Synapse(new Input(0), new Weight()));
|
||||
}
|
||||
|
||||
Bias bias = new Bias(new Weight());
|
||||
|
||||
Neuron n = new Neuron(neuronId, syns.toArray(new Synapse[0]), bias, new TanH());
|
||||
neurons.add(n);
|
||||
neuronId++;
|
||||
}
|
||||
Layer layer = new Layer(neurons.toArray(new Neuron[0]));
|
||||
layers.add(layer);
|
||||
}
|
||||
|
||||
return new FullyConnectedNetwork(layers.toArray(new Layer[0]));
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package com.naaturel.ANN.implementation.training;
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||
import com.naaturel.ANN.domain.model.helpers.ModelCreator;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||
import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext;
|
||||
@@ -11,6 +12,7 @@ import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStep;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStep;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
|
||||
import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot;
|
||||
import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer;
|
||||
|
||||
import java.util.List;
|
||||
@@ -18,12 +20,14 @@ import java.util.List;
|
||||
|
||||
public class AdalineTraining implements Trainer {
|
||||
|
||||
public AdalineTraining(){
|
||||
private Model model;
|
||||
|
||||
public AdalineTraining(int[] neurons, int nbrInputs){
|
||||
model = ModelCreator.createModel(neurons, nbrInputs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
|
||||
public void train(float learningRate, int epoch, DataSet dataset) {
|
||||
AdalineTrainingContext context = new AdalineTrainingContext(model, dataset);
|
||||
context.learningRate = learningRate;
|
||||
|
||||
@@ -45,6 +49,15 @@ public class AdalineTraining implements Trainer {
|
||||
.run(context);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void saveModel(ModelSnapshot snapshot, String path) {
|
||||
try {
|
||||
snapshot.saveToFile(model, path);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
/*public void train(Neuron n, float learningRate, DataSet dataSet) {
|
||||
int epoch = 1;
|
||||
int maxEpoch = 202;
|
||||
|
||||
@@ -3,19 +3,30 @@ package com.naaturel.ANN.implementation.training;
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||
import com.naaturel.ANN.domain.model.helpers.ModelCreator;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep;
|
||||
import com.naaturel.ANN.implementation.multiLayers.BackpropagationCorrectionStep;
|
||||
import com.naaturel.ANN.implementation.multiLayers.GradientBackpropagationContext;
|
||||
import com.naaturel.ANN.implementation.multiLayers.ErrorSignalStep;
|
||||
import com.naaturel.ANN.implementation.multiLayers.OutputLayerErrorStep;
|
||||
import com.naaturel.ANN.implementation.multiLayers.*;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class GradientBackpropagationTraining implements Trainer {
|
||||
|
||||
private Model model;
|
||||
|
||||
public GradientBackpropagationTraining(int[] neurons, int nbrInputs){
|
||||
model = ModelCreator.createModel(neurons, nbrInputs);
|
||||
}
|
||||
|
||||
public Model getModel() {
|
||||
return model;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
|
||||
public void train(float learningRate, int epoch, DataSet dataset) {
|
||||
GradientBackpropagationContext context =
|
||||
new GradientBackpropagationContext(model, dataset, learningRate, 10);
|
||||
|
||||
@@ -39,4 +50,13 @@ public class GradientBackpropagationTraining implements Trainer {
|
||||
.withTimeMeasurement(true)
|
||||
.run(context);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void saveModel(ModelSnapshot snapshot, String path) {
|
||||
try {
|
||||
snapshot.saveToFile(model, path);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package com.naaturel.ANN.implementation.training;
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||
import com.naaturel.ANN.domain.model.helpers.ModelCreator;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentErrorStrategy;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext;
|
||||
@@ -11,6 +12,7 @@ import com.naaturel.ANN.implementation.gradientDescent.GradientDescentCorrection
|
||||
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
|
||||
import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot;
|
||||
import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer;
|
||||
|
||||
import java.util.ArrayList;
|
||||
@@ -18,12 +20,14 @@ import java.util.List;
|
||||
|
||||
public class GradientDescentTraining implements Trainer {
|
||||
|
||||
public GradientDescentTraining(){
|
||||
private Model model;
|
||||
|
||||
public GradientDescentTraining(int[] neurons, int nbrInputs){
|
||||
model = ModelCreator.createModel(neurons, nbrInputs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
|
||||
public void train(float learningRate, int epoch, DataSet dataset) {
|
||||
GradientDescentTrainingContext context = new GradientDescentTrainingContext(model, dataset);
|
||||
context.learningRate = learningRate;
|
||||
context.correctorTerms = new ArrayList<>();
|
||||
@@ -55,6 +59,15 @@ public class GradientDescentTraining implements Trainer {
|
||||
.run(context);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void saveModel(ModelSnapshot snapshot, String path) {
|
||||
try {
|
||||
snapshot.saveToFile(model, path);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
/*public void train(Neuron n, float learningRate, DataSet dataSet) {
|
||||
int epoch = 1;
|
||||
int maxEpoch = 402;
|
||||
|
||||
@@ -3,20 +3,24 @@ package com.naaturel.ANN.implementation.training;
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||
import com.naaturel.ANN.domain.model.helpers.ModelCreator;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.*;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||
import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class SimpleTraining implements Trainer {
|
||||
|
||||
public SimpleTraining() {
|
||||
private Model model;
|
||||
|
||||
public SimpleTraining(int[] neurons, int nbrInputs){
|
||||
model = ModelCreator.createModel(neurons, nbrInputs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
|
||||
public void train(float learningRate, int epoch, DataSet dataset) {
|
||||
SimpleTrainingContext context = new SimpleTrainingContext(model, dataset);
|
||||
context.dataset = dataset;
|
||||
context.model = model;
|
||||
@@ -38,6 +42,15 @@ public class SimpleTraining implements Trainer {
|
||||
.run(context);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void saveModel(ModelSnapshot snapshot, String path) {
|
||||
try {
|
||||
snapshot.saveToFile(model, path);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
/*public void train(Neuron n, float learningRate, DataSet dataSet) {
|
||||
int epoch = 1;
|
||||
int errorCount;
|
||||
|
||||
@@ -15,23 +15,13 @@ import java.util.Map;
|
||||
|
||||
public class ModelSnapshot {
|
||||
|
||||
private Model model;
|
||||
private final ObjectMapper mapper;
|
||||
|
||||
public ModelSnapshot(){
|
||||
this(null);
|
||||
}
|
||||
|
||||
public ModelSnapshot(Model model){
|
||||
this.model = model;
|
||||
mapper = new ObjectMapper();
|
||||
}
|
||||
|
||||
public Model getModel() {
|
||||
return model;
|
||||
}
|
||||
|
||||
public void saveToFile(String path) throws Exception {
|
||||
public void saveToFile(Model model, String path) throws Exception {
|
||||
|
||||
ArrayNode root = mapper.createArrayNode();
|
||||
model.forEachNeuron(n -> {
|
||||
@@ -52,7 +42,7 @@ public class ModelSnapshot {
|
||||
mapper.writerWithDefaultPrettyPrinter().writeValue(new File(path), root);
|
||||
}
|
||||
|
||||
public void loadFromFile(String path) throws Exception {
|
||||
public Model loadFromFile(String path) throws Exception {
|
||||
ArrayNode root = (ArrayNode) mapper.readTree(new File(path));
|
||||
|
||||
Map<Integer, List<Neuron>> neuronsByLayer = new LinkedHashMap<>();
|
||||
@@ -76,6 +66,6 @@ public class ModelSnapshot {
|
||||
.map(neurons -> new Layer(neurons.toArray(new Neuron[0])))
|
||||
.toArray(Layer[]::new);
|
||||
|
||||
this.model = new FullyConnectedNetwork(layers);
|
||||
return new FullyConnectedNetwork(layers);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user