78 lines
2.7 KiB
Java
78 lines
2.7 KiB
Java
package com.naaturel.ANN;
|
|
|
|
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
|
import com.naaturel.ANN.domain.abstraction.Trainer;
|
|
import com.naaturel.ANN.implementation.multiLayers.Sigmoid;
|
|
import com.naaturel.ANN.implementation.multiLayers.TanH;
|
|
import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
|
|
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
|
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
|
|
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
|
import com.naaturel.ANN.domain.model.neuron.*;
|
|
import com.naaturel.ANN.infrastructure.graph.GraphVisualizer;
|
|
|
|
import java.util.*;
|
|
|
|
public class Main {
|
|
|
|
public static void main(String[] args){
|
|
|
|
int nbrClass = 1;
|
|
|
|
DataSet dataset = new DatasetExtractor()
|
|
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_12.csv", nbrClass);
|
|
|
|
int[] neuronPerLayer = new int[]{10, dataset.getNbrLabels()};
|
|
int nbrInput = dataset.getNbrInputs();
|
|
|
|
List<Layer> layers = new ArrayList<>();
|
|
for (int i = 0; i < neuronPerLayer.length; i++){
|
|
|
|
List<Neuron> neurons = new ArrayList<>();
|
|
for (int j = 0; j < neuronPerLayer[i]; j++){
|
|
|
|
int nbrSyn = i == 0 ? nbrInput: neuronPerLayer[i-1];
|
|
|
|
List<Synapse> syns = new ArrayList<>();
|
|
for (int k=0; k < nbrSyn; k++){
|
|
syns.add(new Synapse(new Input(0), new Weight()));
|
|
}
|
|
|
|
Bias bias = new Bias(new Weight());
|
|
|
|
Neuron n = new Neuron(syns, bias, new Sigmoid(2));
|
|
neurons.add(n);
|
|
}
|
|
Layer layer = new Layer(neurons);
|
|
layers.add(layer);
|
|
}
|
|
|
|
FullyConnectedNetwork network = new FullyConnectedNetwork(layers);
|
|
|
|
Trainer trainer = new GradientBackpropagationTraining();
|
|
trainer.train(0.5F, network, dataset);
|
|
|
|
/*GraphVisualizer visualizer = new GraphVisualizer();
|
|
|
|
for (DataSetEntry entry : dataset) {
|
|
List<Float> label = dataset.getLabelsAsFloat(entry);
|
|
visualizer.addPoint("Label " + label.getFirst(), entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
|
|
}
|
|
|
|
float min = -2F;
|
|
float max = 2F;
|
|
float step = 0.01F;
|
|
for (float x = min; x < max; x+=step){
|
|
for (float y = min; y < max; y+=step){
|
|
float prediction = network.predict(List.of(new Input(x), new Input(y))).getFirst();
|
|
float predSeries = prediction > 0.5F ? 1 : 0;
|
|
visualizer.addPoint(Float.toString(predSeries), x, y);
|
|
}
|
|
}
|
|
|
|
|
|
visualizer.buildScatterGraph();*/
|
|
}
|
|
|
|
}
|