132 lines
5.0 KiB
Java
132 lines
5.0 KiB
Java
package com.naaturel.ANN;
|
|
|
|
import com.naaturel.ANN.domain.abstraction.Model;
|
|
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
|
import com.naaturel.ANN.domain.abstraction.Trainer;
|
|
import com.naaturel.ANN.implementation.multiLayers.TanH;
|
|
import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
|
|
import com.naaturel.ANN.infrastructure.config.ConfigDto;
|
|
import com.naaturel.ANN.infrastructure.config.ConfigLoader;
|
|
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
|
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
|
|
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
|
import com.naaturel.ANN.domain.model.neuron.*;
|
|
import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer;
|
|
import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot;
|
|
import com.naaturel.ANN.infrastructure.visualization.ModelVisualizer;
|
|
|
|
import java.util.*;
|
|
|
|
public class Main {
|
|
|
|
public static void main(String[] args) throws Exception {
|
|
|
|
Scanner sc = new Scanner(System.in);
|
|
|
|
ConfigDto config = ConfigLoader.load("C:/Users/Laurent/Desktop/ANN-framework/config.json");
|
|
|
|
boolean newModel = config.getModelProperty("new", Boolean.class);
|
|
int[] modelParameters = config.getModelProperty("parameters", int[].class);
|
|
String modelPath = config.getModelProperty("path", String.class);
|
|
int maxEpoch = config.getTrainingProperty("max_epoch", Integer.class);
|
|
float learningRate = config.getTrainingProperty("learning_rate", Double.class).floatValue();
|
|
String datasetPath = config.getDatasetProperty("path", String.class);
|
|
|
|
System.out.printf("How many classes ? [%d] : ", modelParameters[modelParameters.length-1]);
|
|
String input = sc.nextLine().trim();
|
|
int nbrClass = input.isEmpty() ? modelParameters[modelParameters.length-1] : Integer.parseInt(input);
|
|
System.out.println();
|
|
|
|
|
|
DataSet dataset = new DatasetExtractor().extract(datasetPath, nbrClass);
|
|
int nbrInput = dataset.getNbrInputs();
|
|
|
|
ModelSnapshot snapshot;
|
|
|
|
Model network;
|
|
if(newModel){
|
|
network = createNetwork(modelParameters, nbrInput);
|
|
snapshot = new ModelSnapshot(network);
|
|
System.out.println("Parameters: " + network.synCount());
|
|
Trainer trainer = new GradientBackpropagationTraining();
|
|
trainer.train(learningRate, maxEpoch, network, dataset);
|
|
snapshot.saveToFile(modelPath);
|
|
} else {
|
|
snapshot = new ModelSnapshot();
|
|
snapshot.loadFromFile(modelPath);
|
|
network = snapshot.getModel();
|
|
}
|
|
|
|
plotGraph(dataset, network);
|
|
|
|
new ModelVisualizer(network)
|
|
.withWeights(true)
|
|
.display();
|
|
}
|
|
|
|
private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){
|
|
int neuronId = 0;
|
|
List<Layer> layers = new ArrayList<>();
|
|
for (int i = 0; i < neuronPerLayer.length; i++){
|
|
|
|
List<Neuron> neurons = new ArrayList<>();
|
|
for (int j = 0; j < neuronPerLayer[i]; j++){
|
|
|
|
int nbrSyn = i == 0 ? nbrInput: neuronPerLayer[i-1];
|
|
|
|
List<Synapse> syns = new ArrayList<>();
|
|
for (int k=0; k < nbrSyn; k++){
|
|
syns.add(new Synapse(new Input(0), new Weight()));
|
|
}
|
|
|
|
Bias bias = new Bias(new Weight());
|
|
|
|
Neuron n = new Neuron(neuronId, syns.toArray(new Synapse[0]), bias, new TanH());
|
|
neurons.add(n);
|
|
neuronId++;
|
|
}
|
|
Layer layer = new Layer(neurons.toArray(new Neuron[0]));
|
|
layers.add(layer);
|
|
}
|
|
|
|
return new FullyConnectedNetwork(layers.toArray(new Layer[0]));
|
|
}
|
|
|
|
private static void plotGraph(DataSet dataset, Model network){
|
|
|
|
if(dataset.getNbrInputs() != 2) return;
|
|
|
|
GraphVisualizer visualizer = new GraphVisualizer();
|
|
float min = -10F;
|
|
float max = 10F;
|
|
float step = Math.abs(max - min)/300;
|
|
|
|
//plot labels
|
|
for (DataSetEntry entry : dataset) {
|
|
List<Float> label = dataset.getLabelsAsFloat(entry);
|
|
int expectedClass = 0;
|
|
for (int i = 1; i < label.size(); i++) {
|
|
if (label.get(i) > label.get(expectedClass)) expectedClass = i;
|
|
}
|
|
visualizer.addPoint("Label " + expectedClass,
|
|
entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
|
|
}
|
|
|
|
//plot predictions
|
|
for (float x = min; x < max; x+=step){
|
|
for (float y = min; y < max; y+=step){
|
|
List<Float> predictions = new ArrayList<>();
|
|
|
|
for (float p : network.predict(new float[]{x, y})) predictions.add(p);
|
|
float highest = Collections.max(predictions);
|
|
int predictedClass = predictions.indexOf(highest);
|
|
|
|
visualizer.addPoint("Predict " + predictedClass, x, y);
|
|
}
|
|
}
|
|
|
|
visualizer.buildScatterGraph((int)min-1, (int)max+1);
|
|
}
|
|
|
|
}
|