Files
ANN-framework/src/main/java/com/naaturel/ANN/Main.java
2026-05-09 16:44:12 +02:00

132 lines
5.0 KiB
Java

package com.naaturel.ANN;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.model.neuron.Neuron;
import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.implementation.multiLayers.TanH;
import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
import com.naaturel.ANN.infrastructure.config.ConfigDto;
import com.naaturel.ANN.infrastructure.config.ConfigLoader;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer;
import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot;
import com.naaturel.ANN.infrastructure.visualization.ModelVisualizer;
import java.util.*;
public class Main {
public static void main(String[] args) throws Exception {
Scanner sc = new Scanner(System.in);
ConfigDto config = ConfigLoader.load("C:/Users/Laurent/Desktop/ANN-framework/config.json");
boolean newModel = config.getModelProperty("new", Boolean.class);
int[] modelParameters = config.getModelProperty("parameters", int[].class);
String modelPath = config.getModelProperty("path", String.class);
int maxEpoch = config.getTrainingProperty("max_epoch", Integer.class);
float learningRate = config.getTrainingProperty("learning_rate", Double.class).floatValue();
String datasetPath = config.getDatasetProperty("path", String.class);
System.out.printf("How many classes ? [%d] : ", modelParameters[modelParameters.length-1]);
String input = sc.nextLine().trim();
int nbrClass = input.isEmpty() ? modelParameters[modelParameters.length-1] : Integer.parseInt(input);
System.out.println();
DataSet dataset = new DatasetExtractor().extract(datasetPath, nbrClass);
int nbrInput = dataset.getNbrInputs();
ModelSnapshot snapshot;
Model network;
if(newModel){
network = createNetwork(modelParameters, nbrInput);
snapshot = new ModelSnapshot(network);
System.out.println("Parameters: " + network.synCount());
Trainer trainer = new GradientBackpropagationTraining();
trainer.train(learningRate, maxEpoch, network, dataset);
snapshot.saveToFile(modelPath);
} else {
snapshot = new ModelSnapshot();
snapshot.loadFromFile(modelPath);
network = snapshot.getModel();
}
plotGraph(dataset, network);
new ModelVisualizer(network)
.withWeights(true)
.display();
}
private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){
int neuronId = 0;
List<Layer> layers = new ArrayList<>();
for (int i = 0; i < neuronPerLayer.length; i++){
List<Neuron> neurons = new ArrayList<>();
for (int j = 0; j < neuronPerLayer[i]; j++){
int nbrSyn = i == 0 ? nbrInput: neuronPerLayer[i-1];
List<Synapse> syns = new ArrayList<>();
for (int k=0; k < nbrSyn; k++){
syns.add(new Synapse(new Input(0), new Weight()));
}
Bias bias = new Bias(new Weight());
Neuron n = new Neuron(neuronId, syns.toArray(new Synapse[0]), bias, new TanH());
neurons.add(n);
neuronId++;
}
Layer layer = new Layer(neurons.toArray(new Neuron[0]));
layers.add(layer);
}
return new FullyConnectedNetwork(layers.toArray(new Layer[0]));
}
private static void plotGraph(DataSet dataset, Model network){
if(dataset.getNbrInputs() != 2) return;
GraphVisualizer visualizer = new GraphVisualizer();
float min = -10F;
float max = 10F;
float step = Math.abs(max - min)/300;
//plot labels
for (DataSetEntry entry : dataset) {
List<Float> label = dataset.getLabelsAsFloat(entry);
int expectedClass = 0;
for (int i = 1; i < label.size(); i++) {
if (label.get(i) > label.get(expectedClass)) expectedClass = i;
}
visualizer.addPoint("Label " + expectedClass,
entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
}
//plot predictions
for (float x = min; x < max; x+=step){
for (float y = min; y < max; y+=step){
List<Float> predictions = new ArrayList<>();
for (float p : network.predict(new float[]{x, y})) predictions.add(p);
float highest = Collections.max(predictions);
int predictedClass = predictions.indexOf(highest);
visualizer.addPoint("Predict " + predictedClass, x, y);
}
}
visualizer.buildScatterGraph((int)min-1, (int)max+1);
}
}