Fix multi layer implementation

This commit is contained in:
2026-03-30 21:13:03 +02:00
parent ada01d350b
commit fd97d0853c
15 changed files with 108 additions and 71 deletions

View File

@@ -17,12 +17,12 @@ public class Main {
public static void main(String[] args){
int nbrClass = 1;
int nbrClass = 3;
DataSet dataset = new DatasetExtractor()
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_12.csv", nbrClass);
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_14.csv", nbrClass);
int[] neuronPerLayer = new int[]{10, dataset.getNbrLabels()};
int[] neuronPerLayer = new int[]{3, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 37, dataset.getNbrLabels()};
int nbrInput = dataset.getNbrInputs();
List<Layer> layers = new ArrayList<>();
@@ -40,7 +40,7 @@ public class Main {
Bias bias = new Bias(new Weight());
Neuron n = new Neuron(syns, bias, new Sigmoid(2));
Neuron n = new Neuron(syns, bias, new TanH());
neurons.add(n);
}
Layer layer = new Layer(neurons);
@@ -50,7 +50,7 @@ public class Main {
FullyConnectedNetwork network = new FullyConnectedNetwork(layers);
Trainer trainer = new GradientBackpropagationTraining();
trainer.train(0.5F, network, dataset);
trainer.train(0.001F, 1000, network, dataset);
/*GraphVisualizer visualizer = new GraphVisualizer();
@@ -59,7 +59,7 @@ public class Main {
visualizer.addPoint("Label " + label.getFirst(), entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
}
float min = -2F;
float min = -3F;
float max = 2F;
float step = 0.01F;
for (float x = min; x < max; x+=step){