Optimize some stuff

This commit is contained in:
2026-04-01 16:14:13 +02:00
parent daba4f8420
commit 1e8b02089c
20 changed files with 150 additions and 102 deletions

View File

@@ -24,21 +24,23 @@ public class Main {
int nbrClass = 1;
DataSet dataset = new DatasetExtractor()
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_12.csv", nbrClass);
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_9.csv", nbrClass);
int[] neuronPerLayer = new int[]{50, 50, 50, dataset.getNbrLabels()};
int[] neuronPerLayer = new int[]{1800, 2, 1800, dataset.getNbrLabels()};
int nbrInput = dataset.getNbrInputs();
FullyConnectedNetwork network = createNetwork(neuronPerLayer, nbrInput);
System.out.println(network.synCount());
Trainer trainer = new GradientBackpropagationTraining();
trainer.train(0.01F, 2000, network, dataset);
//plotGraph(dataset, network);
}
private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){
int neuronId = 0;
List<Layer> layers = new ArrayList<>();
for (int i = 0; i < neuronPerLayer.length; i++){
@@ -54,8 +56,9 @@ public class Main {
Bias bias = new Bias(new Weight());
Neuron n = new Neuron(syns.toArray(new Synapse[0]), bias, new TanH());
Neuron n = new Neuron(neuronId, syns.toArray(new Synapse[0]), bias, new TanH());
neurons.add(n);
neuronId++;
}
Layer layer = new Layer(neurons.toArray(new Neuron[0]));
layers.add(layer);
@@ -77,7 +80,7 @@ public class Main {
float min = -5F;
float max = 5F;
float step = 0.01F;
float step = 0.03F;
for (float x = min; x < max; x+=step){
for (float y = min; y < max; y+=step){
List<Float> predictions = network.predict(List.of(new Input(x), new Input(y)));