Implement gradient backpropagation stub

This commit is contained in:
2026-03-28 13:19:36 +01:00
parent 6d88651385
commit 17cff89b44
5 changed files with 53 additions and 3 deletions

View File

@@ -2,6 +2,9 @@ package com.naaturel.ANN;
import com.naaturel.ANN.domain.model.neuron.Neuron; import com.naaturel.ANN.domain.model.neuron.Neuron;
import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.implementation.multiLayers.Sigmoid;
import com.naaturel.ANN.implementation.multiLayers.TanH;
import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor; import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
import com.naaturel.ANN.domain.model.neuron.*; import com.naaturel.ANN.domain.model.neuron.*;
@@ -35,7 +38,7 @@ public class Main {
Bias bias = new Bias(new Weight(0)); Bias bias = new Bias(new Weight(0));
Neuron n = new Neuron(syns, bias, new Linear()); Neuron n = new Neuron(syns, bias, new Sigmoid(1));
neurons.add(n); neurons.add(n);
} }
Layer layer = new Layer(neurons); Layer layer = new Layer(neurons);
@@ -43,7 +46,7 @@ public class Main {
} }
Network network = new Network(layers); Network network = new Network(layers);
Trainer trainer = new GradientDescentTraining(); Trainer trainer = new GradientBackpropagationTraining();
trainer.train(network, dataset); trainer.train(network, dataset);
} }

View File

@@ -0,0 +1,18 @@
package com.naaturel.ANN.implementation.multiLayers;
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
import com.naaturel.ANN.domain.model.neuron.Neuron;
public class Sigmoid implements ActivationFunction {
private float steepness;
public Sigmoid(float steepness) {
this.steepness = steepness;
}
@Override
public float accept(Neuron n) {
return (float) (1.0/(1.0 + Math.exp(-steepness * n.calculateWeightedSum())));
}
}

View File

@@ -0,0 +1,17 @@
package com.naaturel.ANN.implementation.multiLayers;
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
import com.naaturel.ANN.domain.model.neuron.Neuron;
public class TanH implements ActivationFunction {
@Override
public float accept(Neuron n) {
//For educational purpose. Math.tanh() could have been used here
float weightedSum = n.calculateWeightedSum();
double exp = Math.exp(weightedSum);
double res = (exp-(1/exp))/(exp+(1/exp));
return (float)(res);
}
}

View File

@@ -0,0 +1,12 @@
package com.naaturel.ANN.implementation.training;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
public class GradientBackpropagationTraining implements Trainer {
@Override
public void train(Model model, DataSet dataset) {
}
}

View File

@@ -50,7 +50,7 @@ public class GradientDescentTraining implements Trainer {
context.globalLoss /= context.dataset.size(); context.globalLoss /= context.dataset.size();
new GradientDescentCorrectionStrategy(context).apply(); new GradientDescentCorrectionStrategy(context).apply();
}) })
.withVerbose(true) //.withVerbose(true)
.withTimeMeasurement(true) .withTimeMeasurement(true)
.withVisualization(true, new GraphVisualizer()) .withVisualization(true, new GraphVisualizer())
.run(context); .run(context);