Minor changes

This commit is contained in:
2026-03-27 12:40:00 +01:00
parent 572e5c7484
commit 7fb4a7c057
3 changed files with 29 additions and 23 deletions

View File

@@ -38,7 +38,7 @@ public class Main {
Layer layer = new Layer(neurons);
Network network = new Network(List.of(layer));
Trainer trainer = new AdalineTraining();
Trainer trainer = new GradientDescentTraining();
trainer.train(network, dataset);
}

View File

@@ -81,25 +81,7 @@ public class TrainingPipeline {
System.out.printf("[Training finished in %.3fs]", (end-start)/1000.0);
}
AtomicInteger neuronIndex = new AtomicInteger(0);
ctx.model.forEachNeuron(n -> {
List<Float> weights = new ArrayList<>();
n.forEachSynapse(syn -> weights.add(syn.getWeight()));
float b = weights.get(0);
float w1 = weights.get(1);
float w2 = weights.get(2);
this.visualizer.addEquation("boundary_" + neuronIndex.getAndIncrement(), w1, w2, b, -3, 3);
});
int i = 0;
for(DataSetEntry entry : ctx.dataset){
List<Input> inputs = entry.getData();
this.visualizer.addPoint("p"+i, inputs.get(0).getValue(), inputs.get(1).getValue());
this.visualizer.addPoint("p"+i, inputs.get(0).getValue()+0.01F, inputs.get(1).getValue()+0.01F);
i++;
}
this.visualizer.build();
if(this.visualization) this.visualize(ctx);
}
private void executeSteps(TrainingContext ctx){
@@ -123,4 +105,26 @@ public class TrainingPipeline {
ctx.epoch += 1;
}
private void visualize(TrainingContext ctx){
AtomicInteger neuronIndex = new AtomicInteger(0);
ctx.model.forEachNeuron(n -> {
List<Float> weights = new ArrayList<>();
n.forEachSynapse(syn -> weights.add(syn.getWeight()));
float b = weights.get(0);
float w1 = weights.get(1);
float w2 = weights.get(2);
this.visualizer.addEquation("boundary_" + neuronIndex.getAndIncrement(), w1, w2, b, -3, 3);
});
int i = 0;
for(DataSetEntry entry : ctx.dataset){
List<Input> inputs = entry.getData();
this.visualizer.addPoint("p"+i, inputs.get(0).getValue(), inputs.get(1).getValue());
this.visualizer.addPoint("p"+i, inputs.get(0).getValue()+0.01F, inputs.get(1).getValue()+0.01F);
i++;
}
this.visualizer.build();
}
}

View File

@@ -12,6 +12,7 @@ import com.naaturel.ANN.implementation.gradientDescent.SquareLossStrategy;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStrategy;
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy;
import com.naaturel.ANN.implementation.training.steps.*;
import com.naaturel.ANN.infrastructure.graph.GraphVisualizer;
import java.util.ArrayList;
import java.util.List;
@@ -27,7 +28,7 @@ public class GradientDescentTraining implements Trainer {
GradientDescentTrainingContext context = new GradientDescentTrainingContext();
context.dataset = dataset;
context.model = model;
context.learningRate = 0.0011F;
context.learningRate = 0.0005F;
context.correctorTerms = new ArrayList<>();
List<TrainingStep> steps = List.of(
@@ -38,7 +39,7 @@ public class GradientDescentTraining implements Trainer {
);
new TrainingPipeline(steps)
.stopCondition(ctx -> ctx.globalLoss <= 0.08F || ctx.epoch > 500)
.stopCondition(ctx -> ctx.globalLoss <= 0.00F || ctx.epoch > 50000)
.beforeEpoch(ctx -> {
GradientDescentTrainingContext gdCtx = (GradientDescentTrainingContext) ctx;
gdCtx.globalLoss = 0.0F;
@@ -49,8 +50,9 @@ public class GradientDescentTraining implements Trainer {
context.globalLoss /= context.dataset.size();
new GradientDescentCorrectionStrategy(context).apply();
})
.withVerbose(true)
//.withVerbose(true)
.withTimeMeasurement(true)
.withVisualization(false, new GraphVisualizer())
.run(context);
}