Minor changes
This commit is contained in:
@@ -38,7 +38,7 @@ public class Main {
|
|||||||
Layer layer = new Layer(neurons);
|
Layer layer = new Layer(neurons);
|
||||||
Network network = new Network(List.of(layer));
|
Network network = new Network(List.of(layer));
|
||||||
|
|
||||||
Trainer trainer = new AdalineTraining();
|
Trainer trainer = new GradientDescentTraining();
|
||||||
trainer.train(network, dataset);
|
trainer.train(network, dataset);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -81,25 +81,7 @@ public class TrainingPipeline {
|
|||||||
System.out.printf("[Training finished in %.3fs]", (end-start)/1000.0);
|
System.out.printf("[Training finished in %.3fs]", (end-start)/1000.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
AtomicInteger neuronIndex = new AtomicInteger(0);
|
if(this.visualization) this.visualize(ctx);
|
||||||
ctx.model.forEachNeuron(n -> {
|
|
||||||
List<Float> weights = new ArrayList<>();
|
|
||||||
n.forEachSynapse(syn -> weights.add(syn.getWeight()));
|
|
||||||
|
|
||||||
float b = weights.get(0);
|
|
||||||
float w1 = weights.get(1);
|
|
||||||
float w2 = weights.get(2);
|
|
||||||
|
|
||||||
this.visualizer.addEquation("boundary_" + neuronIndex.getAndIncrement(), w1, w2, b, -3, 3);
|
|
||||||
});
|
|
||||||
int i = 0;
|
|
||||||
for(DataSetEntry entry : ctx.dataset){
|
|
||||||
List<Input> inputs = entry.getData();
|
|
||||||
this.visualizer.addPoint("p"+i, inputs.get(0).getValue(), inputs.get(1).getValue());
|
|
||||||
this.visualizer.addPoint("p"+i, inputs.get(0).getValue()+0.01F, inputs.get(1).getValue()+0.01F);
|
|
||||||
i++;
|
|
||||||
}
|
|
||||||
this.visualizer.build();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void executeSteps(TrainingContext ctx){
|
private void executeSteps(TrainingContext ctx){
|
||||||
@@ -123,4 +105,26 @@ public class TrainingPipeline {
|
|||||||
ctx.epoch += 1;
|
ctx.epoch += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void visualize(TrainingContext ctx){
|
||||||
|
AtomicInteger neuronIndex = new AtomicInteger(0);
|
||||||
|
ctx.model.forEachNeuron(n -> {
|
||||||
|
List<Float> weights = new ArrayList<>();
|
||||||
|
n.forEachSynapse(syn -> weights.add(syn.getWeight()));
|
||||||
|
|
||||||
|
float b = weights.get(0);
|
||||||
|
float w1 = weights.get(1);
|
||||||
|
float w2 = weights.get(2);
|
||||||
|
|
||||||
|
this.visualizer.addEquation("boundary_" + neuronIndex.getAndIncrement(), w1, w2, b, -3, 3);
|
||||||
|
});
|
||||||
|
int i = 0;
|
||||||
|
for(DataSetEntry entry : ctx.dataset){
|
||||||
|
List<Input> inputs = entry.getData();
|
||||||
|
this.visualizer.addPoint("p"+i, inputs.get(0).getValue(), inputs.get(1).getValue());
|
||||||
|
this.visualizer.addPoint("p"+i, inputs.get(0).getValue()+0.01F, inputs.get(1).getValue()+0.01F);
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
this.visualizer.build();
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import com.naaturel.ANN.implementation.gradientDescent.SquareLossStrategy;
|
|||||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStrategy;
|
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStrategy;
|
||||||
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy;
|
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy;
|
||||||
import com.naaturel.ANN.implementation.training.steps.*;
|
import com.naaturel.ANN.implementation.training.steps.*;
|
||||||
|
import com.naaturel.ANN.infrastructure.graph.GraphVisualizer;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -27,7 +28,7 @@ public class GradientDescentTraining implements Trainer {
|
|||||||
GradientDescentTrainingContext context = new GradientDescentTrainingContext();
|
GradientDescentTrainingContext context = new GradientDescentTrainingContext();
|
||||||
context.dataset = dataset;
|
context.dataset = dataset;
|
||||||
context.model = model;
|
context.model = model;
|
||||||
context.learningRate = 0.0011F;
|
context.learningRate = 0.0005F;
|
||||||
context.correctorTerms = new ArrayList<>();
|
context.correctorTerms = new ArrayList<>();
|
||||||
|
|
||||||
List<TrainingStep> steps = List.of(
|
List<TrainingStep> steps = List.of(
|
||||||
@@ -38,7 +39,7 @@ public class GradientDescentTraining implements Trainer {
|
|||||||
);
|
);
|
||||||
|
|
||||||
new TrainingPipeline(steps)
|
new TrainingPipeline(steps)
|
||||||
.stopCondition(ctx -> ctx.globalLoss <= 0.08F || ctx.epoch > 500)
|
.stopCondition(ctx -> ctx.globalLoss <= 0.00F || ctx.epoch > 50000)
|
||||||
.beforeEpoch(ctx -> {
|
.beforeEpoch(ctx -> {
|
||||||
GradientDescentTrainingContext gdCtx = (GradientDescentTrainingContext) ctx;
|
GradientDescentTrainingContext gdCtx = (GradientDescentTrainingContext) ctx;
|
||||||
gdCtx.globalLoss = 0.0F;
|
gdCtx.globalLoss = 0.0F;
|
||||||
@@ -49,8 +50,9 @@ public class GradientDescentTraining implements Trainer {
|
|||||||
context.globalLoss /= context.dataset.size();
|
context.globalLoss /= context.dataset.size();
|
||||||
new GradientDescentCorrectionStrategy(context).apply();
|
new GradientDescentCorrectionStrategy(context).apply();
|
||||||
})
|
})
|
||||||
.withVerbose(true)
|
//.withVerbose(true)
|
||||||
.withTimeMeasurement(true)
|
.withTimeMeasurement(true)
|
||||||
|
.withVisualization(false, new GraphVisualizer())
|
||||||
.run(context);
|
.run(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user