diff --git a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java index fa1893d..09804f4 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java @@ -27,30 +27,28 @@ public class GradientDescentTraining implements Trainer { GradientDescentTrainingContext context = new GradientDescentTrainingContext(); context.dataset = dataset; context.model = model; - context.learningRate = 0.2F; context.correctorTerms = new ArrayList<>(); List steps = List.of( new PredictionStep(new SimplePredictionStrategy(context)), new DeltaStep(new SimpleDeltaStrategy(context)), new LossStep(new SquareLossStrategy(context)), - new ErrorRegistrationStep(new GradientDescentErrorStrategy(context)), - new WeightCorrectionStep(new GradientDescentCorrectionStrategy(context)) + new ErrorRegistrationStep(new GradientDescentErrorStrategy(context)) ); - TrainingPipeline pipeline = new TrainingPipeline(steps); - pipeline - .stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 50) - .beforeEpoch(ctx -> { - ctx.globalLoss = 0.0F; - for (int i = 0; i < model.synCount(); i++){ - context.correctorTerms.add(0F); - } - }) - .afterEpoch(ctx -> ctx.globalLoss /= ctx.dataset.size()) - .withVerbose(true) - .withTimeMeasurement(true) - .run(context); + new TrainingPipeline(steps) + .stopCondition(ctx -> ctx.globalLoss <= 0.125F || ctx.epoch > 100) + .beforeEpoch(ctx -> { + GradientDescentTrainingContext gdCtx = (GradientDescentTrainingContext) ctx; + gdCtx.globalLoss = 0.0F; + gdCtx.correctorTerms.clear(); + for (int i = 0; i < ctx.model.synCount(); i++){ + gdCtx.correctorTerms.add(0F); + } + }) + .afterEpoch(ctx -> new GradientDescentCorrectionStrategy(context).apply()) + .withVerbose(true) + .run(context); } /*public void train(Neuron n, float learningRate, DataSet dataSet) { diff --git a/src/main/resources/assets/and-gradient.csv b/src/main/resources/assets/and-gradient.csv new file mode 100644 index 0000000..998919c --- /dev/null +++ b/src/main/resources/assets/and-gradient.csv @@ -0,0 +1,4 @@ +0,0,-1 +0,1,-1 +1,0,-1 +1,1,1 \ No newline at end of file diff --git a/src/test/java/gradientDescent/GradientDescentTest.java b/src/test/java/gradientDescent/GradientDescentTest.java new file mode 100644 index 0000000..948025f --- /dev/null +++ b/src/test/java/gradientDescent/GradientDescentTest.java @@ -0,0 +1,98 @@ +package gradientDescent; + +import com.naaturel.ANN.domain.abstraction.Neuron; +import com.naaturel.ANN.domain.abstraction.TrainingStep; +import com.naaturel.ANN.domain.model.dataset.DataSet; +import com.naaturel.ANN.domain.model.dataset.DatasetExtractor; +import com.naaturel.ANN.domain.model.neuron.*; +import com.naaturel.ANN.domain.model.training.TrainingPipeline; +import com.naaturel.ANN.implementation.gradientDescent.*; +import com.naaturel.ANN.implementation.neuron.SimplePerceptron; +import com.naaturel.ANN.implementation.simplePerceptron.*; +import com.naaturel.ANN.implementation.training.steps.*; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.jupiter.api.Assertions.*; + + +public class GradientDescentTest { + + private DataSet dataset; + private GradientDescentTrainingContext context; + + private List synapses; + private Bias bias; + private Network network; + + private TrainingPipeline pipeline; + + @BeforeEach + public void init(){ + dataset = new DatasetExtractor() + .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and-gradient.csv"); + + List syns = new ArrayList<>(); + syns.add(new Synapse(new Input(0), new Weight(0))); + syns.add(new Synapse(new Input(0), new Weight(0))); + + bias = new Bias(new Weight(0)); + + Neuron neuron = new SimplePerceptron(syns, bias, new Linear()); + Layer layer = new Layer(List.of(neuron)); + network = new Network(List.of(layer)); + + context = new GradientDescentTrainingContext(); + context.dataset = dataset; + context.model = network; + context.correctorTerms = new ArrayList<>(); + + List steps = List.of( + new PredictionStep(new SimplePredictionStrategy(context)), + new DeltaStep(new SimpleDeltaStrategy(context)), + new LossStep(new SquareLossStrategy(context)), + new ErrorRegistrationStep(new GradientDescentErrorStrategy(context)) + ); + + pipeline = new TrainingPipeline(steps) + .stopCondition(ctx -> ctx.globalLoss <= 0.125F || ctx.epoch > 100) + .beforeEpoch(ctx -> { + GradientDescentTrainingContext gdCtx = (GradientDescentTrainingContext) ctx; + gdCtx.globalLoss = 0.0F; + gdCtx.correctorTerms.clear(); + for (int i = 0; i < ctx.model.synCount(); i++){ + gdCtx.correctorTerms.add(0F); + } + }); + } + + @Test + public void test_the_whole_algorithm(){ + + List expectedGlobalLosses = List.of( + 0.5F, + 0.38F, + 0.3176F, + 0.272096F, + 0.237469F + ); + + context.learningRate = 0.2F; + pipeline.afterEpoch(ctx -> { + context.globalLoss /= context.dataset.size(); + new GradientDescentCorrectionStrategy(context).apply(); + + int index = ctx.epoch-1; + if(index >= expectedGlobalLosses.size()) return; + + assertEquals(expectedGlobalLosses.get(index), context.globalLoss, 0.00001f); + }); + + pipeline.run(context); + assertEquals(67, context.epoch); + } +} diff --git a/src/test/java/perceptron/simplePerceptronTest.java b/src/test/java/perceptron/SimplePerceptronTest.java similarity index 98% rename from src/test/java/perceptron/simplePerceptronTest.java rename to src/test/java/perceptron/SimplePerceptronTest.java index 4b8fd79..13ac89f 100644 --- a/src/test/java/perceptron/simplePerceptronTest.java +++ b/src/test/java/perceptron/SimplePerceptronTest.java @@ -18,7 +18,7 @@ import java.util.List; import static org.junit.jupiter.api.Assertions.*; -public class simplePerceptronTest { +public class SimplePerceptronTest { private DataSet dataset; private SimpleTrainingContext context;