Implement multi layer
This commit is contained in:
@@ -3,6 +3,7 @@ package com.naaturel.ANN;
|
||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||
import com.naaturel.ANN.implementation.multiLayers.Sigmoid;
|
||||
import com.naaturel.ANN.implementation.multiLayers.TanH;
|
||||
import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
||||
@@ -14,33 +15,36 @@ public class Main {
|
||||
|
||||
public static void main(String[] args){
|
||||
|
||||
int nbrInput = 2;
|
||||
int nbrClass = 3;
|
||||
int nbrInput = 25;
|
||||
int nbrClass = 4;
|
||||
|
||||
int nbrLayers = 2;
|
||||
int[] neuronPerLayer = new int[]{10, nbrClass};
|
||||
|
||||
DataSet dataset = new DatasetExtractor()
|
||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_3_1.csv", nbrClass);
|
||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_3_5.csv", nbrClass);
|
||||
|
||||
List<Layer> layers = new ArrayList<>();
|
||||
for(int i = 0; i < nbrLayers; i++){
|
||||
for (int i = 0; i < neuronPerLayer.length; i++){
|
||||
|
||||
List<Neuron> neurons = new ArrayList<>();
|
||||
for (int j=0; j < nbrClass; j++){
|
||||
for (int j = 0; j < neuronPerLayer[i]; j++){
|
||||
|
||||
int nbrSyn = i == 0 ? nbrInput: neuronPerLayer[i-1];
|
||||
|
||||
List<Synapse> syns = new ArrayList<>();
|
||||
for (int k=0; k < nbrInput; k++){
|
||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||
for (int k=0; k < nbrSyn; k++){
|
||||
syns.add(new Synapse(new Input(0), new Weight()));
|
||||
}
|
||||
|
||||
Bias bias = new Bias(new Weight(0));
|
||||
Bias bias = new Bias(new Weight());
|
||||
|
||||
Neuron n = new Neuron(syns, bias, new Sigmoid(1));
|
||||
Neuron n = new Neuron(syns, bias, new TanH());
|
||||
neurons.add(n);
|
||||
}
|
||||
Layer layer = new Layer(neurons);
|
||||
layers.add(layer);
|
||||
}
|
||||
|
||||
FullyConnectedNetwork network = new FullyConnectedNetwork(layers);
|
||||
|
||||
Trainer trainer = new GradientBackpropagationTraining();
|
||||
|
||||
@@ -10,6 +10,7 @@ import java.util.function.Consumer;
|
||||
public interface Model {
|
||||
int synCount();
|
||||
int neuronCount();
|
||||
int indexOf(Neuron n);
|
||||
void forEachNeuron(Consumer<Neuron> consumer);
|
||||
void forEachSynapse(Consumer<Synapse> consumer);
|
||||
void forEachOutputNeurons(Consumer<Neuron> consumer);
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package com.naaturel.ANN.domain.model.neuron;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
import com.naaturel.ANN.domain.abstraction.Network;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
@@ -16,10 +15,11 @@ public class FullyConnectedNetwork implements Model {
|
||||
|
||||
private final List<Layer> layers;;
|
||||
private final Map<Neuron, List<Neuron>> connectionMap;
|
||||
|
||||
private final Map<Neuron, Integer> neuronIndex;
|
||||
public FullyConnectedNetwork(List<Layer> layers) {
|
||||
this.layers = layers;
|
||||
this.connectionMap = this.createConnectionMap();
|
||||
this.neuronIndex = this.createNeuronIndex();
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -70,6 +70,11 @@ public class FullyConnectedNetwork implements Model {
|
||||
this.connectionMap.get(n).forEach(consumer);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int indexOf(Neuron n) {
|
||||
return this.neuronIndex.get(n);
|
||||
}
|
||||
|
||||
private Map<Neuron, List<Neuron>> createConnectionMap() {
|
||||
Map<Neuron, List<Neuron>> res = new HashMap<>();
|
||||
|
||||
@@ -81,4 +86,11 @@ public class FullyConnectedNetwork implements Model {
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
private Map<Neuron, Integer> createNeuronIndex() {
|
||||
Map<Neuron, Integer> res = new HashMap<>();
|
||||
int[] index = {0};
|
||||
this.layers.forEach(l -> l.forEachNeuron(n -> res.put(n, index[0]++)));
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,6 +38,11 @@ public class Layer implements Model {
|
||||
return this.neurons.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int indexOf(Neuron n) {
|
||||
return this.neurons.indexOf(n);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void forEachNeuron(Consumer<Neuron> consumer) {
|
||||
this.neurons.forEach(consumer);
|
||||
|
||||
@@ -11,12 +11,14 @@ public class Neuron implements Model {
|
||||
protected Bias bias;
|
||||
protected ActivationFunction activationFunction;
|
||||
protected Float output;
|
||||
protected Float weightedSum;
|
||||
|
||||
public Neuron(List<Synapse> synapses, Bias bias, ActivationFunction func){
|
||||
this.synapses = synapses;
|
||||
this.bias = bias;
|
||||
this.activationFunction = func;
|
||||
this.output = 0F;
|
||||
this.output = null;
|
||||
this.weightedSum = null;
|
||||
}
|
||||
|
||||
public void updateBias(Weight weight) {
|
||||
@@ -42,13 +44,22 @@ public class Neuron implements Model {
|
||||
return this.output;
|
||||
}
|
||||
|
||||
public float getWeight(int index){
|
||||
return this.synapses.get(index).getWeight();
|
||||
}
|
||||
|
||||
public float getWeightedSum(){
|
||||
return this.weightedSum;
|
||||
}
|
||||
|
||||
public float calculateWeightedSum() {
|
||||
float res = 0;
|
||||
res += this.bias.getWeight() * this.bias.getInput();
|
||||
for(Synapse syn : this.synapses){
|
||||
res += syn.getWeight() * syn.getInput();
|
||||
}
|
||||
return res;
|
||||
this.weightedSum = res;
|
||||
return this.weightedSum;
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -61,6 +72,11 @@ public class Neuron implements Model {
|
||||
return 1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int indexOf(Neuron n) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Float> predict(List<Input> inputs) {
|
||||
this.setInputs(inputs);
|
||||
|
||||
@@ -2,7 +2,6 @@ package com.naaturel.ANN.domain.model.training;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
|
||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||
import com.naaturel.ANN.infrastructure.graph.GraphVisualizer;
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
package com.naaturel.ANN.implementation.multiLayers;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
|
||||
public class BackpropagationCorrectionStep implements AlgorithmStep {
|
||||
|
||||
private GradientBackpropagationContext context;
|
||||
|
||||
public BackpropagationCorrectionStep(GradientBackpropagationContext context){
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
this.context.model.forEachOutputNeurons(n -> {
|
||||
n.forEachSynapse(syn -> {
|
||||
float lr = context.learningRate;
|
||||
float signal = context.errorSignals.get(n);
|
||||
float newWeight = syn.getWeight() + (lr * signal * syn.getInput());
|
||||
syn.setWeight(newWeight);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
package com.naaturel.ANN.implementation.multiLayers;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
public class ErrorSignalStep implements AlgorithmStep {
|
||||
|
||||
private GradientBackpropagationContext context;
|
||||
public ErrorSignalStep(GradientBackpropagationContext context) {
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
this.context.deltas = new ArrayList<>();
|
||||
this.context.errorSignals = new HashMap<>();
|
||||
this.calculateOutputLayerErrorSignals();
|
||||
|
||||
this.context.model.forEachNeuron(n -> calculateErrorSignalRecursive(n, this.context.errorSignals));
|
||||
}
|
||||
|
||||
private float calculateErrorSignalRecursive(Neuron n, Map<Neuron, Float> signals) {
|
||||
if (signals.containsKey(n)) return signals.get(n);
|
||||
|
||||
AtomicInteger connectedIndex = new AtomicInteger(0);
|
||||
AtomicReference<Float> signalSum = new AtomicReference<>(0F);
|
||||
this.context.model.forEachNeuronConnectedTo(n, connected -> {
|
||||
int neuronIndex = this.context.model.indexOf(n);
|
||||
float weightedSignal = calculateErrorSignalRecursive(connected, signals) * connected.getWeight(neuronIndex);
|
||||
signalSum.set(signalSum.get() + weightedSignal);
|
||||
connectedIndex.incrementAndGet();
|
||||
});
|
||||
|
||||
float derivative = n.getActivationFunction().derivative(n.getOutput());
|
||||
float finalSignal = derivative * signalSum.get();
|
||||
signals.put(n, finalSignal);
|
||||
return finalSignal;
|
||||
}
|
||||
|
||||
private void calculateOutputLayerErrorSignals(){
|
||||
DataSetEntry entry = this.context.currentEntry;
|
||||
List<Float> expectations = this.context.dataset.getLabelsAsFloat(entry);
|
||||
AtomicInteger index = new AtomicInteger(0);
|
||||
|
||||
this.context.model.forEachOutputNeurons(n -> {
|
||||
float expected = expectations.get(index.get());
|
||||
float predicted = n.getOutput();
|
||||
float output = n.getOutput();
|
||||
float delta = expected - predicted;
|
||||
float signal = delta * n.getActivationFunction().derivative(output);
|
||||
|
||||
this.context.deltas.add(delta);
|
||||
this.context.errorSignals.put(n, signal);
|
||||
index.incrementAndGet();
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1,15 +1,14 @@
|
||||
package com.naaturel.ANN.implementation.multiLayers;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class GradientBackpropagationContext extends TrainingContext {
|
||||
|
||||
public List<Float> hiddenDeltas;
|
||||
public Map<Neuron, Float> errorSignals;
|
||||
|
||||
public GradientBackpropagationContext(){
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
package com.naaturel.ANN.implementation.multiLayers;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||
|
||||
public class GradientBackpropagationStep implements AlgorithmStep {
|
||||
|
||||
private GradientBackpropagationContext context;
|
||||
public GradientBackpropagationStep(GradientBackpropagationContext context) {
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
private float calculateDeltaRecursive(Neuron n){
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -4,9 +4,10 @@ import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep;
|
||||
import com.naaturel.ANN.implementation.multiLayers.BackpropagationCorrectionStep;
|
||||
import com.naaturel.ANN.implementation.multiLayers.GradientBackpropagationContext;
|
||||
import com.naaturel.ANN.implementation.multiLayers.GradientBackpropagationStep;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
|
||||
import com.naaturel.ANN.implementation.multiLayers.ErrorSignalStep;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
|
||||
@@ -19,17 +20,19 @@ public class GradientBackpropagationTraining implements Trainer {
|
||||
GradientBackpropagationContext context = new GradientBackpropagationContext();
|
||||
context.dataset = dataset;
|
||||
context.model = model;
|
||||
context.learningRate = 0.001F;
|
||||
context.learningRate = 0.1F;
|
||||
|
||||
List<AlgorithmStep> steps = List.of(
|
||||
new SimplePredictionStep(context),
|
||||
new SimpleDeltaStep(context),
|
||||
new GradientBackpropagationStep(context)
|
||||
new ErrorSignalStep(context),
|
||||
new BackpropagationCorrectionStep(context),
|
||||
new SquareLossStep(context)
|
||||
);
|
||||
|
||||
new TrainingPipeline(steps)
|
||||
.stopCondition(ctx -> false)
|
||||
.stopCondition(ctx -> ctx.epoch == 250)
|
||||
.withVerbose(true)
|
||||
.withTimeMeasurement(true)
|
||||
.run(context);
|
||||
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import com.naaturel.ANN.implementation.gradientDescent.GradientDescentCorrection
|
||||
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
|
||||
import com.naaturel.ANN.implementation.training.steps.*;
|
||||
import com.naaturel.ANN.infrastructure.graph.GraphVisualizer;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
||||
@@ -6,7 +6,6 @@ import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.*;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||
import com.naaturel.ANN.implementation.training.steps.*;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user