Minor fixes
This commit is contained in:
@@ -2,6 +2,7 @@ package com.naaturel.ANN;
|
||||
|
||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.Linear;
|
||||
import com.naaturel.ANN.implementation.multiLayers.Sigmoid;
|
||||
import com.naaturel.ANN.implementation.multiLayers.TanH;
|
||||
import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
|
||||
@@ -18,12 +19,12 @@ public class Main {
|
||||
|
||||
public static void main(String[] args){
|
||||
|
||||
int nbrClass = 3;
|
||||
int nbrClass = 1;
|
||||
|
||||
DataSet dataset = new DatasetExtractor()
|
||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_14.csv", nbrClass);
|
||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_9.csv", nbrClass);
|
||||
|
||||
int[] neuronPerLayer = new int[]{27, dataset.getNbrLabels()};
|
||||
int[] neuronPerLayer = new int[]{10, 5, 5, dataset.getNbrLabels()};
|
||||
int nbrInput = dataset.getNbrInputs();
|
||||
|
||||
List<Layer> layers = new ArrayList<>();
|
||||
@@ -51,7 +52,7 @@ public class Main {
|
||||
FullyConnectedNetwork network = new FullyConnectedNetwork(layers.toArray(new Layer[0]));
|
||||
|
||||
Trainer trainer = new GradientBackpropagationTraining();
|
||||
trainer.train(0.0001F, 15000, network, dataset);
|
||||
trainer.train(0.0005F, 15000, network, dataset);
|
||||
|
||||
GraphVisualizer visualizer = new GraphVisualizer();
|
||||
|
||||
@@ -60,13 +61,13 @@ public class Main {
|
||||
visualizer.addPoint("Label " + label.getFirst(), entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
|
||||
}
|
||||
|
||||
float min = -5F;
|
||||
float max = 5F;
|
||||
float step = 0.025F;
|
||||
float min = 0F;
|
||||
float max = 15F;
|
||||
float step = 0.03F;
|
||||
for (float x = min; x < max; x+=step){
|
||||
for (float y = min; y < max; y+=step){
|
||||
float prediction = network.predict(List.of(new Input(x), new Input(y))).getFirst();
|
||||
float predSeries = prediction > 0.5F ? 1 : 0;
|
||||
float predSeries = prediction > 0.5F ? 1 : -1;
|
||||
visualizer.addPoint(Float.toString(predSeries), x, y);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
|
||||
public class BackpropagationCorrectionStep implements AlgorithmStep {
|
||||
|
||||
private GradientBackpropagationContext context;
|
||||
private GradientBackpropagationContext context;
|
||||
|
||||
public BackpropagationCorrectionStep(GradientBackpropagationContext context){
|
||||
this.context = context;
|
||||
|
||||
@@ -30,12 +30,12 @@ public class GradientBackpropagationTraining implements Trainer {
|
||||
);
|
||||
|
||||
new TrainingPipeline(steps)
|
||||
.stopCondition(ctx -> ctx.globalLoss <= 0.00F || ctx.epoch > epoch)
|
||||
.stopCondition(ctx -> ctx.globalLoss <= 0.001F || ctx.epoch > epoch)
|
||||
.beforeEpoch(ctx -> {
|
||||
ctx.globalLoss = 0.0F;
|
||||
})
|
||||
.afterEpoch(ctx -> ctx.globalLoss /= dataset.size())
|
||||
.withVerbose(true, epoch/10)
|
||||
.withVerbose(true,epoch/10)
|
||||
.withTimeMeasurement(true)
|
||||
.run(context);
|
||||
}
|
||||
|
||||
@@ -19,3 +19,10 @@
|
||||
4,7,-1
|
||||
4,9,1
|
||||
4,10,1
|
||||
2,6,-1
|
||||
7,7,-1
|
||||
5,9,1
|
||||
9,10,1
|
||||
7,1,-1
|
||||
5,0,1
|
||||
9,5,1
|
||||
|
Reference in New Issue
Block a user