From 4441b149f93a415131b1abb8f24648b2354a6ab0 Mon Sep 17 00:00:00 2001 From: Laurent Date: Wed, 1 Apr 2026 17:40:33 +0200 Subject: [PATCH] Fix weighted sum back --- src/main/java/com/naaturel/ANN/Main.java | 10 +++++----- .../ANN/domain/abstraction/TrainingContext.java | 2 +- .../com/naaturel/ANN/domain/model/neuron/Neuron.java | 12 +++++------- .../training/GradientBackpropagationTraining.java | 2 +- 4 files changed, 12 insertions(+), 14 deletions(-) diff --git a/src/main/java/com/naaturel/ANN/Main.java b/src/main/java/com/naaturel/ANN/Main.java index b551c4c..9cc16bc 100644 --- a/src/main/java/com/naaturel/ANN/Main.java +++ b/src/main/java/com/naaturel/ANN/Main.java @@ -24,9 +24,9 @@ public class Main { int nbrClass = 1; DataSet dataset = new DatasetExtractor() - .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_9.csv", nbrClass); + .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_12.csv", nbrClass); - int[] neuronPerLayer = new int[]{1800, 2, 1800, dataset.getNbrLabels()}; + int[] neuronPerLayer = new int[]{10, 5, 10, dataset.getNbrLabels()}; int nbrInput = dataset.getNbrInputs(); FullyConnectedNetwork network = createNetwork(neuronPerLayer, nbrInput); @@ -36,7 +36,7 @@ public class Main { Trainer trainer = new GradientBackpropagationTraining(); trainer.train(0.01F, 2000, network, dataset); - //plotGraph(dataset, network); + plotGraph(dataset, network); } private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){ @@ -78,8 +78,8 @@ public class Main { }); } - float min = -5F; - float max = 5F; + float min = -3F; + float max = 3F; float step = 0.03F; for (float x = min; x < max; x+=step){ for (float y = min; y < max; y+=step){ diff --git a/src/main/java/com/naaturel/ANN/domain/abstraction/TrainingContext.java b/src/main/java/com/naaturel/ANN/domain/abstraction/TrainingContext.java index 35f65af..7de9377 100644 --- a/src/main/java/com/naaturel/ANN/domain/abstraction/TrainingContext.java +++ b/src/main/java/com/naaturel/ANN/domain/abstraction/TrainingContext.java @@ -23,7 +23,7 @@ public abstract class TrainingContext { public TrainingContext(Model model, DataSet dataset) { this.model = model; this.dataset = dataset; - this.deltas = new float[model.neuronCount()]; + this.deltas = new float[dataset.getNbrLabels()]; } } diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java index c6ee3f8..f5356a7 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java +++ b/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java @@ -60,14 +60,12 @@ public class Neuron implements Model { } public float calculateWeightedSum() { - float sum = bias.getWeight() * bias.getInput(); - - for (int i = 0; i < weights.length; i++) { - sum += weights[i] * inputs[i]; + this.weightedSum = 0F; + this.weightedSum += this.bias.getWeight() * this.bias.getInput(); + for(Synapse syn : this.synapses){ + this.weightedSum += syn.getWeight() * syn.getInput(); } - - this.weightedSum = sum; - return sum; + return this.weightedSum; } public int getId(){ diff --git a/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java index c40915f..2f7c80c 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java @@ -35,7 +35,7 @@ public class GradientBackpropagationTraining implements Trainer { .afterEpoch(ctx -> { ctx.globalLoss /= dataset.size(); }) - .withVerbose(false,epoch/10) + .withVerbose(true,epoch/10) .withTimeMeasurement(true) .run(context); }