package com.naaturel.ANN; import com.naaturel.ANN.domain.model.neuron.Neuron; import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.implementation.multiLayers.Sigmoid; import com.naaturel.ANN.implementation.multiLayers.TanH; import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining; import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.infrastructure.dataset.DataSetEntry; import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor; import com.naaturel.ANN.domain.model.neuron.*; import com.naaturel.ANN.infrastructure.graph.GraphVisualizer; import java.util.*; public class Main { public static void main(String[] args){ int nbrClass = 1; DataSet dataset = new DatasetExtractor() .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_12.csv", nbrClass); int[] neuronPerLayer = new int[]{10, dataset.getNbrLabels()}; int nbrInput = dataset.getNbrInputs(); List layers = new ArrayList<>(); for (int i = 0; i < neuronPerLayer.length; i++){ List neurons = new ArrayList<>(); for (int j = 0; j < neuronPerLayer[i]; j++){ int nbrSyn = i == 0 ? nbrInput: neuronPerLayer[i-1]; List syns = new ArrayList<>(); for (int k=0; k < nbrSyn; k++){ syns.add(new Synapse(new Input(0), new Weight())); } Bias bias = new Bias(new Weight()); Neuron n = new Neuron(syns, bias, new Sigmoid(2)); neurons.add(n); } Layer layer = new Layer(neurons); layers.add(layer); } FullyConnectedNetwork network = new FullyConnectedNetwork(layers); Trainer trainer = new GradientBackpropagationTraining(); trainer.train(0.5F, network, dataset); /*GraphVisualizer visualizer = new GraphVisualizer(); for (DataSetEntry entry : dataset) { List label = dataset.getLabelsAsFloat(entry); visualizer.addPoint("Label " + label.getFirst(), entry.getData().get(0).getValue(), entry.getData().get(1).getValue()); } float min = -2F; float max = 2F; float step = 0.01F; for (float x = min; x < max; x+=step){ for (float y = min; y < max; y+=step){ float prediction = network.predict(List.of(new Input(x), new Input(y))).getFirst(); float predSeries = prediction > 0.5F ? 1 : 0; visualizer.addPoint(Float.toString(predSeries), x, y); } } visualizer.buildScatterGraph();*/ } }