Add gradient descent test

This commit is contained in:
2026-03-26 08:23:24 +01:00
parent 76465ab6ee
commit c389646794
4 changed files with 117 additions and 17 deletions

View File

@@ -27,29 +27,27 @@ public class GradientDescentTraining implements Trainer {
GradientDescentTrainingContext context = new GradientDescentTrainingContext();
context.dataset = dataset;
context.model = model;
context.learningRate = 0.2F;
context.correctorTerms = new ArrayList<>();
List<TrainingStep> steps = List.of(
new PredictionStep(new SimplePredictionStrategy(context)),
new DeltaStep(new SimpleDeltaStrategy(context)),
new LossStep(new SquareLossStrategy(context)),
new ErrorRegistrationStep(new GradientDescentErrorStrategy(context)),
new WeightCorrectionStep(new GradientDescentCorrectionStrategy(context))
new ErrorRegistrationStep(new GradientDescentErrorStrategy(context))
);
TrainingPipeline pipeline = new TrainingPipeline(steps);
pipeline
.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 50)
new TrainingPipeline(steps)
.stopCondition(ctx -> ctx.globalLoss <= 0.125F || ctx.epoch > 100)
.beforeEpoch(ctx -> {
ctx.globalLoss = 0.0F;
for (int i = 0; i < model.synCount(); i++){
context.correctorTerms.add(0F);
GradientDescentTrainingContext gdCtx = (GradientDescentTrainingContext) ctx;
gdCtx.globalLoss = 0.0F;
gdCtx.correctorTerms.clear();
for (int i = 0; i < ctx.model.synCount(); i++){
gdCtx.correctorTerms.add(0F);
}
})
.afterEpoch(ctx -> ctx.globalLoss /= ctx.dataset.size())
.afterEpoch(ctx -> new GradientDescentCorrectionStrategy(context).apply())
.withVerbose(true)
.withTimeMeasurement(true)
.run(context);
}

View File

@@ -0,0 +1,4 @@
0,0,-1
0,1,-1
1,0,-1
1,1,1
1 0 0 -1
2 0 1 -1
3 1 0 -1
4 1 1 1

View File

@@ -0,0 +1,98 @@
package gradientDescent;
import com.naaturel.ANN.domain.abstraction.Neuron;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DatasetExtractor;
import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.gradientDescent.*;
import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
import com.naaturel.ANN.implementation.simplePerceptron.*;
import com.naaturel.ANN.implementation.training.steps.*;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import static org.junit.jupiter.api.Assertions.*;
public class GradientDescentTest {
private DataSet dataset;
private GradientDescentTrainingContext context;
private List<Synapse> synapses;
private Bias bias;
private Network network;
private TrainingPipeline pipeline;
@BeforeEach
public void init(){
dataset = new DatasetExtractor()
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and-gradient.csv");
List<Synapse> syns = new ArrayList<>();
syns.add(new Synapse(new Input(0), new Weight(0)));
syns.add(new Synapse(new Input(0), new Weight(0)));
bias = new Bias(new Weight(0));
Neuron neuron = new SimplePerceptron(syns, bias, new Linear());
Layer layer = new Layer(List.of(neuron));
network = new Network(List.of(layer));
context = new GradientDescentTrainingContext();
context.dataset = dataset;
context.model = network;
context.correctorTerms = new ArrayList<>();
List<TrainingStep> steps = List.of(
new PredictionStep(new SimplePredictionStrategy(context)),
new DeltaStep(new SimpleDeltaStrategy(context)),
new LossStep(new SquareLossStrategy(context)),
new ErrorRegistrationStep(new GradientDescentErrorStrategy(context))
);
pipeline = new TrainingPipeline(steps)
.stopCondition(ctx -> ctx.globalLoss <= 0.125F || ctx.epoch > 100)
.beforeEpoch(ctx -> {
GradientDescentTrainingContext gdCtx = (GradientDescentTrainingContext) ctx;
gdCtx.globalLoss = 0.0F;
gdCtx.correctorTerms.clear();
for (int i = 0; i < ctx.model.synCount(); i++){
gdCtx.correctorTerms.add(0F);
}
});
}
@Test
public void test_the_whole_algorithm(){
List<Float> expectedGlobalLosses = List.of(
0.5F,
0.38F,
0.3176F,
0.272096F,
0.237469F
);
context.learningRate = 0.2F;
pipeline.afterEpoch(ctx -> {
context.globalLoss /= context.dataset.size();
new GradientDescentCorrectionStrategy(context).apply();
int index = ctx.epoch-1;
if(index >= expectedGlobalLosses.size()) return;
assertEquals(expectedGlobalLosses.get(index), context.globalLoss, 0.00001f);
});
pipeline.run(context);
assertEquals(67, context.epoch);
}
}

View File

@@ -18,7 +18,7 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
public class simplePerceptronTest {
public class SimplePerceptronTest {
private DataSet dataset;
private SimpleTrainingContext context;