From 17cff89b44ffad141ceefe3c1cea75b2471ea70f Mon Sep 17 00:00:00 2001 From: Laurent Date: Sat, 28 Mar 2026 13:19:36 +0100 Subject: [PATCH] Implement gradient backpropagation stub --- src/main/java/com/naaturel/ANN/Main.java | 7 +++++-- .../implementation/multiLayers/Sigmoid.java | 18 ++++++++++++++++++ .../ANN/implementation/multiLayers/TanH.java | 17 +++++++++++++++++ .../GradientBackpropagationTraining.java | 12 ++++++++++++ .../training/GradientDescentTraining.java | 2 +- 5 files changed, 53 insertions(+), 3 deletions(-) create mode 100644 src/main/java/com/naaturel/ANN/implementation/multiLayers/Sigmoid.java create mode 100644 src/main/java/com/naaturel/ANN/implementation/multiLayers/TanH.java create mode 100644 src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java diff --git a/src/main/java/com/naaturel/ANN/Main.java b/src/main/java/com/naaturel/ANN/Main.java index d53a1ab..d1735b0 100644 --- a/src/main/java/com/naaturel/ANN/Main.java +++ b/src/main/java/com/naaturel/ANN/Main.java @@ -2,6 +2,9 @@ package com.naaturel.ANN; import com.naaturel.ANN.domain.model.neuron.Neuron; import com.naaturel.ANN.domain.abstraction.Trainer; +import com.naaturel.ANN.implementation.multiLayers.Sigmoid; +import com.naaturel.ANN.implementation.multiLayers.TanH; +import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining; import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor; import com.naaturel.ANN.domain.model.neuron.*; @@ -35,7 +38,7 @@ public class Main { Bias bias = new Bias(new Weight(0)); - Neuron n = new Neuron(syns, bias, new Linear()); + Neuron n = new Neuron(syns, bias, new Sigmoid(1)); neurons.add(n); } Layer layer = new Layer(neurons); @@ -43,7 +46,7 @@ public class Main { } Network network = new Network(layers); - Trainer trainer = new GradientDescentTraining(); + Trainer trainer = new GradientBackpropagationTraining(); trainer.train(network, dataset); } diff --git a/src/main/java/com/naaturel/ANN/implementation/multiLayers/Sigmoid.java b/src/main/java/com/naaturel/ANN/implementation/multiLayers/Sigmoid.java new file mode 100644 index 0000000..01149cd --- /dev/null +++ b/src/main/java/com/naaturel/ANN/implementation/multiLayers/Sigmoid.java @@ -0,0 +1,18 @@ +package com.naaturel.ANN.implementation.multiLayers; + +import com.naaturel.ANN.domain.abstraction.ActivationFunction; +import com.naaturel.ANN.domain.model.neuron.Neuron; + +public class Sigmoid implements ActivationFunction { + + private float steepness; + + public Sigmoid(float steepness) { + this.steepness = steepness; + } + + @Override + public float accept(Neuron n) { + return (float) (1.0/(1.0 + Math.exp(-steepness * n.calculateWeightedSum()))); + } +} diff --git a/src/main/java/com/naaturel/ANN/implementation/multiLayers/TanH.java b/src/main/java/com/naaturel/ANN/implementation/multiLayers/TanH.java new file mode 100644 index 0000000..a3b8508 --- /dev/null +++ b/src/main/java/com/naaturel/ANN/implementation/multiLayers/TanH.java @@ -0,0 +1,17 @@ +package com.naaturel.ANN.implementation.multiLayers; + +import com.naaturel.ANN.domain.abstraction.ActivationFunction; +import com.naaturel.ANN.domain.model.neuron.Neuron; + +public class TanH implements ActivationFunction { + + @Override + public float accept(Neuron n) { + //For educational purpose. Math.tanh() could have been used here + float weightedSum = n.calculateWeightedSum(); + double exp = Math.exp(weightedSum); + double res = (exp-(1/exp))/(exp+(1/exp)); + return (float)(res); + } + +} diff --git a/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java new file mode 100644 index 0000000..45452df --- /dev/null +++ b/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java @@ -0,0 +1,12 @@ +package com.naaturel.ANN.implementation.training; + +import com.naaturel.ANN.domain.abstraction.Model; +import com.naaturel.ANN.domain.abstraction.Trainer; +import com.naaturel.ANN.infrastructure.dataset.DataSet; + +public class GradientBackpropagationTraining implements Trainer { + @Override + public void train(Model model, DataSet dataset) { + + } +} diff --git a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java index f6a08d8..bbcf03f 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java @@ -50,7 +50,7 @@ public class GradientDescentTraining implements Trainer { context.globalLoss /= context.dataset.size(); new GradientDescentCorrectionStrategy(context).apply(); }) - .withVerbose(true) + //.withVerbose(true) .withTimeMeasurement(true) .withVisualization(true, new GraphVisualizer()) .run(context);