Optimize some stuff
This commit is contained in:
@@ -24,21 +24,23 @@ public class Main {
|
|||||||
int nbrClass = 1;
|
int nbrClass = 1;
|
||||||
|
|
||||||
DataSet dataset = new DatasetExtractor()
|
DataSet dataset = new DatasetExtractor()
|
||||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_12.csv", nbrClass);
|
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_9.csv", nbrClass);
|
||||||
|
|
||||||
int[] neuronPerLayer = new int[]{50, 50, 50, dataset.getNbrLabels()};
|
int[] neuronPerLayer = new int[]{1800, 2, 1800, dataset.getNbrLabels()};
|
||||||
int nbrInput = dataset.getNbrInputs();
|
int nbrInput = dataset.getNbrInputs();
|
||||||
|
|
||||||
FullyConnectedNetwork network = createNetwork(neuronPerLayer, nbrInput);
|
FullyConnectedNetwork network = createNetwork(neuronPerLayer, nbrInput);
|
||||||
|
|
||||||
|
System.out.println(network.synCount());
|
||||||
|
|
||||||
Trainer trainer = new GradientBackpropagationTraining();
|
Trainer trainer = new GradientBackpropagationTraining();
|
||||||
trainer.train(0.01F, 2000, network, dataset);
|
trainer.train(0.01F, 2000, network, dataset);
|
||||||
|
|
||||||
//plotGraph(dataset, network);
|
//plotGraph(dataset, network);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){
|
private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){
|
||||||
|
int neuronId = 0;
|
||||||
List<Layer> layers = new ArrayList<>();
|
List<Layer> layers = new ArrayList<>();
|
||||||
for (int i = 0; i < neuronPerLayer.length; i++){
|
for (int i = 0; i < neuronPerLayer.length; i++){
|
||||||
|
|
||||||
@@ -54,8 +56,9 @@ public class Main {
|
|||||||
|
|
||||||
Bias bias = new Bias(new Weight());
|
Bias bias = new Bias(new Weight());
|
||||||
|
|
||||||
Neuron n = new Neuron(syns.toArray(new Synapse[0]), bias, new TanH());
|
Neuron n = new Neuron(neuronId, syns.toArray(new Synapse[0]), bias, new TanH());
|
||||||
neurons.add(n);
|
neurons.add(n);
|
||||||
|
neuronId++;
|
||||||
}
|
}
|
||||||
Layer layer = new Layer(neurons.toArray(new Neuron[0]));
|
Layer layer = new Layer(neurons.toArray(new Neuron[0]));
|
||||||
layers.add(layer);
|
layers.add(layer);
|
||||||
@@ -77,7 +80,7 @@ public class Main {
|
|||||||
|
|
||||||
float min = -5F;
|
float min = -5F;
|
||||||
float max = 5F;
|
float max = 5F;
|
||||||
float step = 0.01F;
|
float step = 0.03F;
|
||||||
for (float x = min; x < max; x+=step){
|
for (float x = min; x < max; x+=step){
|
||||||
for (float y = min; y < max; y+=step){
|
for (float y = min; y < max; y+=step){
|
||||||
List<Float> predictions = network.predict(List.of(new Input(x), new Input(y)));
|
List<Float> predictions = network.predict(List.of(new Input(x), new Input(y)));
|
||||||
|
|||||||
@@ -12,11 +12,18 @@ public abstract class TrainingContext {
|
|||||||
|
|
||||||
public List<Float> expectations;
|
public List<Float> expectations;
|
||||||
public List<Float> predictions;
|
public List<Float> predictions;
|
||||||
public List<Float> deltas;
|
public float[] deltas;
|
||||||
|
|
||||||
public float globalLoss;
|
public float globalLoss;
|
||||||
public float localLoss;
|
public float localLoss;
|
||||||
|
|
||||||
public float learningRate;
|
public float learningRate;
|
||||||
public int epoch;
|
public int epoch;
|
||||||
|
|
||||||
|
public TrainingContext(Model model, DataSet dataset) {
|
||||||
|
this.model = model;
|
||||||
|
this.dataset = dataset;
|
||||||
|
this.deltas = new float[model.neuronCount()];
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,18 +7,25 @@ import java.util.function.Consumer;
|
|||||||
|
|
||||||
public class Neuron implements Model {
|
public class Neuron implements Model {
|
||||||
|
|
||||||
protected Synapse[] synapses;
|
protected int id;
|
||||||
|
protected final Synapse[] synapses;
|
||||||
protected Bias bias;
|
protected Bias bias;
|
||||||
protected ActivationFunction activationFunction;
|
protected ActivationFunction activationFunction;
|
||||||
protected Float output;
|
protected Float output;
|
||||||
protected Float weightedSum;
|
protected Float weightedSum;
|
||||||
|
protected final float[] weights;
|
||||||
|
protected final float[] inputs;
|
||||||
|
|
||||||
public Neuron(Synapse[] synapses, Bias bias, ActivationFunction func){
|
public Neuron(int id, Synapse[] synapses, Bias bias, ActivationFunction func){
|
||||||
|
this.id = id;
|
||||||
this.synapses = synapses;
|
this.synapses = synapses;
|
||||||
this.bias = bias;
|
this.bias = bias;
|
||||||
this.activationFunction = func;
|
this.activationFunction = func;
|
||||||
this.output = null;
|
this.output = null;
|
||||||
this.weightedSum = null;
|
this.weightedSum = null;
|
||||||
|
|
||||||
|
weights = new float[synapses.length];
|
||||||
|
inputs = new float[synapses.length];
|
||||||
}
|
}
|
||||||
|
|
||||||
public void updateBias(Weight weight) {
|
public void updateBias(Weight weight) {
|
||||||
@@ -53,12 +60,18 @@ public class Neuron implements Model {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public float calculateWeightedSum() {
|
public float calculateWeightedSum() {
|
||||||
this.weightedSum = 0F;
|
float sum = bias.getWeight() * bias.getInput();
|
||||||
this.weightedSum += this.bias.getWeight() * this.bias.getInput();
|
|
||||||
for(Synapse syn : this.synapses){
|
for (int i = 0; i < weights.length; i++) {
|
||||||
this.weightedSum += syn.getWeight() * syn.getInput();
|
sum += weights[i] * inputs[i];
|
||||||
}
|
}
|
||||||
return this.weightedSum;
|
|
||||||
|
this.weightedSum = sum;
|
||||||
|
return sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getId(){
|
||||||
|
return this.id;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -103,7 +103,7 @@ public class TrainingPipeline {
|
|||||||
System.out.printf("Epoch : %d, ", ctx.epoch);
|
System.out.printf("Epoch : %d, ", ctx.epoch);
|
||||||
System.out.printf("predicted : %s, ", Arrays.toString(ctx.predictions.toArray()));
|
System.out.printf("predicted : %s, ", Arrays.toString(ctx.predictions.toArray()));
|
||||||
System.out.printf("expected : %s, ", Arrays.toString(ctx.expectations.toArray()));
|
System.out.printf("expected : %s, ", Arrays.toString(ctx.expectations.toArray()));
|
||||||
System.out.printf("delta : %s, ", Arrays.toString(ctx.deltas.toArray()));
|
System.out.printf("delta : %s, ", Arrays.toString(ctx.deltas));
|
||||||
System.out.printf("loss : %.5f\n", ctx.localLoss);
|
System.out.printf("loss : %.5f\n", ctx.localLoss);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,11 @@
|
|||||||
package com.naaturel.ANN.implementation.adaline;
|
package com.naaturel.ANN.implementation.adaline;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.Model;
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||||
|
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||||
|
|
||||||
public class AdalineTrainingContext extends TrainingContext {
|
public class AdalineTrainingContext extends TrainingContext {
|
||||||
|
public AdalineTrainingContext(Model model, DataSet dataset) {
|
||||||
|
super(model, dataset);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ public class GradientDescentErrorStrategy implements AlgorithmStep {
|
|||||||
AtomicInteger synIndex = new AtomicInteger(0);
|
AtomicInteger synIndex = new AtomicInteger(0);
|
||||||
|
|
||||||
context.model.forEachNeuron(neuron -> {
|
context.model.forEachNeuron(neuron -> {
|
||||||
float correspondingDelta = context.deltas.get(neuronIndex.get());
|
float correspondingDelta = context.deltas[neuronIndex.get()];
|
||||||
|
|
||||||
neuron.forEachSynapse(syn -> {
|
neuron.forEachSynapse(syn -> {
|
||||||
float corrector = context.correctorTerms.get(synIndex.get());
|
float corrector = context.correctorTerms.get(synIndex.get());
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package com.naaturel.ANN.implementation.gradientDescent;
|
package com.naaturel.ANN.implementation.gradientDescent;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.Model;
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||||
|
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@@ -8,4 +10,7 @@ public class GradientDescentTrainingContext extends TrainingContext {
|
|||||||
|
|
||||||
public List<Float> correctorTerms;
|
public List<Float> correctorTerms;
|
||||||
|
|
||||||
|
public GradientDescentTrainingContext(Model model, DataSet dataset) {
|
||||||
|
super(model, dataset);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,9 +15,11 @@ public class SquareLossStep implements AlgorithmStep {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void run() {
|
public void run() {
|
||||||
Stream<Float> deltaStream = this.context.deltas.stream();
|
float loss = 0f;
|
||||||
this.context.localLoss = deltaStream.reduce(0.0F, (acc, d) -> (float) (acc + Math.pow(d, 2)));
|
for (float d : this.context.deltas) {
|
||||||
this.context.localLoss /= 2;
|
loss += d * d;
|
||||||
this.context.globalLoss += this.context.localLoss; //broke MSE en gradientDescentTraining
|
}
|
||||||
|
this.context.localLoss = loss / 2f;
|
||||||
|
this.context.globalLoss += this.context.localLoss;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,41 +2,55 @@ package com.naaturel.ANN.implementation.multiLayers;
|
|||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||||
|
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
|
||||||
|
|
||||||
public class BackpropagationCorrectionStep implements AlgorithmStep {
|
public class BackpropagationCorrectionStep implements AlgorithmStep {
|
||||||
|
|
||||||
private GradientBackpropagationContext context;
|
private final GradientBackpropagationContext context;
|
||||||
|
private final int synCount;
|
||||||
|
private final float[] inputs;
|
||||||
|
private final float[] signals;
|
||||||
|
|
||||||
public BackpropagationCorrectionStep(GradientBackpropagationContext context){
|
public BackpropagationCorrectionStep(GradientBackpropagationContext context){
|
||||||
this.context = context;
|
this.context = context;
|
||||||
|
this.synCount = context.correctionBuffer.length;
|
||||||
|
this.inputs = new float[synCount];
|
||||||
|
this.signals = new float[synCount];
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void run() {
|
public void run() {
|
||||||
|
int[] synIndex = {0};
|
||||||
AtomicInteger synIndex = new AtomicInteger(0);
|
context.model.forEachNeuron(n -> {
|
||||||
this.context.model.forEachNeuron(n -> {
|
float signal = context.errorSignals[n.getId()];
|
||||||
float signal = context.errorSignals.get(n);
|
|
||||||
n.forEachSynapse(syn -> {
|
n.forEachSynapse(syn -> {
|
||||||
float lr = context.learningRate;
|
inputs[synIndex[0]] = syn.getInput();
|
||||||
float corrector = lr * signal * syn.getInput();
|
signals[synIndex[0]] = signal;
|
||||||
float existingCorrector = context.correctionBuffer[synIndex.get()];
|
synIndex[0]++;
|
||||||
float newCorrector = existingCorrector + corrector;
|
});
|
||||||
|
});
|
||||||
|
|
||||||
if(context.currentSample >= context.batchSize){
|
float lr = context.learningRate;
|
||||||
float newWeight = syn.getWeight() + newCorrector;
|
boolean applyUpdate = context.currentSample >= context.batchSize;
|
||||||
syn.setWeight(newWeight);
|
|
||||||
context.correctionBuffer[synIndex.get()] = 0;
|
for (int i = 0; i < synCount; i++) {
|
||||||
} else {
|
context.correctionBuffer[i] += lr * signals[i] * inputs[i];
|
||||||
context.correctionBuffer[synIndex.get()] = newCorrector;
|
|
||||||
}
|
}
|
||||||
synIndex.incrementAndGet();
|
|
||||||
});
|
if (applyUpdate) {
|
||||||
});
|
syncWeights();
|
||||||
if(context.currentSample >= context.batchSize) {
|
|
||||||
context.currentSample = 0;
|
context.currentSample = 0;
|
||||||
}
|
}
|
||||||
context.currentSample += 1;
|
|
||||||
|
context.currentSample++;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void syncWeights() {
|
||||||
|
int[] i = {0};
|
||||||
|
context.model.forEachNeuron(n -> {
|
||||||
|
n.forEachSynapse(syn -> {
|
||||||
|
syn.setWeight(syn.getWeight() + context.correctionBuffer[i[0]]);
|
||||||
|
context.correctionBuffer[i[0]] = 0f;
|
||||||
|
i[0]++;
|
||||||
|
});
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1,38 +1,29 @@
|
|||||||
package com.naaturel.ANN.implementation.multiLayers;
|
package com.naaturel.ANN.implementation.multiLayers;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
|
||||||
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.concurrent.atomic.AtomicReference;
|
|
||||||
|
|
||||||
public class ErrorSignalStep implements AlgorithmStep {
|
public class ErrorSignalStep implements AlgorithmStep {
|
||||||
|
|
||||||
private GradientBackpropagationContext context;
|
private final GradientBackpropagationContext context;
|
||||||
|
|
||||||
public ErrorSignalStep(GradientBackpropagationContext context) {
|
public ErrorSignalStep(GradientBackpropagationContext context) {
|
||||||
this.context = context;
|
this.context = context;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void run() {
|
public void run() {
|
||||||
this.context.model.forEachNeuron(n -> {
|
|
||||||
calculateErrorSignalRecursive(n, this.context.errorSignals);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
private float calculateErrorSignalRecursive(Neuron n, Map<Neuron, Float> signals) {
|
context.model.forEachNeuron(n -> {
|
||||||
if (signals.containsKey(n)) return signals.get(n);
|
if (context.errorSignalsComputed[n.getId()]) return;
|
||||||
|
|
||||||
int neuronIndex = this.context.model.indexInLayerOf(n);
|
int neuronIndex = context.model.indexInLayerOf(n);
|
||||||
AtomicReference<Float> signalSum = new AtomicReference<>(0F);
|
float[] signalSum = {0f};
|
||||||
this.context.model.forEachNeuronConnectedTo(n, connected -> {
|
context.model.forEachNeuronConnectedTo(n, connected -> {
|
||||||
float weightedSignal = calculateErrorSignalRecursive(connected, signals) * connected.getWeight(neuronIndex);
|
signalSum[0] += context.errorSignals[connected.getId()] * connected.getWeight(neuronIndex);
|
||||||
signalSum.set(signalSum.get() + weightedSignal);
|
|
||||||
});
|
});
|
||||||
|
|
||||||
float derivative = n.getActivationFunction().derivative(n.getOutput());
|
context.errorSignals[n.getId()] = n.getActivationFunction().derivative(n.getOutput()) * signalSum[0];
|
||||||
float finalSignal = derivative * signalSum.get();
|
context.errorSignalsComputed[n.getId()] = true;
|
||||||
signals.put(n, finalSignal);
|
});
|
||||||
return finalSignal;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -10,20 +10,21 @@ import java.util.Map;
|
|||||||
|
|
||||||
public class GradientBackpropagationContext extends TrainingContext {
|
public class GradientBackpropagationContext extends TrainingContext {
|
||||||
|
|
||||||
public final Map<Neuron, Float> errorSignals;
|
public final float[] errorSignals;
|
||||||
public final float[] correctionBuffer;
|
public final float[] correctionBuffer;
|
||||||
|
public final boolean[] errorSignalsComputed;
|
||||||
|
|
||||||
public int currentSample;
|
public int currentSample;
|
||||||
public int batchSize;
|
public int batchSize;
|
||||||
|
|
||||||
public GradientBackpropagationContext(Model model, DataSet dataSet, float learningRate, int batchSize){
|
public GradientBackpropagationContext(Model model, DataSet dataSet, float learningRate, int batchSize){
|
||||||
this.model = model;
|
super(model, dataSet);
|
||||||
this.dataset = dataSet;
|
|
||||||
this.learningRate = learningRate;
|
this.learningRate = learningRate;
|
||||||
this.batchSize = batchSize;
|
this.batchSize = batchSize;
|
||||||
|
|
||||||
this.errorSignals = new HashMap<>();
|
this.errorSignals = new float[model.neuronCount()];
|
||||||
this.correctionBuffer = new float[model.synCount()];
|
this.correctionBuffer = new float[model.synCount()];
|
||||||
|
this.errorSignalsComputed = new boolean[model.neuronCount()];
|
||||||
this.currentSample = 1;
|
this.currentSample = 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,37 +3,40 @@ package com.naaturel.ANN.implementation.multiLayers;
|
|||||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
|
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.Arrays;
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
|
||||||
|
|
||||||
public class OutputLayerErrorStep implements AlgorithmStep {
|
public class OutputLayerErrorStep implements AlgorithmStep {
|
||||||
|
|
||||||
private final GradientBackpropagationContext context;
|
private final GradientBackpropagationContext context;
|
||||||
|
private final float[] expectations;
|
||||||
|
|
||||||
public OutputLayerErrorStep(GradientBackpropagationContext context){
|
public OutputLayerErrorStep(GradientBackpropagationContext context){
|
||||||
this.context = context;
|
this.context = context;
|
||||||
|
this.expectations = new float[context.model.neuronCount()];
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void run() {
|
public void run() {
|
||||||
context.deltas = new ArrayList<>();
|
Arrays.fill(context.errorSignals, 0f);
|
||||||
DataSetEntry entry = this.context.currentEntry;
|
Arrays.fill(context.errorSignalsComputed, false);
|
||||||
List<Float> expectations = this.context.dataset.getLabelsAsFloat(entry);
|
|
||||||
AtomicInteger index = new AtomicInteger(0);
|
|
||||||
|
|
||||||
context.errorSignals.clear();
|
DataSetEntry entry = context.currentEntry;
|
||||||
this.context.model.forEachOutputNeurons(n -> {
|
List<Float> labels = context.dataset.getLabelsAsFloat(entry);
|
||||||
float expected = expectations.get(index.get());
|
for (int i = 0; i < labels.size(); i++) {
|
||||||
|
expectations[i] = labels.get(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
int[] index = {0};
|
||||||
|
context.model.forEachOutputNeurons(n -> {
|
||||||
|
float expected = expectations[index[0]];
|
||||||
float predicted = n.getOutput();
|
float predicted = n.getOutput();
|
||||||
float output = n.getOutput();
|
|
||||||
float delta = expected - predicted;
|
float delta = expected - predicted;
|
||||||
float signal = delta * n.getActivationFunction().derivative(output);
|
|
||||||
|
|
||||||
this.context.deltas.add(delta);
|
context.deltas[index[0]] = delta;
|
||||||
this.context.errorSignals.put(n, signal);
|
context.errorSignals[n.getId()] = delta * n.getActivationFunction().derivative(predicted);
|
||||||
index.incrementAndGet();
|
context.errorSignalsComputed[n.getId()] = true;
|
||||||
|
index[0]++;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -21,7 +21,7 @@ public class SimpleCorrectionStep implements AlgorithmStep {
|
|||||||
AtomicInteger synIndex = new AtomicInteger(0);
|
AtomicInteger synIndex = new AtomicInteger(0);
|
||||||
|
|
||||||
context.model.forEachNeuron(neuron -> {
|
context.model.forEachNeuron(neuron -> {
|
||||||
float correspondingDelta = context.deltas.get(neuronIndex.get());
|
float correspondingDelta = context.deltas[neuronIndex.get()];
|
||||||
neuron.forEachSynapse(syn -> {
|
neuron.forEachSynapse(syn -> {
|
||||||
float currentW = syn.getWeight();
|
float currentW = syn.getWeight();
|
||||||
float currentInput = syn.getInput();
|
float currentInput = syn.getInput();
|
||||||
|
|||||||
@@ -24,10 +24,9 @@ public class SimpleDeltaStep implements AlgorithmStep {
|
|||||||
List<Float> predicted = context.predictions;
|
List<Float> predicted = context.predictions;
|
||||||
List<Float> expected = dataSet.getLabelsAsFloat(entry);
|
List<Float> expected = dataSet.getLabelsAsFloat(entry);
|
||||||
|
|
||||||
//context.delta = label.getValue() - context.predictions;
|
for (int i = 0; i < predicted.size(); i++) {
|
||||||
context.deltas = IntStream.range(0, predicted.size())
|
context.deltas[i] = expected.get(i) - predicted.get(i);
|
||||||
.mapToObj(i -> expected.get(i) - predicted.get(i))
|
}
|
||||||
.collect(Collectors.toList());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,10 @@ public class SimpleLossStrategy implements AlgorithmStep {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void run() {
|
public void run() {
|
||||||
this.context.localLoss = this.context.deltas.stream().reduce(0.0F, Float::sum);
|
float loss = 0f;
|
||||||
|
for (float d : context.deltas) {
|
||||||
|
loss += d;
|
||||||
|
}
|
||||||
|
context.localLoss = loss;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,11 @@
|
|||||||
package com.naaturel.ANN.implementation.simplePerceptron;
|
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.Model;
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||||
|
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||||
|
|
||||||
public class SimpleTrainingContext extends TrainingContext {
|
public class SimpleTrainingContext extends TrainingContext {
|
||||||
|
public SimpleTrainingContext(Model model, DataSet dataset) {
|
||||||
|
super(model, dataset);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,9 +24,7 @@ public class AdalineTraining implements Trainer {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
|
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
|
||||||
AdalineTrainingContext context = new AdalineTrainingContext();
|
AdalineTrainingContext context = new AdalineTrainingContext(model, dataset);
|
||||||
context.dataset = dataset;
|
|
||||||
context.model = model;
|
|
||||||
context.learningRate = learningRate;
|
context.learningRate = learningRate;
|
||||||
|
|
||||||
List<AlgorithmStep> steps = List.of(
|
List<AlgorithmStep> steps = List.of(
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ public class GradientBackpropagationTraining implements Trainer {
|
|||||||
@Override
|
@Override
|
||||||
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
|
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
|
||||||
GradientBackpropagationContext context =
|
GradientBackpropagationContext context =
|
||||||
new GradientBackpropagationContext(model, dataset, learningRate, dataset.size());
|
new GradientBackpropagationContext(model, dataset, learningRate, dataset.size()/3);
|
||||||
|
|
||||||
List<AlgorithmStep> steps = List.of(
|
List<AlgorithmStep> steps = List.of(
|
||||||
new SimplePredictionStep(context),
|
new SimplePredictionStep(context),
|
||||||
@@ -28,14 +28,14 @@ public class GradientBackpropagationTraining implements Trainer {
|
|||||||
);
|
);
|
||||||
|
|
||||||
new TrainingPipeline(steps)
|
new TrainingPipeline(steps)
|
||||||
.stopCondition(ctx -> ctx.globalLoss <= 0.001F || ctx.epoch > epoch)
|
.stopCondition(ctx -> ctx.globalLoss <= 0.00F || ctx.epoch > epoch)
|
||||||
.beforeEpoch(ctx -> {
|
.beforeEpoch(ctx -> {
|
||||||
ctx.globalLoss = 0.0F;
|
ctx.globalLoss = 0.0F;
|
||||||
})
|
})
|
||||||
.afterEpoch(ctx -> {
|
.afterEpoch(ctx -> {
|
||||||
ctx.globalLoss /= dataset.size();
|
ctx.globalLoss /= dataset.size();
|
||||||
})
|
})
|
||||||
.withVerbose(true,epoch/10)
|
.withVerbose(false,epoch/10)
|
||||||
.withTimeMeasurement(true)
|
.withTimeMeasurement(true)
|
||||||
.run(context);
|
.run(context);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,9 +24,7 @@ public class GradientDescentTraining implements Trainer {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
|
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
|
||||||
GradientDescentTrainingContext context = new GradientDescentTrainingContext();
|
GradientDescentTrainingContext context = new GradientDescentTrainingContext(model, dataset);
|
||||||
context.dataset = dataset;
|
|
||||||
context.model = model;
|
|
||||||
context.learningRate = learningRate;
|
context.learningRate = learningRate;
|
||||||
context.correctorTerms = new ArrayList<>();
|
context.correctorTerms = new ArrayList<>();
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ public class SimpleTraining implements Trainer {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
|
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
|
||||||
SimpleTrainingContext context = new SimpleTrainingContext();
|
SimpleTrainingContext context = new SimpleTrainingContext(model, dataset);
|
||||||
context.dataset = dataset;
|
context.dataset = dataset;
|
||||||
context.model = model;
|
context.model = model;
|
||||||
context.learningRate = learningRate;
|
context.learningRate = learningRate;
|
||||||
|
|||||||
Reference in New Issue
Block a user