Minor changes
This commit is contained in:
@@ -38,7 +38,7 @@ public class Main {
|
||||
Layer layer = new Layer(neurons);
|
||||
Network network = new Network(List.of(layer));
|
||||
|
||||
Trainer trainer = new AdalineTraining();
|
||||
Trainer trainer = new GradientDescentTraining();
|
||||
trainer.train(network, dataset);
|
||||
|
||||
}
|
||||
|
||||
@@ -81,25 +81,7 @@ public class TrainingPipeline {
|
||||
System.out.printf("[Training finished in %.3fs]", (end-start)/1000.0);
|
||||
}
|
||||
|
||||
AtomicInteger neuronIndex = new AtomicInteger(0);
|
||||
ctx.model.forEachNeuron(n -> {
|
||||
List<Float> weights = new ArrayList<>();
|
||||
n.forEachSynapse(syn -> weights.add(syn.getWeight()));
|
||||
|
||||
float b = weights.get(0);
|
||||
float w1 = weights.get(1);
|
||||
float w2 = weights.get(2);
|
||||
|
||||
this.visualizer.addEquation("boundary_" + neuronIndex.getAndIncrement(), w1, w2, b, -3, 3);
|
||||
});
|
||||
int i = 0;
|
||||
for(DataSetEntry entry : ctx.dataset){
|
||||
List<Input> inputs = entry.getData();
|
||||
this.visualizer.addPoint("p"+i, inputs.get(0).getValue(), inputs.get(1).getValue());
|
||||
this.visualizer.addPoint("p"+i, inputs.get(0).getValue()+0.01F, inputs.get(1).getValue()+0.01F);
|
||||
i++;
|
||||
}
|
||||
this.visualizer.build();
|
||||
if(this.visualization) this.visualize(ctx);
|
||||
}
|
||||
|
||||
private void executeSteps(TrainingContext ctx){
|
||||
@@ -123,4 +105,26 @@ public class TrainingPipeline {
|
||||
ctx.epoch += 1;
|
||||
}
|
||||
|
||||
private void visualize(TrainingContext ctx){
|
||||
AtomicInteger neuronIndex = new AtomicInteger(0);
|
||||
ctx.model.forEachNeuron(n -> {
|
||||
List<Float> weights = new ArrayList<>();
|
||||
n.forEachSynapse(syn -> weights.add(syn.getWeight()));
|
||||
|
||||
float b = weights.get(0);
|
||||
float w1 = weights.get(1);
|
||||
float w2 = weights.get(2);
|
||||
|
||||
this.visualizer.addEquation("boundary_" + neuronIndex.getAndIncrement(), w1, w2, b, -3, 3);
|
||||
});
|
||||
int i = 0;
|
||||
for(DataSetEntry entry : ctx.dataset){
|
||||
List<Input> inputs = entry.getData();
|
||||
this.visualizer.addPoint("p"+i, inputs.get(0).getValue(), inputs.get(1).getValue());
|
||||
this.visualizer.addPoint("p"+i, inputs.get(0).getValue()+0.01F, inputs.get(1).getValue()+0.01F);
|
||||
i++;
|
||||
}
|
||||
this.visualizer.build();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import com.naaturel.ANN.implementation.gradientDescent.SquareLossStrategy;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStrategy;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy;
|
||||
import com.naaturel.ANN.implementation.training.steps.*;
|
||||
import com.naaturel.ANN.infrastructure.graph.GraphVisualizer;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
@@ -27,7 +28,7 @@ public class GradientDescentTraining implements Trainer {
|
||||
GradientDescentTrainingContext context = new GradientDescentTrainingContext();
|
||||
context.dataset = dataset;
|
||||
context.model = model;
|
||||
context.learningRate = 0.0011F;
|
||||
context.learningRate = 0.0005F;
|
||||
context.correctorTerms = new ArrayList<>();
|
||||
|
||||
List<TrainingStep> steps = List.of(
|
||||
@@ -38,7 +39,7 @@ public class GradientDescentTraining implements Trainer {
|
||||
);
|
||||
|
||||
new TrainingPipeline(steps)
|
||||
.stopCondition(ctx -> ctx.globalLoss <= 0.08F || ctx.epoch > 500)
|
||||
.stopCondition(ctx -> ctx.globalLoss <= 0.00F || ctx.epoch > 50000)
|
||||
.beforeEpoch(ctx -> {
|
||||
GradientDescentTrainingContext gdCtx = (GradientDescentTrainingContext) ctx;
|
||||
gdCtx.globalLoss = 0.0F;
|
||||
@@ -49,8 +50,9 @@ public class GradientDescentTraining implements Trainer {
|
||||
context.globalLoss /= context.dataset.size();
|
||||
new GradientDescentCorrectionStrategy(context).apply();
|
||||
})
|
||||
.withVerbose(true)
|
||||
//.withVerbose(true)
|
||||
.withTimeMeasurement(true)
|
||||
.withVisualization(false, new GraphVisualizer())
|
||||
.run(context);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user