diff --git a/src/main/java/com/naaturel/ANN/Main.java b/src/main/java/com/naaturel/ANN/Main.java index 966a1e9..b551c4c 100644 --- a/src/main/java/com/naaturel/ANN/Main.java +++ b/src/main/java/com/naaturel/ANN/Main.java @@ -24,21 +24,23 @@ public class Main { int nbrClass = 1; DataSet dataset = new DatasetExtractor() - .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_12.csv", nbrClass); + .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_9.csv", nbrClass); - int[] neuronPerLayer = new int[]{50, 50, 50, dataset.getNbrLabels()}; + int[] neuronPerLayer = new int[]{1800, 2, 1800, dataset.getNbrLabels()}; int nbrInput = dataset.getNbrInputs(); FullyConnectedNetwork network = createNetwork(neuronPerLayer, nbrInput); + System.out.println(network.synCount()); + Trainer trainer = new GradientBackpropagationTraining(); trainer.train(0.01F, 2000, network, dataset); //plotGraph(dataset, network); - } private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){ + int neuronId = 0; List layers = new ArrayList<>(); for (int i = 0; i < neuronPerLayer.length; i++){ @@ -54,8 +56,9 @@ public class Main { Bias bias = new Bias(new Weight()); - Neuron n = new Neuron(syns.toArray(new Synapse[0]), bias, new TanH()); + Neuron n = new Neuron(neuronId, syns.toArray(new Synapse[0]), bias, new TanH()); neurons.add(n); + neuronId++; } Layer layer = new Layer(neurons.toArray(new Neuron[0])); layers.add(layer); @@ -77,7 +80,7 @@ public class Main { float min = -5F; float max = 5F; - float step = 0.01F; + float step = 0.03F; for (float x = min; x < max; x+=step){ for (float y = min; y < max; y+=step){ List predictions = network.predict(List.of(new Input(x), new Input(y))); diff --git a/src/main/java/com/naaturel/ANN/domain/abstraction/TrainingContext.java b/src/main/java/com/naaturel/ANN/domain/abstraction/TrainingContext.java index beeab05..35f65af 100644 --- a/src/main/java/com/naaturel/ANN/domain/abstraction/TrainingContext.java +++ b/src/main/java/com/naaturel/ANN/domain/abstraction/TrainingContext.java @@ -12,11 +12,18 @@ public abstract class TrainingContext { public List expectations; public List predictions; - public List deltas; + public float[] deltas; public float globalLoss; public float localLoss; public float learningRate; public int epoch; + + public TrainingContext(Model model, DataSet dataset) { + this.model = model; + this.dataset = dataset; + this.deltas = new float[model.neuronCount()]; + } + } diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java index 587febc..c6ee3f8 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java +++ b/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java @@ -7,18 +7,25 @@ import java.util.function.Consumer; public class Neuron implements Model { - protected Synapse[] synapses; + protected int id; + protected final Synapse[] synapses; protected Bias bias; protected ActivationFunction activationFunction; protected Float output; protected Float weightedSum; + protected final float[] weights; + protected final float[] inputs; - public Neuron(Synapse[] synapses, Bias bias, ActivationFunction func){ + public Neuron(int id, Synapse[] synapses, Bias bias, ActivationFunction func){ + this.id = id; this.synapses = synapses; this.bias = bias; this.activationFunction = func; this.output = null; this.weightedSum = null; + + weights = new float[synapses.length]; + inputs = new float[synapses.length]; } public void updateBias(Weight weight) { @@ -53,12 +60,18 @@ public class Neuron implements Model { } public float calculateWeightedSum() { - this.weightedSum = 0F; - this.weightedSum += this.bias.getWeight() * this.bias.getInput(); - for(Synapse syn : this.synapses){ - this.weightedSum += syn.getWeight() * syn.getInput(); + float sum = bias.getWeight() * bias.getInput(); + + for (int i = 0; i < weights.length; i++) { + sum += weights[i] * inputs[i]; } - return this.weightedSum; + + this.weightedSum = sum; + return sum; + } + + public int getId(){ + return this.id; } @Override diff --git a/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java b/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java index 18921d1..c8b9b81 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java +++ b/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java @@ -103,7 +103,7 @@ public class TrainingPipeline { System.out.printf("Epoch : %d, ", ctx.epoch); System.out.printf("predicted : %s, ", Arrays.toString(ctx.predictions.toArray())); System.out.printf("expected : %s, ", Arrays.toString(ctx.expectations.toArray())); - System.out.printf("delta : %s, ", Arrays.toString(ctx.deltas.toArray())); + System.out.printf("delta : %s, ", Arrays.toString(ctx.deltas)); System.out.printf("loss : %.5f\n", ctx.localLoss); } } diff --git a/src/main/java/com/naaturel/ANN/implementation/adaline/AdalineTrainingContext.java b/src/main/java/com/naaturel/ANN/implementation/adaline/AdalineTrainingContext.java index 3b0b623..4c82776 100644 --- a/src/main/java/com/naaturel/ANN/implementation/adaline/AdalineTrainingContext.java +++ b/src/main/java/com/naaturel/ANN/implementation/adaline/AdalineTrainingContext.java @@ -1,6 +1,11 @@ package com.naaturel.ANN.implementation.adaline; +import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.abstraction.TrainingContext; +import com.naaturel.ANN.infrastructure.dataset.DataSet; public class AdalineTrainingContext extends TrainingContext { + public AdalineTrainingContext(Model model, DataSet dataset) { + super(model, dataset); + } } diff --git a/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentErrorStrategy.java b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentErrorStrategy.java index 101bbef..5b36a4d 100644 --- a/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentErrorStrategy.java +++ b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentErrorStrategy.java @@ -20,7 +20,7 @@ public class GradientDescentErrorStrategy implements AlgorithmStep { AtomicInteger synIndex = new AtomicInteger(0); context.model.forEachNeuron(neuron -> { - float correspondingDelta = context.deltas.get(neuronIndex.get()); + float correspondingDelta = context.deltas[neuronIndex.get()]; neuron.forEachSynapse(syn -> { float corrector = context.correctorTerms.get(synIndex.get()); diff --git a/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentTrainingContext.java b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentTrainingContext.java index 0b1ec5f..8f91189 100644 --- a/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentTrainingContext.java +++ b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentTrainingContext.java @@ -1,6 +1,8 @@ package com.naaturel.ANN.implementation.gradientDescent; +import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.abstraction.TrainingContext; +import com.naaturel.ANN.infrastructure.dataset.DataSet; import java.util.List; @@ -8,4 +10,7 @@ public class GradientDescentTrainingContext extends TrainingContext { public List correctorTerms; + public GradientDescentTrainingContext(Model model, DataSet dataset) { + super(model, dataset); + } } diff --git a/src/main/java/com/naaturel/ANN/implementation/gradientDescent/SquareLossStep.java b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/SquareLossStep.java index f928f00..76d70f5 100644 --- a/src/main/java/com/naaturel/ANN/implementation/gradientDescent/SquareLossStep.java +++ b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/SquareLossStep.java @@ -15,9 +15,11 @@ public class SquareLossStep implements AlgorithmStep { @Override public void run() { - Stream deltaStream = this.context.deltas.stream(); - this.context.localLoss = deltaStream.reduce(0.0F, (acc, d) -> (float) (acc + Math.pow(d, 2))); - this.context.localLoss /= 2; - this.context.globalLoss += this.context.localLoss; //broke MSE en gradientDescentTraining + float loss = 0f; + for (float d : this.context.deltas) { + loss += d * d; + } + this.context.localLoss = loss / 2f; + this.context.globalLoss += this.context.localLoss; } } diff --git a/src/main/java/com/naaturel/ANN/implementation/multiLayers/BackpropagationCorrectionStep.java b/src/main/java/com/naaturel/ANN/implementation/multiLayers/BackpropagationCorrectionStep.java index 980aaa3..d8034d8 100644 --- a/src/main/java/com/naaturel/ANN/implementation/multiLayers/BackpropagationCorrectionStep.java +++ b/src/main/java/com/naaturel/ANN/implementation/multiLayers/BackpropagationCorrectionStep.java @@ -2,41 +2,55 @@ package com.naaturel.ANN.implementation.multiLayers; import com.naaturel.ANN.domain.abstraction.AlgorithmStep; -import java.util.concurrent.atomic.AtomicInteger; - public class BackpropagationCorrectionStep implements AlgorithmStep { - private GradientBackpropagationContext context; + private final GradientBackpropagationContext context; + private final int synCount; + private final float[] inputs; + private final float[] signals; public BackpropagationCorrectionStep(GradientBackpropagationContext context){ this.context = context; + this.synCount = context.correctionBuffer.length; + this.inputs = new float[synCount]; + this.signals = new float[synCount]; } @Override public void run() { - - AtomicInteger synIndex = new AtomicInteger(0); - this.context.model.forEachNeuron(n -> { - float signal = context.errorSignals.get(n); + int[] synIndex = {0}; + context.model.forEachNeuron(n -> { + float signal = context.errorSignals[n.getId()]; n.forEachSynapse(syn -> { - float lr = context.learningRate; - float corrector = lr * signal * syn.getInput(); - float existingCorrector = context.correctionBuffer[synIndex.get()]; - float newCorrector = existingCorrector + corrector; - - if(context.currentSample >= context.batchSize){ - float newWeight = syn.getWeight() + newCorrector; - syn.setWeight(newWeight); - context.correctionBuffer[synIndex.get()] = 0; - } else { - context.correctionBuffer[synIndex.get()] = newCorrector; - } - synIndex.incrementAndGet(); + inputs[synIndex[0]] = syn.getInput(); + signals[synIndex[0]] = signal; + synIndex[0]++; }); }); - if(context.currentSample >= context.batchSize) { + + float lr = context.learningRate; + boolean applyUpdate = context.currentSample >= context.batchSize; + + for (int i = 0; i < synCount; i++) { + context.correctionBuffer[i] += lr * signals[i] * inputs[i]; + } + + if (applyUpdate) { + syncWeights(); context.currentSample = 0; } - context.currentSample += 1; + + context.currentSample++; } -} + + private void syncWeights() { + int[] i = {0}; + context.model.forEachNeuron(n -> { + n.forEachSynapse(syn -> { + syn.setWeight(syn.getWeight() + context.correctionBuffer[i[0]]); + context.correctionBuffer[i[0]] = 0f; + i[0]++; + }); + }); + } +} \ No newline at end of file diff --git a/src/main/java/com/naaturel/ANN/implementation/multiLayers/ErrorSignalStep.java b/src/main/java/com/naaturel/ANN/implementation/multiLayers/ErrorSignalStep.java index cbb0880..d36bf2b 100644 --- a/src/main/java/com/naaturel/ANN/implementation/multiLayers/ErrorSignalStep.java +++ b/src/main/java/com/naaturel/ANN/implementation/multiLayers/ErrorSignalStep.java @@ -1,38 +1,29 @@ package com.naaturel.ANN.implementation.multiLayers; import com.naaturel.ANN.domain.abstraction.AlgorithmStep; -import com.naaturel.ANN.domain.model.neuron.Neuron; - -import java.util.Map; -import java.util.concurrent.atomic.AtomicReference; public class ErrorSignalStep implements AlgorithmStep { - private GradientBackpropagationContext context; + private final GradientBackpropagationContext context; + public ErrorSignalStep(GradientBackpropagationContext context) { this.context = context; } @Override public void run() { - this.context.model.forEachNeuron(n -> { - calculateErrorSignalRecursive(n, this.context.errorSignals); + + context.model.forEachNeuron(n -> { + if (context.errorSignalsComputed[n.getId()]) return; + + int neuronIndex = context.model.indexInLayerOf(n); + float[] signalSum = {0f}; + context.model.forEachNeuronConnectedTo(n, connected -> { + signalSum[0] += context.errorSignals[connected.getId()] * connected.getWeight(neuronIndex); + }); + + context.errorSignals[n.getId()] = n.getActivationFunction().derivative(n.getOutput()) * signalSum[0]; + context.errorSignalsComputed[n.getId()] = true; }); } - - private float calculateErrorSignalRecursive(Neuron n, Map signals) { - if (signals.containsKey(n)) return signals.get(n); - - int neuronIndex = this.context.model.indexInLayerOf(n); - AtomicReference signalSum = new AtomicReference<>(0F); - this.context.model.forEachNeuronConnectedTo(n, connected -> { - float weightedSignal = calculateErrorSignalRecursive(connected, signals) * connected.getWeight(neuronIndex); - signalSum.set(signalSum.get() + weightedSignal); - }); - - float derivative = n.getActivationFunction().derivative(n.getOutput()); - float finalSignal = derivative * signalSum.get(); - signals.put(n, finalSignal); - return finalSignal; - } -} +} \ No newline at end of file diff --git a/src/main/java/com/naaturel/ANN/implementation/multiLayers/GradientBackpropagationContext.java b/src/main/java/com/naaturel/ANN/implementation/multiLayers/GradientBackpropagationContext.java index 11d1cd2..71dcda6 100644 --- a/src/main/java/com/naaturel/ANN/implementation/multiLayers/GradientBackpropagationContext.java +++ b/src/main/java/com/naaturel/ANN/implementation/multiLayers/GradientBackpropagationContext.java @@ -10,20 +10,21 @@ import java.util.Map; public class GradientBackpropagationContext extends TrainingContext { - public final Map errorSignals; + public final float[] errorSignals; public final float[] correctionBuffer; + public final boolean[] errorSignalsComputed; public int currentSample; public int batchSize; public GradientBackpropagationContext(Model model, DataSet dataSet, float learningRate, int batchSize){ - this.model = model; - this.dataset = dataSet; + super(model, dataSet); this.learningRate = learningRate; this.batchSize = batchSize; - this.errorSignals = new HashMap<>(); + this.errorSignals = new float[model.neuronCount()]; this.correctionBuffer = new float[model.synCount()]; + this.errorSignalsComputed = new boolean[model.neuronCount()]; this.currentSample = 1; } } diff --git a/src/main/java/com/naaturel/ANN/implementation/multiLayers/OutputLayerErrorStep.java b/src/main/java/com/naaturel/ANN/implementation/multiLayers/OutputLayerErrorStep.java index c90f765..a51f0aa 100644 --- a/src/main/java/com/naaturel/ANN/implementation/multiLayers/OutputLayerErrorStep.java +++ b/src/main/java/com/naaturel/ANN/implementation/multiLayers/OutputLayerErrorStep.java @@ -3,37 +3,40 @@ package com.naaturel.ANN.implementation.multiLayers; import com.naaturel.ANN.domain.abstraction.AlgorithmStep; import com.naaturel.ANN.infrastructure.dataset.DataSetEntry; -import java.util.ArrayList; -import java.util.HashMap; +import java.util.Arrays; import java.util.List; -import java.util.concurrent.atomic.AtomicInteger; public class OutputLayerErrorStep implements AlgorithmStep { private final GradientBackpropagationContext context; + private final float[] expectations; public OutputLayerErrorStep(GradientBackpropagationContext context){ this.context = context; + this.expectations = new float[context.model.neuronCount()]; } @Override public void run() { - context.deltas = new ArrayList<>(); - DataSetEntry entry = this.context.currentEntry; - List expectations = this.context.dataset.getLabelsAsFloat(entry); - AtomicInteger index = new AtomicInteger(0); + Arrays.fill(context.errorSignals, 0f); + Arrays.fill(context.errorSignalsComputed, false); - context.errorSignals.clear(); - this.context.model.forEachOutputNeurons(n -> { - float expected = expectations.get(index.get()); + DataSetEntry entry = context.currentEntry; + List labels = context.dataset.getLabelsAsFloat(entry); + for (int i = 0; i < labels.size(); i++) { + expectations[i] = labels.get(i); + } + + int[] index = {0}; + context.model.forEachOutputNeurons(n -> { + float expected = expectations[index[0]]; float predicted = n.getOutput(); - float output = n.getOutput(); float delta = expected - predicted; - float signal = delta * n.getActivationFunction().derivative(output); - this.context.deltas.add(delta); - this.context.errorSignals.put(n, signal); - index.incrementAndGet(); + context.deltas[index[0]] = delta; + context.errorSignals[n.getId()] = delta * n.getActivationFunction().derivative(predicted); + context.errorSignalsComputed[n.getId()] = true; + index[0]++; }); } -} +} \ No newline at end of file diff --git a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleCorrectionStep.java b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleCorrectionStep.java index 81e8f3c..a726318 100644 --- a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleCorrectionStep.java +++ b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleCorrectionStep.java @@ -21,7 +21,7 @@ public class SimpleCorrectionStep implements AlgorithmStep { AtomicInteger synIndex = new AtomicInteger(0); context.model.forEachNeuron(neuron -> { - float correspondingDelta = context.deltas.get(neuronIndex.get()); + float correspondingDelta = context.deltas[neuronIndex.get()]; neuron.forEachSynapse(syn -> { float currentW = syn.getWeight(); float currentInput = syn.getInput(); diff --git a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleDeltaStep.java b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleDeltaStep.java index 10b0300..5b5f88c 100644 --- a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleDeltaStep.java +++ b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleDeltaStep.java @@ -24,10 +24,9 @@ public class SimpleDeltaStep implements AlgorithmStep { List predicted = context.predictions; List expected = dataSet.getLabelsAsFloat(entry); - //context.delta = label.getValue() - context.predictions; - context.deltas = IntStream.range(0, predicted.size()) - .mapToObj(i -> expected.get(i) - predicted.get(i)) - .collect(Collectors.toList()); + for (int i = 0; i < predicted.size(); i++) { + context.deltas[i] = expected.get(i) - predicted.get(i); + } } } diff --git a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleLossStrategy.java b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleLossStrategy.java index 562eec6..e3fc7b0 100644 --- a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleLossStrategy.java +++ b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleLossStrategy.java @@ -12,6 +12,10 @@ public class SimpleLossStrategy implements AlgorithmStep { @Override public void run() { - this.context.localLoss = this.context.deltas.stream().reduce(0.0F, Float::sum); + float loss = 0f; + for (float d : context.deltas) { + loss += d; + } + context.localLoss = loss; } } diff --git a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleTrainingContext.java b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleTrainingContext.java index b804f21..09bf385 100644 --- a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleTrainingContext.java +++ b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleTrainingContext.java @@ -1,6 +1,11 @@ package com.naaturel.ANN.implementation.simplePerceptron; +import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.abstraction.TrainingContext; +import com.naaturel.ANN.infrastructure.dataset.DataSet; public class SimpleTrainingContext extends TrainingContext { + public SimpleTrainingContext(Model model, DataSet dataset) { + super(model, dataset); + } } diff --git a/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java index 6a4805e..25ca682 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java @@ -24,9 +24,7 @@ public class AdalineTraining implements Trainer { @Override public void train(float learningRate, int epoch, Model model, DataSet dataset) { - AdalineTrainingContext context = new AdalineTrainingContext(); - context.dataset = dataset; - context.model = model; + AdalineTrainingContext context = new AdalineTrainingContext(model, dataset); context.learningRate = learningRate; List steps = List.of( diff --git a/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java index 84e45cb..c40915f 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java @@ -17,7 +17,7 @@ public class GradientBackpropagationTraining implements Trainer { @Override public void train(float learningRate, int epoch, Model model, DataSet dataset) { GradientBackpropagationContext context = - new GradientBackpropagationContext(model, dataset, learningRate, dataset.size()); + new GradientBackpropagationContext(model, dataset, learningRate, dataset.size()/3); List steps = List.of( new SimplePredictionStep(context), @@ -28,14 +28,14 @@ public class GradientBackpropagationTraining implements Trainer { ); new TrainingPipeline(steps) - .stopCondition(ctx -> ctx.globalLoss <= 0.001F || ctx.epoch > epoch) + .stopCondition(ctx -> ctx.globalLoss <= 0.00F || ctx.epoch > epoch) .beforeEpoch(ctx -> { ctx.globalLoss = 0.0F; }) .afterEpoch(ctx -> { ctx.globalLoss /= dataset.size(); }) - .withVerbose(true,epoch/10) + .withVerbose(false,epoch/10) .withTimeMeasurement(true) .run(context); } diff --git a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java index 993fd93..5296936 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java @@ -24,9 +24,7 @@ public class GradientDescentTraining implements Trainer { @Override public void train(float learningRate, int epoch, Model model, DataSet dataset) { - GradientDescentTrainingContext context = new GradientDescentTrainingContext(); - context.dataset = dataset; - context.model = model; + GradientDescentTrainingContext context = new GradientDescentTrainingContext(model, dataset); context.learningRate = learningRate; context.correctorTerms = new ArrayList<>(); diff --git a/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java index 78fd2c7..0cdf100 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java @@ -17,7 +17,7 @@ public class SimpleTraining implements Trainer { @Override public void train(float learningRate, int epoch, Model model, DataSet dataset) { - SimpleTrainingContext context = new SimpleTrainingContext(); + SimpleTrainingContext context = new SimpleTrainingContext(model, dataset); context.dataset = dataset; context.model = model; context.learningRate = learningRate;