Just a regular commit

This commit is contained in:
Laurent
2026-03-23 18:47:36 +01:00
parent fbf2a571ef
commit 1da32862f5
3 changed files with 26 additions and 26 deletions

View File

@@ -1,6 +1,7 @@
package com.naaturel.ANN; package com.naaturel.ANN;
import com.naaturel.ANN.domain.abstraction.Neuron; import com.naaturel.ANN.domain.abstraction.Neuron;
import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.abstraction.TrainingStep; import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.dataset.DataSet; import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry; import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
@@ -11,6 +12,7 @@ import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.activation.Heaviside; import com.naaturel.ANN.implementation.activation.Heaviside;
import com.naaturel.ANN.implementation.correction.SimpleCorrectionStrategy; import com.naaturel.ANN.implementation.correction.SimpleCorrectionStrategy;
import com.naaturel.ANN.implementation.neuron.SimplePerceptron; import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
import com.naaturel.ANN.implementation.training.SimpleTraining;
import com.naaturel.ANN.implementation.training.steps.*; import com.naaturel.ANN.implementation.training.steps.*;
import java.util.*; import java.util.*;
@@ -67,25 +69,8 @@ public class Main {
Layer layer = new Layer(List.of(neuron)); Layer layer = new Layer(List.of(neuron));
Network network = new Network(List.of(layer)); Network network = new Network(List.of(layer));
TrainingContext context = new TrainingContext(); Trainer trainer = new SimpleTraining();
context.dataset = orDataSet; trainer.train(network, orDataSet);
context.model = network;
context.learningRate = 0.3F;
List<TrainingStep> steps = List.of(
new PredictionStep(),
new DeltaStep(),
new SimpleLossStep(),
new SimpleErrorDetectionStep(),
new WeightCorrectionStep(new SimpleCorrectionStrategy())
);
TrainingPipeline pipeline = new TrainingPipeline(steps);
pipeline
.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 100)
.beforeEpoch(ctx -> ctx.globalLoss = 0)
.withVerbose(true)
.run(context);
} }
} }

View File

@@ -1,10 +1,7 @@
package com.naaturel.ANN.domain.abstraction; package com.naaturel.ANN.domain.abstraction;
import com.naaturel.ANN.domain.model.training.TrainingContext; import com.naaturel.ANN.domain.model.dataset.DataSet;
import java.util.List;
public interface Trainer { public interface Trainer {
void train(Trainable model, DataSet dataset);
void train(TrainingContext context, List<TrainingStep> steps);
} }

View File

@@ -1,9 +1,9 @@
package com.naaturel.ANN.implementation.training; package com.naaturel.ANN.implementation.training;
import com.naaturel.ANN.domain.abstraction.Trainable;
import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.abstraction.TrainingStep; import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.dataset.DataSet; import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.neuron.Network;
import com.naaturel.ANN.domain.model.training.TrainingContext; import com.naaturel.ANN.domain.model.training.TrainingContext;
import com.naaturel.ANN.domain.model.training.TrainingPipeline; import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.correction.SimpleCorrectionStrategy; import com.naaturel.ANN.implementation.correction.SimpleCorrectionStrategy;
@@ -18,8 +18,26 @@ public class SimpleTraining implements Trainer {
} }
@Override @Override
public void train(TrainingContext context, List<TrainingStep> steps) { public void train(Trainable model, DataSet dataset) {
TrainingContext context = new TrainingContext();
context.dataset = dataset;
context.model = model;
context.learningRate = 0.3F;
List<TrainingStep> steps = List.of(
new PredictionStep(),
new DeltaStep(),
new SimpleLossStep(),
new SimpleErrorDetectionStep(),
new WeightCorrectionStep(new SimpleCorrectionStrategy())
);
TrainingPipeline pipeline = new TrainingPipeline(steps);
pipeline
.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 100)
.beforeEpoch(ctx -> ctx.globalLoss = 0)
.withVerbose(true)
.run(context);
} }
/*public void train(Neuron n, float learningRate, DataSet dataSet) { /*public void train(Neuron n, float learningRate, DataSet dataSet) {