Add gradient descent test
This commit is contained in:
@@ -27,29 +27,27 @@ public class GradientDescentTraining implements Trainer {
|
|||||||
GradientDescentTrainingContext context = new GradientDescentTrainingContext();
|
GradientDescentTrainingContext context = new GradientDescentTrainingContext();
|
||||||
context.dataset = dataset;
|
context.dataset = dataset;
|
||||||
context.model = model;
|
context.model = model;
|
||||||
context.learningRate = 0.2F;
|
|
||||||
context.correctorTerms = new ArrayList<>();
|
context.correctorTerms = new ArrayList<>();
|
||||||
|
|
||||||
List<TrainingStep> steps = List.of(
|
List<TrainingStep> steps = List.of(
|
||||||
new PredictionStep(new SimplePredictionStrategy(context)),
|
new PredictionStep(new SimplePredictionStrategy(context)),
|
||||||
new DeltaStep(new SimpleDeltaStrategy(context)),
|
new DeltaStep(new SimpleDeltaStrategy(context)),
|
||||||
new LossStep(new SquareLossStrategy(context)),
|
new LossStep(new SquareLossStrategy(context)),
|
||||||
new ErrorRegistrationStep(new GradientDescentErrorStrategy(context)),
|
new ErrorRegistrationStep(new GradientDescentErrorStrategy(context))
|
||||||
new WeightCorrectionStep(new GradientDescentCorrectionStrategy(context))
|
|
||||||
);
|
);
|
||||||
|
|
||||||
TrainingPipeline pipeline = new TrainingPipeline(steps);
|
new TrainingPipeline(steps)
|
||||||
pipeline
|
.stopCondition(ctx -> ctx.globalLoss <= 0.125F || ctx.epoch > 100)
|
||||||
.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 50)
|
|
||||||
.beforeEpoch(ctx -> {
|
.beforeEpoch(ctx -> {
|
||||||
ctx.globalLoss = 0.0F;
|
GradientDescentTrainingContext gdCtx = (GradientDescentTrainingContext) ctx;
|
||||||
for (int i = 0; i < model.synCount(); i++){
|
gdCtx.globalLoss = 0.0F;
|
||||||
context.correctorTerms.add(0F);
|
gdCtx.correctorTerms.clear();
|
||||||
|
for (int i = 0; i < ctx.model.synCount(); i++){
|
||||||
|
gdCtx.correctorTerms.add(0F);
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.afterEpoch(ctx -> ctx.globalLoss /= ctx.dataset.size())
|
.afterEpoch(ctx -> new GradientDescentCorrectionStrategy(context).apply())
|
||||||
.withVerbose(true)
|
.withVerbose(true)
|
||||||
.withTimeMeasurement(true)
|
|
||||||
.run(context);
|
.run(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
4
src/main/resources/assets/and-gradient.csv
Normal file
4
src/main/resources/assets/and-gradient.csv
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
0,0,-1
|
||||||
|
0,1,-1
|
||||||
|
1,0,-1
|
||||||
|
1,1,1
|
||||||
|
98
src/test/java/gradientDescent/GradientDescentTest.java
Normal file
98
src/test/java/gradientDescent/GradientDescentTest.java
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
package gradientDescent;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
|
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||||
|
import com.naaturel.ANN.domain.model.dataset.DatasetExtractor;
|
||||||
|
import com.naaturel.ANN.domain.model.neuron.*;
|
||||||
|
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||||
|
import com.naaturel.ANN.implementation.gradientDescent.*;
|
||||||
|
import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
|
||||||
|
import com.naaturel.ANN.implementation.simplePerceptron.*;
|
||||||
|
import com.naaturel.ANN.implementation.training.steps.*;
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
|
|
||||||
|
public class GradientDescentTest {
|
||||||
|
|
||||||
|
private DataSet dataset;
|
||||||
|
private GradientDescentTrainingContext context;
|
||||||
|
|
||||||
|
private List<Synapse> synapses;
|
||||||
|
private Bias bias;
|
||||||
|
private Network network;
|
||||||
|
|
||||||
|
private TrainingPipeline pipeline;
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
public void init(){
|
||||||
|
dataset = new DatasetExtractor()
|
||||||
|
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and-gradient.csv");
|
||||||
|
|
||||||
|
List<Synapse> syns = new ArrayList<>();
|
||||||
|
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||||
|
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||||
|
|
||||||
|
bias = new Bias(new Weight(0));
|
||||||
|
|
||||||
|
Neuron neuron = new SimplePerceptron(syns, bias, new Linear());
|
||||||
|
Layer layer = new Layer(List.of(neuron));
|
||||||
|
network = new Network(List.of(layer));
|
||||||
|
|
||||||
|
context = new GradientDescentTrainingContext();
|
||||||
|
context.dataset = dataset;
|
||||||
|
context.model = network;
|
||||||
|
context.correctorTerms = new ArrayList<>();
|
||||||
|
|
||||||
|
List<TrainingStep> steps = List.of(
|
||||||
|
new PredictionStep(new SimplePredictionStrategy(context)),
|
||||||
|
new DeltaStep(new SimpleDeltaStrategy(context)),
|
||||||
|
new LossStep(new SquareLossStrategy(context)),
|
||||||
|
new ErrorRegistrationStep(new GradientDescentErrorStrategy(context))
|
||||||
|
);
|
||||||
|
|
||||||
|
pipeline = new TrainingPipeline(steps)
|
||||||
|
.stopCondition(ctx -> ctx.globalLoss <= 0.125F || ctx.epoch > 100)
|
||||||
|
.beforeEpoch(ctx -> {
|
||||||
|
GradientDescentTrainingContext gdCtx = (GradientDescentTrainingContext) ctx;
|
||||||
|
gdCtx.globalLoss = 0.0F;
|
||||||
|
gdCtx.correctorTerms.clear();
|
||||||
|
for (int i = 0; i < ctx.model.synCount(); i++){
|
||||||
|
gdCtx.correctorTerms.add(0F);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void test_the_whole_algorithm(){
|
||||||
|
|
||||||
|
List<Float> expectedGlobalLosses = List.of(
|
||||||
|
0.5F,
|
||||||
|
0.38F,
|
||||||
|
0.3176F,
|
||||||
|
0.272096F,
|
||||||
|
0.237469F
|
||||||
|
);
|
||||||
|
|
||||||
|
context.learningRate = 0.2F;
|
||||||
|
pipeline.afterEpoch(ctx -> {
|
||||||
|
context.globalLoss /= context.dataset.size();
|
||||||
|
new GradientDescentCorrectionStrategy(context).apply();
|
||||||
|
|
||||||
|
int index = ctx.epoch-1;
|
||||||
|
if(index >= expectedGlobalLosses.size()) return;
|
||||||
|
|
||||||
|
assertEquals(expectedGlobalLosses.get(index), context.globalLoss, 0.00001f);
|
||||||
|
});
|
||||||
|
|
||||||
|
pipeline.run(context);
|
||||||
|
assertEquals(67, context.epoch);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -18,7 +18,7 @@ import java.util.List;
|
|||||||
import static org.junit.jupiter.api.Assertions.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
|
|
||||||
public class simplePerceptronTest {
|
public class SimplePerceptronTest {
|
||||||
|
|
||||||
private DataSet dataset;
|
private DataSet dataset;
|
||||||
private SimpleTrainingContext context;
|
private SimpleTrainingContext context;
|
||||||
Reference in New Issue
Block a user