Fix multi layer implementation
This commit is contained in:
@@ -17,12 +17,12 @@ public class Main {
|
||||
|
||||
public static void main(String[] args){
|
||||
|
||||
int nbrClass = 1;
|
||||
int nbrClass = 3;
|
||||
|
||||
DataSet dataset = new DatasetExtractor()
|
||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_12.csv", nbrClass);
|
||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_14.csv", nbrClass);
|
||||
|
||||
int[] neuronPerLayer = new int[]{10, dataset.getNbrLabels()};
|
||||
int[] neuronPerLayer = new int[]{3, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 37, dataset.getNbrLabels()};
|
||||
int nbrInput = dataset.getNbrInputs();
|
||||
|
||||
List<Layer> layers = new ArrayList<>();
|
||||
@@ -40,7 +40,7 @@ public class Main {
|
||||
|
||||
Bias bias = new Bias(new Weight());
|
||||
|
||||
Neuron n = new Neuron(syns, bias, new Sigmoid(2));
|
||||
Neuron n = new Neuron(syns, bias, new TanH());
|
||||
neurons.add(n);
|
||||
}
|
||||
Layer layer = new Layer(neurons);
|
||||
@@ -50,7 +50,7 @@ public class Main {
|
||||
FullyConnectedNetwork network = new FullyConnectedNetwork(layers);
|
||||
|
||||
Trainer trainer = new GradientBackpropagationTraining();
|
||||
trainer.train(0.5F, network, dataset);
|
||||
trainer.train(0.001F, 1000, network, dataset);
|
||||
|
||||
/*GraphVisualizer visualizer = new GraphVisualizer();
|
||||
|
||||
@@ -59,7 +59,7 @@ public class Main {
|
||||
visualizer.addPoint("Label " + label.getFirst(), entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
|
||||
}
|
||||
|
||||
float min = -2F;
|
||||
float min = -3F;
|
||||
float max = 2F;
|
||||
float step = 0.01F;
|
||||
for (float x = min; x < max; x+=step){
|
||||
|
||||
@@ -10,7 +10,7 @@ import java.util.function.Consumer;
|
||||
public interface Model {
|
||||
int synCount();
|
||||
int neuronCount();
|
||||
int indexOf(Neuron n);
|
||||
int indexInLayerOf(Neuron n);
|
||||
void forEachNeuron(Consumer<Neuron> consumer);
|
||||
void forEachSynapse(Consumer<Synapse> consumer);
|
||||
void forEachOutputNeurons(Consumer<Neuron> consumer);
|
||||
|
||||
@@ -3,5 +3,5 @@ package com.naaturel.ANN.domain.abstraction;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
|
||||
public interface Trainer {
|
||||
void train(float learningRate, Model model, DataSet dataset);
|
||||
void train(float learningRate, int epoch, Model model, DataSet dataset);
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
/**
|
||||
@@ -13,13 +14,13 @@ import java.util.function.Consumer;
|
||||
*/
|
||||
public class FullyConnectedNetwork implements Model {
|
||||
|
||||
private final List<Layer> layers;;
|
||||
private final List<Layer> layers;
|
||||
private final Map<Neuron, List<Neuron>> connectionMap;
|
||||
private final Map<Neuron, Integer> neuronIndex;
|
||||
private final Map<Neuron, Integer> layerIndexByNeuron;
|
||||
public FullyConnectedNetwork(List<Layer> layers) {
|
||||
this.layers = layers;
|
||||
this.connectionMap = this.createConnectionMap();
|
||||
this.neuronIndex = this.createNeuronIndex();
|
||||
this.layerIndexByNeuron = this.createNeuronIndex();
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -71,8 +72,9 @@ public class FullyConnectedNetwork implements Model {
|
||||
}
|
||||
|
||||
@Override
|
||||
public int indexOf(Neuron n) {
|
||||
return this.neuronIndex.get(n);
|
||||
public int indexInLayerOf(Neuron n) {
|
||||
int layerIndex = this.layerIndexByNeuron.get(n);
|
||||
return this.layers.get(layerIndex).indexInLayerOf(n);
|
||||
}
|
||||
|
||||
private Map<Neuron, List<Neuron>> createConnectionMap() {
|
||||
@@ -83,14 +85,16 @@ public class FullyConnectedNetwork implements Model {
|
||||
this.layers.get(i + 1).forEachNeuron(nextLayerNeurons::add);
|
||||
this.layers.get(i).forEachNeuron(n -> res.put(n, nextLayerNeurons));
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
private Map<Neuron, Integer> createNeuronIndex() {
|
||||
Map<Neuron, Integer> res = new HashMap<>();
|
||||
int[] index = {0};
|
||||
this.layers.forEach(l -> l.forEachNeuron(n -> res.put(n, index[0]++)));
|
||||
AtomicInteger index = new AtomicInteger(0);
|
||||
this.layers.forEach(l -> {
|
||||
l.forEachNeuron(n -> res.put(n, index.get()));
|
||||
index.incrementAndGet();
|
||||
});
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,15 +3,19 @@ package com.naaturel.ANN.domain.model.neuron;
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
public class Layer implements Model {
|
||||
|
||||
private final List<Neuron> neurons;
|
||||
private final Map<Neuron, Integer> neuronIndex;
|
||||
|
||||
public Layer(List<Neuron> neurons) {
|
||||
this.neurons = neurons;
|
||||
this.neuronIndex = createNeuronIndex();
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -39,8 +43,8 @@ public class Layer implements Model {
|
||||
}
|
||||
|
||||
@Override
|
||||
public int indexOf(Neuron n) {
|
||||
return this.neurons.indexOf(n);
|
||||
public int indexInLayerOf(Neuron n) {
|
||||
return this.neuronIndex.get(n);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -62,4 +66,14 @@ public class Layer implements Model {
|
||||
public void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer) {
|
||||
throw new UnsupportedOperationException("Neurons have no connection within the same layer");
|
||||
}
|
||||
|
||||
private Map<Neuron, Integer> createNeuronIndex() {
|
||||
Map<Neuron, Integer> res = new HashMap<>();
|
||||
int[] index = {0};
|
||||
this.neurons.forEach(n -> {
|
||||
res.put(n, index[0]++);
|
||||
});
|
||||
return res;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -72,7 +72,7 @@ public class Neuron implements Model {
|
||||
}
|
||||
|
||||
@Override
|
||||
public int indexOf(Neuron n) {
|
||||
public int indexInLayerOf(Neuron n) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@@ -20,11 +20,13 @@ public class TrainingPipeline {
|
||||
private Consumer<TrainingContext> afterEpoch;
|
||||
private Predicate<TrainingContext> stopCondition;
|
||||
|
||||
private GraphVisualizer visualizer;
|
||||
private boolean verbose;
|
||||
private boolean visualization;
|
||||
private boolean timeMeasurement;
|
||||
|
||||
private GraphVisualizer visualizer;
|
||||
private int verboseDelay;
|
||||
|
||||
public TrainingPipeline(List<AlgorithmStep> steps) {
|
||||
this.steps = new ArrayList<>(steps);
|
||||
this.stopCondition = (ctx) -> false;
|
||||
@@ -47,8 +49,10 @@ public class TrainingPipeline {
|
||||
return this;
|
||||
}
|
||||
|
||||
public TrainingPipeline withVerbose(boolean enabled) {
|
||||
public TrainingPipeline withVerbose(boolean enabled, int epochDelay) {
|
||||
if(epochDelay <= 0) throw new IllegalArgumentException("Epoch delay cannot lower or equal to 0");
|
||||
this.verbose = enabled;
|
||||
this.verboseDelay = epochDelay;
|
||||
return this;
|
||||
}
|
||||
|
||||
@@ -71,9 +75,10 @@ public class TrainingPipeline {
|
||||
this.beforeEpoch.accept(ctx);
|
||||
this.executeSteps(ctx);
|
||||
this.afterEpoch.accept(ctx);
|
||||
if(this.verbose) {
|
||||
if(this.verbose && ctx.epoch % this.verboseDelay == 0) {
|
||||
System.out.printf("[Global error] : %f\n", ctx.globalLoss);
|
||||
}
|
||||
ctx.epoch += 1;
|
||||
} while (!this.stopCondition.test(ctx));
|
||||
|
||||
if(this.timeMeasurement) {
|
||||
@@ -94,7 +99,7 @@ public class TrainingPipeline {
|
||||
step.run();
|
||||
}
|
||||
|
||||
if(this.verbose) {
|
||||
if(this.verbose && ctx.epoch % this.verboseDelay == 0) {
|
||||
System.out.printf("Epoch : %d, ", ctx.epoch);
|
||||
System.out.printf("predicted : %s, ", Arrays.toString(ctx.predictions.toArray()));
|
||||
System.out.printf("expected : %s, ", Arrays.toString(ctx.expectations.toArray()));
|
||||
@@ -102,7 +107,6 @@ public class TrainingPipeline {
|
||||
System.out.printf("loss : %.5f\n", ctx.localLoss);
|
||||
}
|
||||
}
|
||||
ctx.epoch += 1;
|
||||
}
|
||||
|
||||
private void visualize(TrainingContext ctx){
|
||||
|
||||
@@ -18,5 +18,6 @@ public class SquareLossStep implements AlgorithmStep {
|
||||
Stream<Float> deltaStream = this.context.deltas.stream();
|
||||
this.context.localLoss = deltaStream.reduce(0.0F, (acc, d) -> (float) (acc + Math.pow(d, 2)));
|
||||
this.context.localLoss /= 2;
|
||||
this.context.globalLoss += this.context.localLoss; //broke MSE en gradientDescentTraining
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ public class BackpropagationCorrectionStep implements AlgorithmStep {
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
this.context.model.forEachOutputNeurons(n -> {
|
||||
this.context.model.forEachNeuron(n -> {
|
||||
n.forEachSynapse(syn -> {
|
||||
float lr = context.learningRate;
|
||||
float signal = context.errorSignals.get(n);
|
||||
|
||||
@@ -2,13 +2,8 @@ package com.naaturel.ANN.implementation.multiLayers;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
public class ErrorSignalStep implements AlgorithmStep {
|
||||
@@ -20,23 +15,19 @@ public class ErrorSignalStep implements AlgorithmStep {
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
this.context.deltas = new ArrayList<>();
|
||||
this.context.errorSignals = new HashMap<>();
|
||||
this.calculateOutputLayerErrorSignals();
|
||||
|
||||
this.context.model.forEachNeuron(n -> calculateErrorSignalRecursive(n, this.context.errorSignals));
|
||||
this.context.model.forEachNeuron(n -> {
|
||||
calculateErrorSignalRecursive(n, this.context.errorSignals);
|
||||
});
|
||||
}
|
||||
|
||||
private float calculateErrorSignalRecursive(Neuron n, Map<Neuron, Float> signals) {
|
||||
if (signals.containsKey(n)) return signals.get(n);
|
||||
|
||||
AtomicInteger connectedIndex = new AtomicInteger(0);
|
||||
int neuronIndex = this.context.model.indexInLayerOf(n);
|
||||
AtomicReference<Float> signalSum = new AtomicReference<>(0F);
|
||||
this.context.model.forEachNeuronConnectedTo(n, connected -> {
|
||||
int neuronIndex = this.context.model.indexOf(n);
|
||||
float weightedSignal = calculateErrorSignalRecursive(connected, signals) * connected.getWeight(neuronIndex);
|
||||
signalSum.set(signalSum.get() + weightedSignal);
|
||||
connectedIndex.incrementAndGet();
|
||||
});
|
||||
|
||||
float derivative = n.getActivationFunction().derivative(n.getOutput());
|
||||
@@ -44,22 +35,4 @@ public class ErrorSignalStep implements AlgorithmStep {
|
||||
signals.put(n, finalSignal);
|
||||
return finalSignal;
|
||||
}
|
||||
|
||||
private void calculateOutputLayerErrorSignals(){
|
||||
DataSetEntry entry = this.context.currentEntry;
|
||||
List<Float> expectations = this.context.dataset.getLabelsAsFloat(entry);
|
||||
AtomicInteger index = new AtomicInteger(0);
|
||||
|
||||
this.context.model.forEachOutputNeurons(n -> {
|
||||
float expected = expectations.get(index.get());
|
||||
float predicted = n.getOutput();
|
||||
float output = n.getOutput();
|
||||
float delta = expected - predicted;
|
||||
float signal = delta * n.getActivationFunction().derivative(output);
|
||||
|
||||
this.context.deltas.add(delta);
|
||||
this.context.errorSignals.put(n, signal);
|
||||
index.incrementAndGet();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
package com.naaturel.ANN.implementation.multiLayers;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
public class OutputLayerErrorStep implements AlgorithmStep {
|
||||
|
||||
private final GradientBackpropagationContext context;
|
||||
|
||||
public OutputLayerErrorStep(GradientBackpropagationContext context){
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
context.deltas = new ArrayList<>();
|
||||
DataSetEntry entry = this.context.currentEntry;
|
||||
List<Float> expectations = this.context.dataset.getLabelsAsFloat(entry);
|
||||
AtomicInteger index = new AtomicInteger(0);
|
||||
|
||||
context.errorSignals = new HashMap<>();
|
||||
this.context.model.forEachOutputNeurons(n -> {
|
||||
float expected = expectations.get(index.get());
|
||||
float predicted = n.getOutput();
|
||||
float output = n.getOutput();
|
||||
float delta = expected - predicted;
|
||||
float signal = delta * n.getActivationFunction().derivative(output);
|
||||
|
||||
this.context.deltas.add(delta);
|
||||
this.context.errorSignals.put(n, signal);
|
||||
index.incrementAndGet();
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -23,7 +23,7 @@ public class AdalineTraining implements Trainer {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void train(float learningRate, Model model, DataSet dataset) {
|
||||
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
|
||||
AdalineTrainingContext context = new AdalineTrainingContext();
|
||||
context.dataset = dataset;
|
||||
context.model = model;
|
||||
@@ -38,11 +38,11 @@ public class AdalineTraining implements Trainer {
|
||||
);
|
||||
|
||||
new TrainingPipeline(steps)
|
||||
.stopCondition(ctx -> ctx.globalLoss <= 0.00F || ctx.epoch > 25)
|
||||
.stopCondition(ctx -> ctx.globalLoss <= 0.00F || ctx.epoch > epoch)
|
||||
.beforeEpoch(ctx -> ctx.globalLoss = 0.0F)
|
||||
.afterEpoch(ctx -> ctx.globalLoss /= context.dataset.size())
|
||||
.withTimeMeasurement(true)
|
||||
.withVerbose(true)
|
||||
.withVerbose(true, 1)
|
||||
.withVisualization(true, new GraphVisualizer())
|
||||
.run(context);
|
||||
}
|
||||
|
||||
@@ -8,14 +8,14 @@ import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep;
|
||||
import com.naaturel.ANN.implementation.multiLayers.BackpropagationCorrectionStep;
|
||||
import com.naaturel.ANN.implementation.multiLayers.GradientBackpropagationContext;
|
||||
import com.naaturel.ANN.implementation.multiLayers.ErrorSignalStep;
|
||||
import com.naaturel.ANN.implementation.multiLayers.OutputLayerErrorStep;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class GradientBackpropagationTraining implements Trainer {
|
||||
@Override
|
||||
public void train(float learningRate, Model model, DataSet dataset) {
|
||||
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
|
||||
GradientBackpropagationContext context = new GradientBackpropagationContext();
|
||||
context.dataset = dataset;
|
||||
context.model = model;
|
||||
@@ -23,18 +23,20 @@ public class GradientBackpropagationTraining implements Trainer {
|
||||
|
||||
List<AlgorithmStep> steps = List.of(
|
||||
new SimplePredictionStep(context),
|
||||
new OutputLayerErrorStep(context),
|
||||
new ErrorSignalStep(context),
|
||||
new BackpropagationCorrectionStep(context),
|
||||
new SquareLossStep(context)
|
||||
);
|
||||
|
||||
new TrainingPipeline(steps)
|
||||
.beforeEpoch(ctx -> ctx.globalLoss = 0.0F)
|
||||
.afterEpoch(ctx -> ctx.globalLoss = ctx.localLoss/dataset.size())
|
||||
.stopCondition(ctx -> ctx.epoch > 1000000)
|
||||
.withVerbose(false)
|
||||
.stopCondition(ctx -> ctx.globalLoss <= 0.0001F || ctx.epoch > epoch)
|
||||
.beforeEpoch(ctx -> {
|
||||
ctx.globalLoss = 0.0F;
|
||||
})
|
||||
.afterEpoch(ctx -> ctx.globalLoss /= dataset.size())
|
||||
.withVerbose(true, 100)
|
||||
.withTimeMeasurement(true)
|
||||
.run(context);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ public class GradientDescentTraining implements Trainer {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void train(float learningRate, Model model, DataSet dataset) {
|
||||
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
|
||||
GradientDescentTrainingContext context = new GradientDescentTrainingContext();
|
||||
context.dataset = dataset;
|
||||
context.model = model;
|
||||
@@ -38,7 +38,7 @@ public class GradientDescentTraining implements Trainer {
|
||||
);
|
||||
|
||||
new TrainingPipeline(steps)
|
||||
.stopCondition(ctx -> ctx.globalLoss <= 0.08F || ctx.epoch > 150)
|
||||
.stopCondition(ctx -> ctx.globalLoss <= 0.08F || ctx.epoch > epoch)
|
||||
.beforeEpoch(ctx -> {
|
||||
GradientDescentTrainingContext gdCtx = (GradientDescentTrainingContext) ctx;
|
||||
gdCtx.globalLoss = 0.0F;
|
||||
|
||||
@@ -16,7 +16,7 @@ public class SimpleTraining implements Trainer {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void train(float learningRate, Model model, DataSet dataset) {
|
||||
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
|
||||
SimpleTrainingContext context = new SimpleTrainingContext();
|
||||
context.dataset = dataset;
|
||||
context.model = model;
|
||||
@@ -32,9 +32,9 @@ public class SimpleTraining implements Trainer {
|
||||
|
||||
TrainingPipeline pipeline = new TrainingPipeline(steps);
|
||||
pipeline
|
||||
.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 10)
|
||||
.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > epoch)
|
||||
.beforeEpoch(ctx -> ctx.globalLoss = 0)
|
||||
.withVerbose(true)
|
||||
.withVerbose(true, 1)
|
||||
.run(context);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user