Fix weighted sum back

This commit is contained in:
2026-04-01 17:40:33 +02:00
parent 1e8b02089c
commit 4441b149f9
4 changed files with 12 additions and 14 deletions

View File

@@ -24,9 +24,9 @@ public class Main {
int nbrClass = 1; int nbrClass = 1;
DataSet dataset = new DatasetExtractor() DataSet dataset = new DatasetExtractor()
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_9.csv", nbrClass); .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_12.csv", nbrClass);
int[] neuronPerLayer = new int[]{1800, 2, 1800, dataset.getNbrLabels()}; int[] neuronPerLayer = new int[]{10, 5, 10, dataset.getNbrLabels()};
int nbrInput = dataset.getNbrInputs(); int nbrInput = dataset.getNbrInputs();
FullyConnectedNetwork network = createNetwork(neuronPerLayer, nbrInput); FullyConnectedNetwork network = createNetwork(neuronPerLayer, nbrInput);
@@ -36,7 +36,7 @@ public class Main {
Trainer trainer = new GradientBackpropagationTraining(); Trainer trainer = new GradientBackpropagationTraining();
trainer.train(0.01F, 2000, network, dataset); trainer.train(0.01F, 2000, network, dataset);
//plotGraph(dataset, network); plotGraph(dataset, network);
} }
private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){ private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){
@@ -78,8 +78,8 @@ public class Main {
}); });
} }
float min = -5F; float min = -3F;
float max = 5F; float max = 3F;
float step = 0.03F; float step = 0.03F;
for (float x = min; x < max; x+=step){ for (float x = min; x < max; x+=step){
for (float y = min; y < max; y+=step){ for (float y = min; y < max; y+=step){

View File

@@ -23,7 +23,7 @@ public abstract class TrainingContext {
public TrainingContext(Model model, DataSet dataset) { public TrainingContext(Model model, DataSet dataset) {
this.model = model; this.model = model;
this.dataset = dataset; this.dataset = dataset;
this.deltas = new float[model.neuronCount()]; this.deltas = new float[dataset.getNbrLabels()];
} }
} }

View File

@@ -60,14 +60,12 @@ public class Neuron implements Model {
} }
public float calculateWeightedSum() { public float calculateWeightedSum() {
float sum = bias.getWeight() * bias.getInput(); this.weightedSum = 0F;
this.weightedSum += this.bias.getWeight() * this.bias.getInput();
for (int i = 0; i < weights.length; i++) { for(Synapse syn : this.synapses){
sum += weights[i] * inputs[i]; this.weightedSum += syn.getWeight() * syn.getInput();
} }
return this.weightedSum;
this.weightedSum = sum;
return sum;
} }
public int getId(){ public int getId(){

View File

@@ -35,7 +35,7 @@ public class GradientBackpropagationTraining implements Trainer {
.afterEpoch(ctx -> { .afterEpoch(ctx -> {
ctx.globalLoss /= dataset.size(); ctx.globalLoss /= dataset.size();
}) })
.withVerbose(false,epoch/10) .withVerbose(true,epoch/10)
.withTimeMeasurement(true) .withTimeMeasurement(true)
.run(context); .run(context);
} }