package com.naaturel.ANN; import com.naaturel.ANN.domain.model.neuron.Neuron; import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.implementation.gradientDescent.Linear; import com.naaturel.ANN.implementation.multiLayers.Sigmoid; import com.naaturel.ANN.implementation.multiLayers.TanH; import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining; import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.infrastructure.dataset.DataSetEntry; import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor; import com.naaturel.ANN.domain.model.neuron.*; import com.naaturel.ANN.infrastructure.graph.GraphVisualizer; import java.io.Console; import java.util.*; public class Main { public static void main(String[] args){ int nbrClass = 1; DataSet dataset = new DatasetExtractor() .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_9.csv", nbrClass); int[] neuronPerLayer = new int[]{10, 5, 5, dataset.getNbrLabels()}; int nbrInput = dataset.getNbrInputs(); List layers = new ArrayList<>(); for (int i = 0; i < neuronPerLayer.length; i++){ List neurons = new ArrayList<>(); for (int j = 0; j < neuronPerLayer[i]; j++){ int nbrSyn = i == 0 ? nbrInput: neuronPerLayer[i-1]; List syns = new ArrayList<>(); for (int k=0; k < nbrSyn; k++){ syns.add(new Synapse(new Input(0), new Weight())); } Bias bias = new Bias(new Weight()); Neuron n = new Neuron(syns.toArray(new Synapse[0]), bias, new TanH()); neurons.add(n); } Layer layer = new Layer(neurons.toArray(new Neuron[0])); layers.add(layer); } FullyConnectedNetwork network = new FullyConnectedNetwork(layers.toArray(new Layer[0])); Trainer trainer = new GradientBackpropagationTraining(); trainer.train(0.0005F, 15000, network, dataset); GraphVisualizer visualizer = new GraphVisualizer(); for (DataSetEntry entry : dataset) { List label = dataset.getLabelsAsFloat(entry); visualizer.addPoint("Label " + label.getFirst(), entry.getData().get(0).getValue(), entry.getData().get(1).getValue()); } float min = 0F; float max = 15F; float step = 0.03F; for (float x = min; x < max; x+=step){ for (float y = min; y < max; y+=step){ float prediction = network.predict(List.of(new Input(x), new Input(y))).getFirst(); float predSeries = prediction > 0.5F ? 1 : -1; visualizer.addPoint(Float.toString(predSeries), x, y); } } visualizer.buildScatterGraph((int)min-1, (int)max+1); } }