From 6742b184732fc29e96ff9d82f4c09f224da0cbe2 Mon Sep 17 00:00:00 2001 From: Laurent Date: Fri, 20 Mar 2026 16:58:51 +0100 Subject: [PATCH] Just a regular commit --- src/main/java/com/naaturel/ANN/Main.java | 5 +- .../training/AdalineTraining.java | 85 +++++++++++++++++++ .../training/GradientDescentTraining.java | 16 ++-- 3 files changed, 96 insertions(+), 10 deletions(-) diff --git a/src/main/java/com/naaturel/ANN/Main.java b/src/main/java/com/naaturel/ANN/Main.java index ee77513..856eca2 100644 --- a/src/main/java/com/naaturel/ANN/Main.java +++ b/src/main/java/com/naaturel/ANN/Main.java @@ -11,6 +11,7 @@ import com.naaturel.ANN.domain.model.neuron.Weight; import com.naaturel.ANN.implementation.activationFunction.Heaviside; import com.naaturel.ANN.implementation.activationFunction.Linear; import com.naaturel.ANN.implementation.neuron.SimplePerceptron; +import com.naaturel.ANN.implementation.training.AdalineTraining; import com.naaturel.ANN.implementation.training.GradientDescentTraining; import java.util.*; @@ -64,11 +65,11 @@ public class Main { Bias bias = new Bias(new Weight(0)); Neuron n = new SimplePerceptron(syns, bias, new Linear()); - GradientDescentTraining st = new GradientDescentTraining(); + AdalineTraining st = new AdalineTraining(); long start = System.currentTimeMillis(); - st.train(n, 0.2F, andDataSet); + st.train(n, 0.03F, andDataSet); long end = System.currentTimeMillis(); System.out.printf("Training completed in %.2f s%n", (end - start) / 1000.0); diff --git a/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java index 1a58e99..0c2790d 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java @@ -1,4 +1,89 @@ package com.naaturel.ANN.implementation.training; +import com.naaturel.ANN.domain.abstraction.Neuron; +import com.naaturel.ANN.domain.model.dataset.DataSet; +import com.naaturel.ANN.domain.model.dataset.DataSetEntry; +import com.naaturel.ANN.domain.model.neuron.Bias; +import com.naaturel.ANN.domain.model.neuron.Input; +import com.naaturel.ANN.domain.model.neuron.Synapse; +import com.naaturel.ANN.domain.model.neuron.Weight; + +import java.util.ArrayList; +import java.util.List; + public class AdalineTraining { + + public AdalineTraining(){ + + } + + public void train(Neuron n, float learningRate, DataSet dataSet) { + int epoch = 1; + int maxEpoch = 1000; + float errorThreshold = 0.0F; + float mse; + + do { + if(epoch > maxEpoch) break; + mse = 0; + + for(DataSetEntry entry : dataSet) { + this.updateInputs(n, entry); + float prediction = n.predict(); + float expectation = dataSet.getLabel(entry).getValue(); + float delta = this.calculateDelta(expectation, prediction); + float loss = this.calculateLoss(delta); + + mse += loss; + + float currentBias = n.getBias().getWeight(); + float biasCorrector = currentBias + (learningRate * delta * n.getBias().getInput()); + n.updateBias(new Weight(biasCorrector)); + + for(Synapse syn : n.getSynapses()){ + float synCorrector = syn.getWeight() + (learningRate * delta * syn.getInput()); + syn.setWeight(synCorrector); + } + + System.out.printf("Epoch : %d ", epoch); + System.out.printf("predicted : %.2f, ", prediction); + System.out.printf("expected : %.2f, ", expectation); + System.out.printf("delta : %.2f, ", delta); + System.out.printf("loss : %.2f\n", loss); + } + System.out.printf("[Total error : %f]\n", mse); + + epoch++; + } while(mse > errorThreshold); + + } + + private List initCorrectorTerms(int number){ + List res = new ArrayList<>(); + for(int i = 0; i < number; i++){ + res.add(0F); + } + return res; + } + + private void updateInputs(Neuron n, DataSetEntry entry){ + int index = 0; + for(float value : entry){ + n.setInput(index, new Input(value)); + index++; + } + } + + private float calculateDelta(float expected, float predicted){ + return expected - predicted; + } + + private float calculateLoss(float delta){ + return (float) Math.pow(delta, 2)/2; + } + + private float calculateWeightCorrection(float value, float delta){ + return value * delta; + } + } diff --git a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java index 42022e3..f86c91b 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java @@ -19,15 +19,15 @@ public class GradientDescentTraining { public void train(Neuron n, float learningRate, DataSet dataSet) { int epoch = 1; - int maxEpoch = 200; - float errorThreshold = 0.125F; - float currentError; + int maxEpoch = 1000; + float errorThreshold = 0.0F; + float mse; do { if(epoch > maxEpoch) break; float biasCorrector = 0; - currentError = 0; + mse = 0; List correctorTerms = this.initCorrectorTerms(n.getSynCount()); for(DataSetEntry entry : dataSet) { @@ -37,7 +37,7 @@ public class GradientDescentTraining { float delta = this.calculateDelta(expectation, prediction); float loss = this.calculateLoss(delta); - currentError += loss/dataSet.size(); + mse += loss; biasCorrector += learningRate * delta * n.getBias().getInput(); @@ -54,7 +54,7 @@ public class GradientDescentTraining { System.out.printf("delta : %.2f, ", delta); System.out.printf("loss : %.2f\n", loss); } - System.out.printf("[Total error : %.3f]\n", currentError); + System.out.printf("[Total error : %f]\n", mse); float currentBias = n.getBias().getWeight(); float newBias = currentBias + biasCorrector; @@ -67,7 +67,7 @@ public class GradientDescentTraining { } epoch++; - } while(currentError > errorThreshold); + } while(mse > errorThreshold); } @@ -92,7 +92,7 @@ public class GradientDescentTraining { } private float calculateLoss(float delta){ - return ((float) Math.pow(delta, 2))/2; + return (float) Math.pow(delta, 2)/2; } private float calculateWeightCorrection(float value, float delta){