Fix multi layer implementation
This commit is contained in:
@@ -17,12 +17,12 @@ public class Main {
|
|||||||
|
|
||||||
public static void main(String[] args){
|
public static void main(String[] args){
|
||||||
|
|
||||||
int nbrClass = 1;
|
int nbrClass = 3;
|
||||||
|
|
||||||
DataSet dataset = new DatasetExtractor()
|
DataSet dataset = new DatasetExtractor()
|
||||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_12.csv", nbrClass);
|
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_14.csv", nbrClass);
|
||||||
|
|
||||||
int[] neuronPerLayer = new int[]{10, dataset.getNbrLabels()};
|
int[] neuronPerLayer = new int[]{3, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 37, dataset.getNbrLabels()};
|
||||||
int nbrInput = dataset.getNbrInputs();
|
int nbrInput = dataset.getNbrInputs();
|
||||||
|
|
||||||
List<Layer> layers = new ArrayList<>();
|
List<Layer> layers = new ArrayList<>();
|
||||||
@@ -40,7 +40,7 @@ public class Main {
|
|||||||
|
|
||||||
Bias bias = new Bias(new Weight());
|
Bias bias = new Bias(new Weight());
|
||||||
|
|
||||||
Neuron n = new Neuron(syns, bias, new Sigmoid(2));
|
Neuron n = new Neuron(syns, bias, new TanH());
|
||||||
neurons.add(n);
|
neurons.add(n);
|
||||||
}
|
}
|
||||||
Layer layer = new Layer(neurons);
|
Layer layer = new Layer(neurons);
|
||||||
@@ -50,7 +50,7 @@ public class Main {
|
|||||||
FullyConnectedNetwork network = new FullyConnectedNetwork(layers);
|
FullyConnectedNetwork network = new FullyConnectedNetwork(layers);
|
||||||
|
|
||||||
Trainer trainer = new GradientBackpropagationTraining();
|
Trainer trainer = new GradientBackpropagationTraining();
|
||||||
trainer.train(0.5F, network, dataset);
|
trainer.train(0.001F, 1000, network, dataset);
|
||||||
|
|
||||||
/*GraphVisualizer visualizer = new GraphVisualizer();
|
/*GraphVisualizer visualizer = new GraphVisualizer();
|
||||||
|
|
||||||
@@ -59,7 +59,7 @@ public class Main {
|
|||||||
visualizer.addPoint("Label " + label.getFirst(), entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
|
visualizer.addPoint("Label " + label.getFirst(), entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
|
||||||
}
|
}
|
||||||
|
|
||||||
float min = -2F;
|
float min = -3F;
|
||||||
float max = 2F;
|
float max = 2F;
|
||||||
float step = 0.01F;
|
float step = 0.01F;
|
||||||
for (float x = min; x < max; x+=step){
|
for (float x = min; x < max; x+=step){
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import java.util.function.Consumer;
|
|||||||
public interface Model {
|
public interface Model {
|
||||||
int synCount();
|
int synCount();
|
||||||
int neuronCount();
|
int neuronCount();
|
||||||
int indexOf(Neuron n);
|
int indexInLayerOf(Neuron n);
|
||||||
void forEachNeuron(Consumer<Neuron> consumer);
|
void forEachNeuron(Consumer<Neuron> consumer);
|
||||||
void forEachSynapse(Consumer<Synapse> consumer);
|
void forEachSynapse(Consumer<Synapse> consumer);
|
||||||
void forEachOutputNeurons(Consumer<Neuron> consumer);
|
void forEachOutputNeurons(Consumer<Neuron> consumer);
|
||||||
|
|||||||
@@ -3,5 +3,5 @@ package com.naaturel.ANN.domain.abstraction;
|
|||||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||||
|
|
||||||
public interface Trainer {
|
public interface Trainer {
|
||||||
void train(float learningRate, Model model, DataSet dataset);
|
void train(float learningRate, int epoch, Model model, DataSet dataset);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import java.util.ArrayList;
|
|||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
import java.util.function.Consumer;
|
import java.util.function.Consumer;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -13,13 +14,13 @@ import java.util.function.Consumer;
|
|||||||
*/
|
*/
|
||||||
public class FullyConnectedNetwork implements Model {
|
public class FullyConnectedNetwork implements Model {
|
||||||
|
|
||||||
private final List<Layer> layers;;
|
private final List<Layer> layers;
|
||||||
private final Map<Neuron, List<Neuron>> connectionMap;
|
private final Map<Neuron, List<Neuron>> connectionMap;
|
||||||
private final Map<Neuron, Integer> neuronIndex;
|
private final Map<Neuron, Integer> layerIndexByNeuron;
|
||||||
public FullyConnectedNetwork(List<Layer> layers) {
|
public FullyConnectedNetwork(List<Layer> layers) {
|
||||||
this.layers = layers;
|
this.layers = layers;
|
||||||
this.connectionMap = this.createConnectionMap();
|
this.connectionMap = this.createConnectionMap();
|
||||||
this.neuronIndex = this.createNeuronIndex();
|
this.layerIndexByNeuron = this.createNeuronIndex();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -71,8 +72,9 @@ public class FullyConnectedNetwork implements Model {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int indexOf(Neuron n) {
|
public int indexInLayerOf(Neuron n) {
|
||||||
return this.neuronIndex.get(n);
|
int layerIndex = this.layerIndexByNeuron.get(n);
|
||||||
|
return this.layers.get(layerIndex).indexInLayerOf(n);
|
||||||
}
|
}
|
||||||
|
|
||||||
private Map<Neuron, List<Neuron>> createConnectionMap() {
|
private Map<Neuron, List<Neuron>> createConnectionMap() {
|
||||||
@@ -83,14 +85,16 @@ public class FullyConnectedNetwork implements Model {
|
|||||||
this.layers.get(i + 1).forEachNeuron(nextLayerNeurons::add);
|
this.layers.get(i + 1).forEachNeuron(nextLayerNeurons::add);
|
||||||
this.layers.get(i).forEachNeuron(n -> res.put(n, nextLayerNeurons));
|
this.layers.get(i).forEachNeuron(n -> res.put(n, nextLayerNeurons));
|
||||||
}
|
}
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
private Map<Neuron, Integer> createNeuronIndex() {
|
private Map<Neuron, Integer> createNeuronIndex() {
|
||||||
Map<Neuron, Integer> res = new HashMap<>();
|
Map<Neuron, Integer> res = new HashMap<>();
|
||||||
int[] index = {0};
|
AtomicInteger index = new AtomicInteger(0);
|
||||||
this.layers.forEach(l -> l.forEachNeuron(n -> res.put(n, index[0]++)));
|
this.layers.forEach(l -> {
|
||||||
|
l.forEachNeuron(n -> res.put(n, index.get()));
|
||||||
|
index.incrementAndGet();
|
||||||
|
});
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,15 +3,19 @@ package com.naaturel.ANN.domain.model.neuron;
|
|||||||
import com.naaturel.ANN.domain.abstraction.Model;
|
import com.naaturel.ANN.domain.abstraction.Model;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
import java.util.function.Consumer;
|
import java.util.function.Consumer;
|
||||||
|
|
||||||
public class Layer implements Model {
|
public class Layer implements Model {
|
||||||
|
|
||||||
private final List<Neuron> neurons;
|
private final List<Neuron> neurons;
|
||||||
|
private final Map<Neuron, Integer> neuronIndex;
|
||||||
|
|
||||||
public Layer(List<Neuron> neurons) {
|
public Layer(List<Neuron> neurons) {
|
||||||
this.neurons = neurons;
|
this.neurons = neurons;
|
||||||
|
this.neuronIndex = createNeuronIndex();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -39,8 +43,8 @@ public class Layer implements Model {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int indexOf(Neuron n) {
|
public int indexInLayerOf(Neuron n) {
|
||||||
return this.neurons.indexOf(n);
|
return this.neuronIndex.get(n);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -62,4 +66,14 @@ public class Layer implements Model {
|
|||||||
public void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer) {
|
public void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer) {
|
||||||
throw new UnsupportedOperationException("Neurons have no connection within the same layer");
|
throw new UnsupportedOperationException("Neurons have no connection within the same layer");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private Map<Neuron, Integer> createNeuronIndex() {
|
||||||
|
Map<Neuron, Integer> res = new HashMap<>();
|
||||||
|
int[] index = {0};
|
||||||
|
this.neurons.forEach(n -> {
|
||||||
|
res.put(n, index[0]++);
|
||||||
|
});
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ public class Neuron implements Model {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int indexOf(Neuron n) {
|
public int indexInLayerOf(Neuron n) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -20,11 +20,13 @@ public class TrainingPipeline {
|
|||||||
private Consumer<TrainingContext> afterEpoch;
|
private Consumer<TrainingContext> afterEpoch;
|
||||||
private Predicate<TrainingContext> stopCondition;
|
private Predicate<TrainingContext> stopCondition;
|
||||||
|
|
||||||
private GraphVisualizer visualizer;
|
|
||||||
private boolean verbose;
|
private boolean verbose;
|
||||||
private boolean visualization;
|
private boolean visualization;
|
||||||
private boolean timeMeasurement;
|
private boolean timeMeasurement;
|
||||||
|
|
||||||
|
private GraphVisualizer visualizer;
|
||||||
|
private int verboseDelay;
|
||||||
|
|
||||||
public TrainingPipeline(List<AlgorithmStep> steps) {
|
public TrainingPipeline(List<AlgorithmStep> steps) {
|
||||||
this.steps = new ArrayList<>(steps);
|
this.steps = new ArrayList<>(steps);
|
||||||
this.stopCondition = (ctx) -> false;
|
this.stopCondition = (ctx) -> false;
|
||||||
@@ -47,8 +49,10 @@ public class TrainingPipeline {
|
|||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
public TrainingPipeline withVerbose(boolean enabled) {
|
public TrainingPipeline withVerbose(boolean enabled, int epochDelay) {
|
||||||
|
if(epochDelay <= 0) throw new IllegalArgumentException("Epoch delay cannot lower or equal to 0");
|
||||||
this.verbose = enabled;
|
this.verbose = enabled;
|
||||||
|
this.verboseDelay = epochDelay;
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -71,9 +75,10 @@ public class TrainingPipeline {
|
|||||||
this.beforeEpoch.accept(ctx);
|
this.beforeEpoch.accept(ctx);
|
||||||
this.executeSteps(ctx);
|
this.executeSteps(ctx);
|
||||||
this.afterEpoch.accept(ctx);
|
this.afterEpoch.accept(ctx);
|
||||||
if(this.verbose) {
|
if(this.verbose && ctx.epoch % this.verboseDelay == 0) {
|
||||||
System.out.printf("[Global error] : %f\n", ctx.globalLoss);
|
System.out.printf("[Global error] : %f\n", ctx.globalLoss);
|
||||||
}
|
}
|
||||||
|
ctx.epoch += 1;
|
||||||
} while (!this.stopCondition.test(ctx));
|
} while (!this.stopCondition.test(ctx));
|
||||||
|
|
||||||
if(this.timeMeasurement) {
|
if(this.timeMeasurement) {
|
||||||
@@ -94,7 +99,7 @@ public class TrainingPipeline {
|
|||||||
step.run();
|
step.run();
|
||||||
}
|
}
|
||||||
|
|
||||||
if(this.verbose) {
|
if(this.verbose && ctx.epoch % this.verboseDelay == 0) {
|
||||||
System.out.printf("Epoch : %d, ", ctx.epoch);
|
System.out.printf("Epoch : %d, ", ctx.epoch);
|
||||||
System.out.printf("predicted : %s, ", Arrays.toString(ctx.predictions.toArray()));
|
System.out.printf("predicted : %s, ", Arrays.toString(ctx.predictions.toArray()));
|
||||||
System.out.printf("expected : %s, ", Arrays.toString(ctx.expectations.toArray()));
|
System.out.printf("expected : %s, ", Arrays.toString(ctx.expectations.toArray()));
|
||||||
@@ -102,7 +107,6 @@ public class TrainingPipeline {
|
|||||||
System.out.printf("loss : %.5f\n", ctx.localLoss);
|
System.out.printf("loss : %.5f\n", ctx.localLoss);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ctx.epoch += 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void visualize(TrainingContext ctx){
|
private void visualize(TrainingContext ctx){
|
||||||
|
|||||||
@@ -18,5 +18,6 @@ public class SquareLossStep implements AlgorithmStep {
|
|||||||
Stream<Float> deltaStream = this.context.deltas.stream();
|
Stream<Float> deltaStream = this.context.deltas.stream();
|
||||||
this.context.localLoss = deltaStream.reduce(0.0F, (acc, d) -> (float) (acc + Math.pow(d, 2)));
|
this.context.localLoss = deltaStream.reduce(0.0F, (acc, d) -> (float) (acc + Math.pow(d, 2)));
|
||||||
this.context.localLoss /= 2;
|
this.context.localLoss /= 2;
|
||||||
|
this.context.globalLoss += this.context.localLoss; //broke MSE en gradientDescentTraining
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ public class BackpropagationCorrectionStep implements AlgorithmStep {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void run() {
|
public void run() {
|
||||||
this.context.model.forEachOutputNeurons(n -> {
|
this.context.model.forEachNeuron(n -> {
|
||||||
n.forEachSynapse(syn -> {
|
n.forEachSynapse(syn -> {
|
||||||
float lr = context.learningRate;
|
float lr = context.learningRate;
|
||||||
float signal = context.errorSignals.get(n);
|
float signal = context.errorSignals.get(n);
|
||||||
|
|||||||
@@ -2,13 +2,8 @@ package com.naaturel.ANN.implementation.multiLayers;
|
|||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
|
||||||
import java.util.concurrent.atomic.AtomicReference;
|
import java.util.concurrent.atomic.AtomicReference;
|
||||||
|
|
||||||
public class ErrorSignalStep implements AlgorithmStep {
|
public class ErrorSignalStep implements AlgorithmStep {
|
||||||
@@ -20,23 +15,19 @@ public class ErrorSignalStep implements AlgorithmStep {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void run() {
|
public void run() {
|
||||||
this.context.deltas = new ArrayList<>();
|
this.context.model.forEachNeuron(n -> {
|
||||||
this.context.errorSignals = new HashMap<>();
|
calculateErrorSignalRecursive(n, this.context.errorSignals);
|
||||||
this.calculateOutputLayerErrorSignals();
|
});
|
||||||
|
|
||||||
this.context.model.forEachNeuron(n -> calculateErrorSignalRecursive(n, this.context.errorSignals));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private float calculateErrorSignalRecursive(Neuron n, Map<Neuron, Float> signals) {
|
private float calculateErrorSignalRecursive(Neuron n, Map<Neuron, Float> signals) {
|
||||||
if (signals.containsKey(n)) return signals.get(n);
|
if (signals.containsKey(n)) return signals.get(n);
|
||||||
|
|
||||||
AtomicInteger connectedIndex = new AtomicInteger(0);
|
int neuronIndex = this.context.model.indexInLayerOf(n);
|
||||||
AtomicReference<Float> signalSum = new AtomicReference<>(0F);
|
AtomicReference<Float> signalSum = new AtomicReference<>(0F);
|
||||||
this.context.model.forEachNeuronConnectedTo(n, connected -> {
|
this.context.model.forEachNeuronConnectedTo(n, connected -> {
|
||||||
int neuronIndex = this.context.model.indexOf(n);
|
|
||||||
float weightedSignal = calculateErrorSignalRecursive(connected, signals) * connected.getWeight(neuronIndex);
|
float weightedSignal = calculateErrorSignalRecursive(connected, signals) * connected.getWeight(neuronIndex);
|
||||||
signalSum.set(signalSum.get() + weightedSignal);
|
signalSum.set(signalSum.get() + weightedSignal);
|
||||||
connectedIndex.incrementAndGet();
|
|
||||||
});
|
});
|
||||||
|
|
||||||
float derivative = n.getActivationFunction().derivative(n.getOutput());
|
float derivative = n.getActivationFunction().derivative(n.getOutput());
|
||||||
@@ -44,22 +35,4 @@ public class ErrorSignalStep implements AlgorithmStep {
|
|||||||
signals.put(n, finalSignal);
|
signals.put(n, finalSignal);
|
||||||
return finalSignal;
|
return finalSignal;
|
||||||
}
|
}
|
||||||
|
|
||||||
private void calculateOutputLayerErrorSignals(){
|
|
||||||
DataSetEntry entry = this.context.currentEntry;
|
|
||||||
List<Float> expectations = this.context.dataset.getLabelsAsFloat(entry);
|
|
||||||
AtomicInteger index = new AtomicInteger(0);
|
|
||||||
|
|
||||||
this.context.model.forEachOutputNeurons(n -> {
|
|
||||||
float expected = expectations.get(index.get());
|
|
||||||
float predicted = n.getOutput();
|
|
||||||
float output = n.getOutput();
|
|
||||||
float delta = expected - predicted;
|
|
||||||
float signal = delta * n.getActivationFunction().derivative(output);
|
|
||||||
|
|
||||||
this.context.deltas.add(delta);
|
|
||||||
this.context.errorSignals.put(n, signal);
|
|
||||||
index.incrementAndGet();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,39 @@
|
|||||||
|
package com.naaturel.ANN.implementation.multiLayers;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||||
|
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
|
||||||
|
public class OutputLayerErrorStep implements AlgorithmStep {
|
||||||
|
|
||||||
|
private final GradientBackpropagationContext context;
|
||||||
|
|
||||||
|
public OutputLayerErrorStep(GradientBackpropagationContext context){
|
||||||
|
this.context = context;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void run() {
|
||||||
|
context.deltas = new ArrayList<>();
|
||||||
|
DataSetEntry entry = this.context.currentEntry;
|
||||||
|
List<Float> expectations = this.context.dataset.getLabelsAsFloat(entry);
|
||||||
|
AtomicInteger index = new AtomicInteger(0);
|
||||||
|
|
||||||
|
context.errorSignals = new HashMap<>();
|
||||||
|
this.context.model.forEachOutputNeurons(n -> {
|
||||||
|
float expected = expectations.get(index.get());
|
||||||
|
float predicted = n.getOutput();
|
||||||
|
float output = n.getOutput();
|
||||||
|
float delta = expected - predicted;
|
||||||
|
float signal = delta * n.getActivationFunction().derivative(output);
|
||||||
|
|
||||||
|
this.context.deltas.add(delta);
|
||||||
|
this.context.errorSignals.put(n, signal);
|
||||||
|
index.incrementAndGet();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -23,7 +23,7 @@ public class AdalineTraining implements Trainer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void train(float learningRate, Model model, DataSet dataset) {
|
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
|
||||||
AdalineTrainingContext context = new AdalineTrainingContext();
|
AdalineTrainingContext context = new AdalineTrainingContext();
|
||||||
context.dataset = dataset;
|
context.dataset = dataset;
|
||||||
context.model = model;
|
context.model = model;
|
||||||
@@ -38,11 +38,11 @@ public class AdalineTraining implements Trainer {
|
|||||||
);
|
);
|
||||||
|
|
||||||
new TrainingPipeline(steps)
|
new TrainingPipeline(steps)
|
||||||
.stopCondition(ctx -> ctx.globalLoss <= 0.00F || ctx.epoch > 25)
|
.stopCondition(ctx -> ctx.globalLoss <= 0.00F || ctx.epoch > epoch)
|
||||||
.beforeEpoch(ctx -> ctx.globalLoss = 0.0F)
|
.beforeEpoch(ctx -> ctx.globalLoss = 0.0F)
|
||||||
.afterEpoch(ctx -> ctx.globalLoss /= context.dataset.size())
|
.afterEpoch(ctx -> ctx.globalLoss /= context.dataset.size())
|
||||||
.withTimeMeasurement(true)
|
.withTimeMeasurement(true)
|
||||||
.withVerbose(true)
|
.withVerbose(true, 1)
|
||||||
.withVisualization(true, new GraphVisualizer())
|
.withVisualization(true, new GraphVisualizer())
|
||||||
.run(context);
|
.run(context);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,14 +8,14 @@ import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep;
|
|||||||
import com.naaturel.ANN.implementation.multiLayers.BackpropagationCorrectionStep;
|
import com.naaturel.ANN.implementation.multiLayers.BackpropagationCorrectionStep;
|
||||||
import com.naaturel.ANN.implementation.multiLayers.GradientBackpropagationContext;
|
import com.naaturel.ANN.implementation.multiLayers.GradientBackpropagationContext;
|
||||||
import com.naaturel.ANN.implementation.multiLayers.ErrorSignalStep;
|
import com.naaturel.ANN.implementation.multiLayers.ErrorSignalStep;
|
||||||
|
import com.naaturel.ANN.implementation.multiLayers.OutputLayerErrorStep;
|
||||||
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
|
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public class GradientBackpropagationTraining implements Trainer {
|
public class GradientBackpropagationTraining implements Trainer {
|
||||||
@Override
|
@Override
|
||||||
public void train(float learningRate, Model model, DataSet dataset) {
|
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
|
||||||
GradientBackpropagationContext context = new GradientBackpropagationContext();
|
GradientBackpropagationContext context = new GradientBackpropagationContext();
|
||||||
context.dataset = dataset;
|
context.dataset = dataset;
|
||||||
context.model = model;
|
context.model = model;
|
||||||
@@ -23,18 +23,20 @@ public class GradientBackpropagationTraining implements Trainer {
|
|||||||
|
|
||||||
List<AlgorithmStep> steps = List.of(
|
List<AlgorithmStep> steps = List.of(
|
||||||
new SimplePredictionStep(context),
|
new SimplePredictionStep(context),
|
||||||
|
new OutputLayerErrorStep(context),
|
||||||
new ErrorSignalStep(context),
|
new ErrorSignalStep(context),
|
||||||
new BackpropagationCorrectionStep(context),
|
new BackpropagationCorrectionStep(context),
|
||||||
new SquareLossStep(context)
|
new SquareLossStep(context)
|
||||||
);
|
);
|
||||||
|
|
||||||
new TrainingPipeline(steps)
|
new TrainingPipeline(steps)
|
||||||
.beforeEpoch(ctx -> ctx.globalLoss = 0.0F)
|
.stopCondition(ctx -> ctx.globalLoss <= 0.0001F || ctx.epoch > epoch)
|
||||||
.afterEpoch(ctx -> ctx.globalLoss = ctx.localLoss/dataset.size())
|
.beforeEpoch(ctx -> {
|
||||||
.stopCondition(ctx -> ctx.epoch > 1000000)
|
ctx.globalLoss = 0.0F;
|
||||||
.withVerbose(false)
|
})
|
||||||
|
.afterEpoch(ctx -> ctx.globalLoss /= dataset.size())
|
||||||
|
.withVerbose(true, 100)
|
||||||
.withTimeMeasurement(true)
|
.withTimeMeasurement(true)
|
||||||
.run(context);
|
.run(context);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ public class GradientDescentTraining implements Trainer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void train(float learningRate, Model model, DataSet dataset) {
|
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
|
||||||
GradientDescentTrainingContext context = new GradientDescentTrainingContext();
|
GradientDescentTrainingContext context = new GradientDescentTrainingContext();
|
||||||
context.dataset = dataset;
|
context.dataset = dataset;
|
||||||
context.model = model;
|
context.model = model;
|
||||||
@@ -38,7 +38,7 @@ public class GradientDescentTraining implements Trainer {
|
|||||||
);
|
);
|
||||||
|
|
||||||
new TrainingPipeline(steps)
|
new TrainingPipeline(steps)
|
||||||
.stopCondition(ctx -> ctx.globalLoss <= 0.08F || ctx.epoch > 150)
|
.stopCondition(ctx -> ctx.globalLoss <= 0.08F || ctx.epoch > epoch)
|
||||||
.beforeEpoch(ctx -> {
|
.beforeEpoch(ctx -> {
|
||||||
GradientDescentTrainingContext gdCtx = (GradientDescentTrainingContext) ctx;
|
GradientDescentTrainingContext gdCtx = (GradientDescentTrainingContext) ctx;
|
||||||
gdCtx.globalLoss = 0.0F;
|
gdCtx.globalLoss = 0.0F;
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ public class SimpleTraining implements Trainer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void train(float learningRate, Model model, DataSet dataset) {
|
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
|
||||||
SimpleTrainingContext context = new SimpleTrainingContext();
|
SimpleTrainingContext context = new SimpleTrainingContext();
|
||||||
context.dataset = dataset;
|
context.dataset = dataset;
|
||||||
context.model = model;
|
context.model = model;
|
||||||
@@ -32,9 +32,9 @@ public class SimpleTraining implements Trainer {
|
|||||||
|
|
||||||
TrainingPipeline pipeline = new TrainingPipeline(steps);
|
TrainingPipeline pipeline = new TrainingPipeline(steps);
|
||||||
pipeline
|
pipeline
|
||||||
.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 10)
|
.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > epoch)
|
||||||
.beforeEpoch(ctx -> ctx.globalLoss = 0)
|
.beforeEpoch(ctx -> ctx.globalLoss = 0)
|
||||||
.withVerbose(true)
|
.withVerbose(true, 1)
|
||||||
.run(context);
|
.run(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user