Files
ANN-framework/src/main/java/com/naaturel/ANN/Main.java
2026-03-31 16:26:28 +02:00

79 lines
2.9 KiB
Java

package com.naaturel.ANN;
import com.naaturel.ANN.domain.model.neuron.Neuron;
import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.implementation.gradientDescent.Linear;
import com.naaturel.ANN.implementation.multiLayers.Sigmoid;
import com.naaturel.ANN.implementation.multiLayers.TanH;
import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.infrastructure.graph.GraphVisualizer;
import java.io.Console;
import java.util.*;
public class Main {
public static void main(String[] args){
int nbrClass = 1;
DataSet dataset = new DatasetExtractor()
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_9.csv", nbrClass);
int[] neuronPerLayer = new int[]{10, 5, 5, dataset.getNbrLabels()};
int nbrInput = dataset.getNbrInputs();
List<Layer> layers = new ArrayList<>();
for (int i = 0; i < neuronPerLayer.length; i++){
List<Neuron> neurons = new ArrayList<>();
for (int j = 0; j < neuronPerLayer[i]; j++){
int nbrSyn = i == 0 ? nbrInput: neuronPerLayer[i-1];
List<Synapse> syns = new ArrayList<>();
for (int k=0; k < nbrSyn; k++){
syns.add(new Synapse(new Input(0), new Weight()));
}
Bias bias = new Bias(new Weight());
Neuron n = new Neuron(syns.toArray(new Synapse[0]), bias, new TanH());
neurons.add(n);
}
Layer layer = new Layer(neurons.toArray(new Neuron[0]));
layers.add(layer);
}
FullyConnectedNetwork network = new FullyConnectedNetwork(layers.toArray(new Layer[0]));
Trainer trainer = new GradientBackpropagationTraining();
trainer.train(0.0005F, 15000, network, dataset);
GraphVisualizer visualizer = new GraphVisualizer();
for (DataSetEntry entry : dataset) {
List<Float> label = dataset.getLabelsAsFloat(entry);
visualizer.addPoint("Label " + label.getFirst(), entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
}
float min = 0F;
float max = 15F;
float step = 0.03F;
for (float x = min; x < max; x+=step){
for (float y = min; y < max; y+=step){
float prediction = network.predict(List.of(new Input(x), new Input(y))).getFirst();
float predSeries = prediction > 0.5F ? 1 : -1;
visualizer.addPoint(Float.toString(predSeries), x, y);
}
}
visualizer.buildScatterGraph((int)min-1, (int)max+1);
}
}