package com.naaturel.ANN; import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.model.neuron.Neuron; import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.implementation.multiLayers.TanH; import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining; import com.naaturel.ANN.infrastructure.config.ConfigDto; import com.naaturel.ANN.infrastructure.config.ConfigLoader; import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.infrastructure.dataset.DataSetEntry; import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor; import com.naaturel.ANN.domain.model.neuron.*; import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer; import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot; import com.naaturel.ANN.infrastructure.visualization.ModelVisualizer; import java.util.*; public class Main { public static void main(String[] args) throws Exception { Scanner sc = new Scanner(System.in); ConfigDto config = ConfigLoader.load("C:/Users/Laurent/Desktop/ANN-framework/config.json"); boolean newModel = config.getModelProperty("new", Boolean.class); int[] modelParameters = config.getModelProperty("parameters", int[].class); String modelPath = config.getModelProperty("path", String.class); int maxEpoch = config.getTrainingProperty("max_epoch", Integer.class); float learningRate = config.getTrainingProperty("learning_rate", Double.class).floatValue(); String datasetPath = config.getDatasetProperty("path", String.class); System.out.printf("How many classes ? [%d] : ", modelParameters[modelParameters.length-1]); String input = sc.nextLine().trim(); int nbrClass = input.isEmpty() ? modelParameters[modelParameters.length-1] : Integer.parseInt(input); System.out.println(); DataSet dataset = new DatasetExtractor().extract(datasetPath, nbrClass); int nbrInput = dataset.getNbrInputs(); ModelSnapshot snapshot; Model network; if(newModel){ network = createNetwork(modelParameters, nbrInput); snapshot = new ModelSnapshot(network); System.out.println("Parameters: " + network.synCount()); Trainer trainer = new GradientBackpropagationTraining(); trainer.train(learningRate, maxEpoch, network, dataset); snapshot.saveToFile(modelPath); } else { snapshot = new ModelSnapshot(); snapshot.loadFromFile(modelPath); network = snapshot.getModel(); } plotGraph(dataset, network); new ModelVisualizer(network) .withWeights(true) .display(); } private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){ int neuronId = 0; List layers = new ArrayList<>(); for (int i = 0; i < neuronPerLayer.length; i++){ List neurons = new ArrayList<>(); for (int j = 0; j < neuronPerLayer[i]; j++){ int nbrSyn = i == 0 ? nbrInput: neuronPerLayer[i-1]; List syns = new ArrayList<>(); for (int k=0; k < nbrSyn; k++){ syns.add(new Synapse(new Input(0), new Weight())); } Bias bias = new Bias(new Weight()); Neuron n = new Neuron(neuronId, syns.toArray(new Synapse[0]), bias, new TanH()); neurons.add(n); neuronId++; } Layer layer = new Layer(neurons.toArray(new Neuron[0])); layers.add(layer); } return new FullyConnectedNetwork(layers.toArray(new Layer[0])); } private static void plotGraph(DataSet dataset, Model network){ if(dataset.getNbrInputs() != 2) return; GraphVisualizer visualizer = new GraphVisualizer(); float min = -10F; float max = 10F; float step = Math.abs(max - min)/300; //plot labels for (DataSetEntry entry : dataset) { List label = dataset.getLabelsAsFloat(entry); int expectedClass = 0; for (int i = 1; i < label.size(); i++) { if (label.get(i) > label.get(expectedClass)) expectedClass = i; } visualizer.addPoint("Label " + expectedClass, entry.getData().get(0).getValue(), entry.getData().get(1).getValue()); } //plot predictions for (float x = min; x < max; x+=step){ for (float y = min; y < max; y+=step){ List predictions = new ArrayList<>(); for (float p : network.predict(new float[]{x, y})) predictions.add(p); float highest = Collections.max(predictions); int predictedClass = predictions.indexOf(highest); visualizer.addPoint("Predict " + predictedClass, x, y); } } visualizer.buildScatterGraph((int)min-1, (int)max+1); } }