Fix implementation

This commit is contained in:
Laurent
2026-03-25 16:11:09 +01:00
parent 2936bf33bf
commit a2452fb4b8
33 changed files with 318 additions and 154 deletions

View File

@@ -2,28 +2,26 @@ package com.naaturel.ANN;
import com.naaturel.ANN.domain.abstraction.Neuron;
import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
import com.naaturel.ANN.domain.model.dataset.DatasetExtractor;
import com.naaturel.ANN.domain.model.dataset.Label;
import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.domain.model.training.TrainingContext;
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.activation.Heaviside;
import com.naaturel.ANN.implementation.correction.SimpleCorrectionStrategy;
import com.naaturel.ANN.implementation.gradientDescent.Linear;
import com.naaturel.ANN.implementation.simplePerceptron.Heaviside;
import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
import com.naaturel.ANN.implementation.training.GradientDescentTraining;
import com.naaturel.ANN.implementation.training.SimpleTraining;
import com.naaturel.ANN.implementation.training.steps.*;
import javax.xml.crypto.Data;
import java.util.*;
public class Main {
public static void main(String[] args){
DataSet dataset = new DatasetExtractor().extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/or.csv");
DataSet dataset = new DatasetExtractor()
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_9.csv");
DataSet orDataset = new DatasetExtractor()
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/or.csv");
List<Synapse> syns = new ArrayList<>();
syns.add(new Synapse(new Input(0), new Weight(0)));
@@ -31,11 +29,11 @@ public class Main {
Bias bias = new Bias(new Weight(0));
Neuron neuron = new SimplePerceptron(syns, bias, new Heaviside());
Neuron neuron = new SimplePerceptron(syns, bias, new Linear());
Layer layer = new Layer(List.of(neuron));
Network network = new Network(List.of(layer));
Trainer trainer = new SimpleTraining();
Trainer trainer = new GradientDescentTraining();
trainer.train(network, dataset);
}

View File

@@ -1,9 +1,7 @@
package com.naaturel.ANN.domain.abstraction;
import com.naaturel.ANN.domain.model.training.TrainingContext;
public interface AlgorithmStrategy {
void apply(TrainingContext ctx);
void apply();
}

View File

@@ -6,9 +6,9 @@ import com.naaturel.ANN.domain.model.neuron.Synapse;
import java.util.List;
import java.util.function.Consumer;
public interface Trainable {
public interface Model {
int synCount();
void applyOnSynapses(Consumer<Synapse> consumer);
List<Float> predict(List<Input> inputs);
void applyOnSynapses(Consumer<Synapse> consumer);
}

View File

@@ -4,10 +4,9 @@ import com.naaturel.ANN.domain.model.neuron.Input;
import com.naaturel.ANN.domain.model.neuron.Synapse;
import com.naaturel.ANN.domain.model.neuron.Weight;
import java.util.ArrayList;
import java.util.List;
public abstract class Neuron implements Trainable {
public abstract class Neuron implements Model {
protected List<Synapse> synapses;
protected Bias bias;
@@ -35,4 +34,9 @@ public abstract class Neuron implements Trainable {
syn.setInput(inputs.get(i));
}
}
@Override
public int synCount() {
return this.synapses.size()+1; //take the bias in account
}
}

View File

@@ -3,5 +3,5 @@ package com.naaturel.ANN.domain.abstraction;
import com.naaturel.ANN.domain.model.dataset.DataSet;
public interface Trainer {
void train(Trainable model, DataSet dataset);
void train(Model model, DataSet dataset);
}

View File

@@ -1,21 +1,21 @@
package com.naaturel.ANN.domain.model.training;
package com.naaturel.ANN.domain.abstraction;
import com.naaturel.ANN.domain.abstraction.Trainable;
import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
import com.naaturel.ANN.domain.model.dataset.Label;
public class TrainingContext {
public Trainable model;
public abstract class TrainingContext {
public Model model;
public DataSet dataset;
public DataSetEntry currentEntry;
public Label currentLabel;
public Label currentLabel;
public float prediction;
public float delta;
public float localLoss;
public float globalLoss;
public float learningRate;
public float globalLoss;
public float localLoss;
public float learningRate;
public int epoch;
}

View File

@@ -1,9 +1,7 @@
package com.naaturel.ANN.domain.abstraction;
import com.naaturel.ANN.domain.model.training.TrainingContext;
public interface TrainingStep {
void run(TrainingContext ctx);
void run();
}

View File

@@ -1,13 +1,13 @@
package com.naaturel.ANN.domain.model.neuron;
import com.naaturel.ANN.domain.abstraction.Neuron;
import com.naaturel.ANN.domain.abstraction.Trainable;
import com.naaturel.ANN.domain.abstraction.Model;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Consumer;
public class Layer implements Trainable {
public class Layer implements Model {
private final List<Neuron> neurons;
@@ -25,6 +25,15 @@ public class Layer implements Trainable {
return result;
}
@Override
public int synCount() {
int res = 0;
for (Neuron neuron : this.neurons) {
res += neuron.synCount();
}
return res;
}
@Override
public void applyOnSynapses(Consumer<Synapse> consumer) {
this.neurons.forEach(neuron -> neuron.applyOnSynapses(consumer));

View File

@@ -1,12 +1,12 @@
package com.naaturel.ANN.domain.model.neuron;
import com.naaturel.ANN.domain.abstraction.Trainable;
import com.naaturel.ANN.domain.abstraction.Model;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Consumer;
public class Network implements Trainable {
public class Network implements Model {
private final List<Layer> layers;
@@ -24,6 +24,15 @@ public class Network implements Trainable {
return result;
}
@Override
public int synCount() {
int res = 0;
for(Layer layer : this.layers){
res += layer.synCount();
}
return res;
}
@Override
public void applyOnSynapses(Consumer<Synapse> consumer) {
this.layers.forEach(layer -> layer.applyOnSynapses(consumer));

View File

@@ -1,5 +1,6 @@
package com.naaturel.ANN.domain.model.training;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
@@ -55,6 +56,9 @@ public class TrainingPipeline {
this.beforeEpoch.accept(ctx);
this.executeSteps(ctx);
this.afterEpoch.accept(ctx);
if(this.verbose) {
System.out.printf("[Global error] : %.2f\n", ctx.globalLoss);
}
} while (!this.stopCondition.test(ctx));
}
@@ -63,18 +67,16 @@ public class TrainingPipeline {
ctx.currentEntry = entry;
ctx.currentLabel = ctx.dataset.getLabel(entry);
for (TrainingStep step : steps) {
step.run(ctx);
step.run();
}
if(this.verbose) {
System.out.printf("Epoch : %d, ", ctx.epoch);
System.out.printf("predicted : %.2f, ", ctx.prediction);
System.out.printf("expected : %.2f, ", ctx.currentLabel.getValue());
System.out.printf("delta : %.2f\n", ctx.delta);
System.out.printf("delta : %.2f, ", ctx.delta);
System.out.printf("loss : %.5f\n", ctx.localLoss);
}
}
if(this.verbose) {
System.out.printf("[Global error] : %.2f\n", ctx.globalLoss);
}
ctx.epoch += 1;
}
}

View File

@@ -1,25 +0,0 @@
package com.naaturel.ANN.implementation.correction;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.model.training.TrainingContext;
import java.util.ArrayList;
import java.util.List;
public class GradientDescentCorrectionStrategy implements AlgorithmStrategy {
List<Float> correctorTerms;
public GradientDescentCorrectionStrategy(int nbrCorrectors){
this.correctorTerms = new ArrayList<>();
for (int i = 0; i < nbrCorrectors; i++){
correctorTerms.add(0F);
}
}
@Override
public void apply(TrainingContext context) {
}
}

View File

@@ -0,0 +1,25 @@
package com.naaturel.ANN.implementation.gradientDescent;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import java.util.concurrent.atomic.AtomicInteger;
public class GradientDescentCorrectionStrategy implements AlgorithmStrategy {
private final GradientDescentTrainingContext context;
public GradientDescentCorrectionStrategy(GradientDescentTrainingContext context) {
this.context = context;
}
@Override
public void apply() {
AtomicInteger i = new AtomicInteger(0);
context.model.applyOnSynapses(syn -> {
float corrector = context.correctorTerms.get(i.get());
float c = syn.getWeight() + corrector;
syn.setWeight(c);
i.incrementAndGet();
});
}
}

View File

@@ -0,0 +1,26 @@
package com.naaturel.ANN.implementation.gradientDescent;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import java.util.concurrent.atomic.AtomicInteger;
public class GradientDescentErrorStrategy implements AlgorithmStrategy {
private final GradientDescentTrainingContext context;
public GradientDescentErrorStrategy(GradientDescentTrainingContext context) {
this.context = context;
}
@Override
public void apply() {
AtomicInteger i = new AtomicInteger(0);
context.model.applyOnSynapses(syn -> {
float corrector = context.correctorTerms.get(i.get());
corrector += context.learningRate * context.delta * syn.getInput();
context.correctorTerms.set(i.get(), corrector);
i.incrementAndGet();
});
}
}

View File

@@ -0,0 +1,11 @@
package com.naaturel.ANN.implementation.gradientDescent;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import java.util.List;
public class GradientDescentTrainingContext extends TrainingContext {
public List<Float> correctorTerms;
}

View File

@@ -1,4 +1,4 @@
package com.naaturel.ANN.implementation.activation;
package com.naaturel.ANN.implementation.gradientDescent;
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
import com.naaturel.ANN.domain.abstraction.Neuron;

View File

@@ -0,0 +1,19 @@
package com.naaturel.ANN.implementation.gradientDescent;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleTrainingContext;
public class SquareLossStrategy implements AlgorithmStrategy {
private final GradientDescentTrainingContext context;
public SquareLossStrategy(GradientDescentTrainingContext context) {
this.context = context;
}
@Override
public void apply() {
this.context.localLoss = (float)Math.pow(this.context.delta, 2)/2;
this.context.globalLoss += context.localLoss;
}
}

View File

@@ -1,11 +0,0 @@
package com.naaturel.ANN.implementation.loss;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.model.training.TrainingContext;
public class SimpleLossStrategy implements AlgorithmStrategy {
@Override
public void apply(TrainingContext ctx) {
ctx.localLoss = Math.abs(ctx.delta);
}
}

View File

@@ -1,11 +0,0 @@
package com.naaturel.ANN.implementation.loss;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.model.training.TrainingContext;
public class SquareLossStrategy implements AlgorithmStrategy {
@Override
public void apply(TrainingContext ctx) {
ctx.localLoss = (float)Math.pow(ctx.delta, 2) / 2;
}
}

View File

@@ -1,4 +1,4 @@
package com.naaturel.ANN.implementation.activation;
package com.naaturel.ANN.implementation.simplePerceptron;
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
import com.naaturel.ANN.domain.abstraction.Neuron;

View File

@@ -1,14 +1,18 @@
package com.naaturel.ANN.implementation.correction;
package com.naaturel.ANN.implementation.simplePerceptron;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.model.training.TrainingContext;
public class SimpleCorrectionStrategy implements AlgorithmStrategy {
@Override
public void apply(TrainingContext context) {
if(context.currentLabel.getValue() == context.prediction) return ;
private final SimpleTrainingContext context;
public SimpleCorrectionStrategy(SimpleTrainingContext context) {
this.context = context;
}
@Override
public void apply() {
if(context.currentLabel.getValue() == context.prediction) return ;
context.model.applyOnSynapses(syn -> {
float currentW = syn.getWeight();
float currentInput = syn.getInput();

View File

@@ -0,0 +1,26 @@
package com.naaturel.ANN.implementation.simplePerceptron;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
import com.naaturel.ANN.domain.model.dataset.Label;
public class SimpleDeltaStrategy implements AlgorithmStrategy {
private final TrainingContext context;
public SimpleDeltaStrategy(TrainingContext context) {
this.context = context;
}
@Override
public void apply() {
DataSet dataSet = context.dataset;
DataSetEntry entry = context.currentEntry;
Label label = dataSet.getLabel(entry);
context.delta = label.getValue() - context.prediction;
}
}

View File

@@ -0,0 +1,17 @@
package com.naaturel.ANN.implementation.simplePerceptron;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
public class SimpleErrorRegistrationStrategy implements AlgorithmStrategy {
private final SimpleTrainingContext context;
public SimpleErrorRegistrationStrategy(SimpleTrainingContext context) {
this.context = context;
}
@Override
public void apply() {
context.globalLoss += context.localLoss;
}
}

View File

@@ -0,0 +1,17 @@
package com.naaturel.ANN.implementation.simplePerceptron;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
public class SimpleLossStrategy implements AlgorithmStrategy {
private final SimpleTrainingContext context;
public SimpleLossStrategy(SimpleTrainingContext context) {
this.context = context;
}
@Override
public void apply() {
this.context.localLoss = Math.abs(this.context.delta);
}
}

View File

@@ -0,0 +1,21 @@
package com.naaturel.ANN.implementation.simplePerceptron;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import java.util.List;
public class SimplePredictionStrategy implements AlgorithmStrategy {
private final TrainingContext context;
public SimplePredictionStrategy(TrainingContext context) {
this.context = context;
}
@Override
public void apply() {
List<Float> predictions = context.model.predict(context.currentEntry.getData());
context.prediction = predictions.getFirst();
}
}

View File

@@ -0,0 +1,6 @@
package com.naaturel.ANN.implementation.simplePerceptron;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
public class SimpleTrainingContext extends TrainingContext {
}

View File

@@ -1,16 +1,19 @@
package com.naaturel.ANN.implementation.training;
import com.naaturel.ANN.domain.abstraction.Trainable;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.training.TrainingContext;
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentErrorStrategy;
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext;
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.correction.GradientDescentCorrectionStrategy;
import com.naaturel.ANN.implementation.correction.SimpleCorrectionStrategy;
import com.naaturel.ANN.implementation.loss.SquareLossStrategy;
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentCorrectionStrategy;
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStrategy;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStrategy;
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy;
import com.naaturel.ANN.implementation.training.steps.*;
import java.util.ArrayList;
import java.util.List;
public class GradientDescentTraining implements Trainer {
@@ -20,25 +23,31 @@ public class GradientDescentTraining implements Trainer {
}
@Override
public void train(Trainable model, DataSet dataset) {
TrainingContext context = new TrainingContext();
public void train(Model model, DataSet dataset) {
GradientDescentTrainingContext context = new GradientDescentTrainingContext();
context.dataset = dataset;
context.model = model;
context.learningRate = 0.3F;
context.learningRate = 0.00011F;
context.correctorTerms = new ArrayList<>();
List<TrainingStep> steps = List.of(
new PredictionStep(),
new DeltaStep(),
new LossStep(new SquareLossStrategy()),
new SimpleErrorDetectionStep(),
new WeightCorrectionStep(new GradientDescentCorrectionStrategy(2))
new PredictionStep(new SimplePredictionStrategy(context)),
new DeltaStep(new SimpleDeltaStrategy(context)),
new LossStep(new SquareLossStrategy(context)),
new ErrorRegistrationStep(new GradientDescentErrorStrategy(context)),
new WeightCorrectionStep(new GradientDescentCorrectionStrategy(context))
);
TrainingPipeline pipeline = new TrainingPipeline(steps);
pipeline
.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 100)
.beforeEpoch(ctx -> ctx.globalLoss = 0)
.afterEpoch(ctx -> ())
.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 1000)
.beforeEpoch(ctx -> {
ctx.globalLoss = 0.0F;
for (int i = 0; i < model.synCount(); i++){
context.correctorTerms.add(0F);
}
})
.afterEpoch(ctx -> ctx.globalLoss /= ctx.dataset.size())
.withVerbose(true)
.withTimeMeasurement(true)
.run(context);

View File

@@ -1,13 +1,11 @@
package com.naaturel.ANN.implementation.training;
import com.naaturel.ANN.domain.abstraction.Trainable;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.training.TrainingContext;
import com.naaturel.ANN.implementation.simplePerceptron.*;
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.correction.SimpleCorrectionStrategy;
import com.naaturel.ANN.implementation.loss.SimpleLossStrategy;
import com.naaturel.ANN.implementation.training.steps.*;
import java.util.List;
@@ -19,18 +17,18 @@ public class SimpleTraining implements Trainer {
}
@Override
public void train(Trainable model, DataSet dataset) {
TrainingContext context = new TrainingContext();
public void train(Model model, DataSet dataset) {
SimpleTrainingContext context = new SimpleTrainingContext();
context.dataset = dataset;
context.model = model;
context.learningRate = 0.3F;
List<TrainingStep> steps = List.of(
new PredictionStep(),
new DeltaStep(),
new LossStep(new SimpleLossStrategy()),
new SimpleErrorDetectionStep(),
new WeightCorrectionStep(new SimpleCorrectionStrategy())
new PredictionStep(new SimplePredictionStrategy(context)),
new DeltaStep(new SimpleDeltaStrategy(context)),
new LossStep(new SimpleLossStrategy(context)),
new ErrorRegistrationStep(new SimpleErrorRegistrationStrategy(context)),
new WeightCorrectionStep(new SimpleCorrectionStrategy(context))
);
TrainingPipeline pipeline = new TrainingPipeline(steps);

View File

@@ -1,19 +1,22 @@
package com.naaturel.ANN.implementation.training.steps;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
import com.naaturel.ANN.domain.model.dataset.Label;
import com.naaturel.ANN.domain.model.training.TrainingContext;
public class DeltaStep implements TrainingStep {
@Override
public void run(TrainingContext ctx) {
DataSet dataSet = ctx.dataset;
DataSetEntry entry = ctx.currentEntry;
Label label = dataSet.getLabel(entry);
private final AlgorithmStrategy strategy;
ctx.delta = label.getValue() - ctx.prediction;
public DeltaStep(AlgorithmStrategy strategy) {
this.strategy = strategy;
}
@Override
public void run() {
this.strategy.apply();
}
}

View File

@@ -0,0 +1,18 @@
package com.naaturel.ANN.implementation.training.steps;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
public class ErrorRegistrationStep implements TrainingStep {
private final AlgorithmStrategy strategy;
public ErrorRegistrationStep(AlgorithmStrategy strategy) {
this.strategy = strategy;
}
@Override
public void run() {
this.strategy.apply();
}
}

View File

@@ -1,11 +1,12 @@
package com.naaturel.ANN.implementation.training.steps;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.training.TrainingContext;
public class LossStep implements TrainingStep {
private final AlgorithmStrategy lossStrategy;
public LossStep(AlgorithmStrategy strategy) {
@@ -13,7 +14,7 @@ public class LossStep implements TrainingStep {
}
@Override
public void run(TrainingContext ctx) {
this.lossStrategy.apply(ctx);
public void run() {
this.lossStrategy.apply();
}
}

View File

@@ -1,17 +1,23 @@
package com.naaturel.ANN.implementation.training.steps;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.neuron.Input;
import com.naaturel.ANN.domain.model.training.TrainingContext;
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleTrainingContext;
import java.util.ArrayList;
import java.util.List;
public class PredictionStep implements TrainingStep {
private final SimplePredictionStrategy strategy;
public PredictionStep(SimplePredictionStrategy strategy) {
this.strategy = strategy;
}
@Override
public void run(TrainingContext ctx) {
List<Float> predictions = ctx.model.predict(ctx.currentEntry.getData());
ctx.prediction = predictions.getFirst();
public void run() {
this.strategy.apply();
}
}

View File

@@ -1,13 +0,0 @@
package com.naaturel.ANN.implementation.training.steps;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.training.TrainingContext;
public class SimpleErrorDetectionStep implements TrainingStep {
@Override
public void run(TrainingContext ctx) {
ctx.globalLoss += ctx.localLoss;
}
}

View File

@@ -2,7 +2,6 @@ package com.naaturel.ANN.implementation.training.steps;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.training.TrainingContext;
public class WeightCorrectionStep implements TrainingStep {
@@ -13,7 +12,7 @@ public class WeightCorrectionStep implements TrainingStep {
}
@Override
public void run(TrainingContext ctx) {
this.correctionStrategy.apply(ctx);
public void run() {
this.correctionStrategy.apply();
}
}