From 1da32862f530dd3c1275465ec9dece85f9290972 Mon Sep 17 00:00:00 2001 From: Laurent <2-naaturel@users.noreply.gitlab.naaturel.be> Date: Mon, 23 Mar 2026 18:47:36 +0100 Subject: [PATCH] Just a regular commit --- src/main/java/com/naaturel/ANN/Main.java | 23 ++++--------------- .../ANN/domain/abstraction/Trainer.java | 7 ++---- .../training/SimpleTraining.java | 22 ++++++++++++++++-- 3 files changed, 26 insertions(+), 26 deletions(-) diff --git a/src/main/java/com/naaturel/ANN/Main.java b/src/main/java/com/naaturel/ANN/Main.java index 16afb02..434e03a 100644 --- a/src/main/java/com/naaturel/ANN/Main.java +++ b/src/main/java/com/naaturel/ANN/Main.java @@ -1,6 +1,7 @@ package com.naaturel.ANN; import com.naaturel.ANN.domain.abstraction.Neuron; +import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.domain.abstraction.TrainingStep; import com.naaturel.ANN.domain.model.dataset.DataSet; import com.naaturel.ANN.domain.model.dataset.DataSetEntry; @@ -11,6 +12,7 @@ import com.naaturel.ANN.domain.model.training.TrainingPipeline; import com.naaturel.ANN.implementation.activation.Heaviside; import com.naaturel.ANN.implementation.correction.SimpleCorrectionStrategy; import com.naaturel.ANN.implementation.neuron.SimplePerceptron; +import com.naaturel.ANN.implementation.training.SimpleTraining; import com.naaturel.ANN.implementation.training.steps.*; import java.util.*; @@ -67,25 +69,8 @@ public class Main { Layer layer = new Layer(List.of(neuron)); Network network = new Network(List.of(layer)); - TrainingContext context = new TrainingContext(); - context.dataset = orDataSet; - context.model = network; - context.learningRate = 0.3F; - - List steps = List.of( - new PredictionStep(), - new DeltaStep(), - new SimpleLossStep(), - new SimpleErrorDetectionStep(), - new WeightCorrectionStep(new SimpleCorrectionStrategy()) - ); - - TrainingPipeline pipeline = new TrainingPipeline(steps); - pipeline - .stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 100) - .beforeEpoch(ctx -> ctx.globalLoss = 0) - .withVerbose(true) - .run(context); + Trainer trainer = new SimpleTraining(); + trainer.train(network, orDataSet); } } diff --git a/src/main/java/com/naaturel/ANN/domain/abstraction/Trainer.java b/src/main/java/com/naaturel/ANN/domain/abstraction/Trainer.java index eec3555..a305cb1 100644 --- a/src/main/java/com/naaturel/ANN/domain/abstraction/Trainer.java +++ b/src/main/java/com/naaturel/ANN/domain/abstraction/Trainer.java @@ -1,10 +1,7 @@ package com.naaturel.ANN.domain.abstraction; -import com.naaturel.ANN.domain.model.training.TrainingContext; - -import java.util.List; +import com.naaturel.ANN.domain.model.dataset.DataSet; public interface Trainer { - - void train(TrainingContext context, List steps); + void train(Trainable model, DataSet dataset); } diff --git a/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java index 357c7a9..7f8f43f 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java @@ -1,9 +1,9 @@ package com.naaturel.ANN.implementation.training; +import com.naaturel.ANN.domain.abstraction.Trainable; import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.domain.abstraction.TrainingStep; import com.naaturel.ANN.domain.model.dataset.DataSet; -import com.naaturel.ANN.domain.model.neuron.Network; import com.naaturel.ANN.domain.model.training.TrainingContext; import com.naaturel.ANN.domain.model.training.TrainingPipeline; import com.naaturel.ANN.implementation.correction.SimpleCorrectionStrategy; @@ -18,8 +18,26 @@ public class SimpleTraining implements Trainer { } @Override - public void train(TrainingContext context, List steps) { + public void train(Trainable model, DataSet dataset) { + TrainingContext context = new TrainingContext(); + context.dataset = dataset; + context.model = model; + context.learningRate = 0.3F; + List steps = List.of( + new PredictionStep(), + new DeltaStep(), + new SimpleLossStep(), + new SimpleErrorDetectionStep(), + new WeightCorrectionStep(new SimpleCorrectionStrategy()) + ); + + TrainingPipeline pipeline = new TrainingPipeline(steps); + pipeline + .stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 100) + .beforeEpoch(ctx -> ctx.globalLoss = 0) + .withVerbose(true) + .run(context); } /*public void train(Neuron n, float learningRate, DataSet dataSet) {