Implement batch size
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
package com.naaturel.ANN;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
import com.naaturel.ANN.domain.abstraction.Network;
|
||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.Linear;
|
||||
@@ -22,11 +24,21 @@ public class Main {
|
||||
int nbrClass = 1;
|
||||
|
||||
DataSet dataset = new DatasetExtractor()
|
||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_9.csv", nbrClass);
|
||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_12.csv", nbrClass);
|
||||
|
||||
int[] neuronPerLayer = new int[]{10, 5, 5, dataset.getNbrLabels()};
|
||||
int[] neuronPerLayer = new int[]{50, 50, 50, dataset.getNbrLabels()};
|
||||
int nbrInput = dataset.getNbrInputs();
|
||||
|
||||
FullyConnectedNetwork network = createNetwork(neuronPerLayer, nbrInput);
|
||||
|
||||
Trainer trainer = new GradientBackpropagationTraining();
|
||||
trainer.train(0.01F, 2000, network, dataset);
|
||||
|
||||
//plotGraph(dataset, network);
|
||||
|
||||
}
|
||||
|
||||
private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){
|
||||
List<Layer> layers = new ArrayList<>();
|
||||
for (int i = 0; i < neuronPerLayer.length; i++){
|
||||
|
||||
@@ -49,26 +61,27 @@ public class Main {
|
||||
layers.add(layer);
|
||||
}
|
||||
|
||||
FullyConnectedNetwork network = new FullyConnectedNetwork(layers.toArray(new Layer[0]));
|
||||
|
||||
Trainer trainer = new GradientBackpropagationTraining();
|
||||
trainer.train(0.0005F, 15000, network, dataset);
|
||||
return new FullyConnectedNetwork(layers.toArray(new Layer[0]));
|
||||
}
|
||||
|
||||
private static void plotGraph(DataSet dataset, Model network){
|
||||
GraphVisualizer visualizer = new GraphVisualizer();
|
||||
|
||||
for (DataSetEntry entry : dataset) {
|
||||
List<Float> label = dataset.getLabelsAsFloat(entry);
|
||||
visualizer.addPoint("Label " + label.getFirst(), entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
|
||||
label.forEach(l -> {
|
||||
visualizer.addPoint("Label " + l,
|
||||
entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
|
||||
});
|
||||
}
|
||||
|
||||
float min = 0F;
|
||||
float max = 15F;
|
||||
float step = 0.03F;
|
||||
float min = -5F;
|
||||
float max = 5F;
|
||||
float step = 0.01F;
|
||||
for (float x = min; x < max; x+=step){
|
||||
for (float y = min; y < max; y+=step){
|
||||
float prediction = network.predict(List.of(new Input(x), new Input(y))).getFirst();
|
||||
float predSeries = prediction > 0.5F ? 1 : -1;
|
||||
visualizer.addPoint(Float.toString(predSeries), x, y);
|
||||
List<Float> predictions = network.predict(List.of(new Input(x), new Input(y)));
|
||||
visualizer.addPoint(Float.toString(Math.round(predictions.getFirst())), x, y);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user