Implement batch size
This commit is contained in:
@@ -1,5 +1,7 @@
|
|||||||
package com.naaturel.ANN;
|
package com.naaturel.ANN;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.Model;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.Network;
|
||||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||||
import com.naaturel.ANN.implementation.gradientDescent.Linear;
|
import com.naaturel.ANN.implementation.gradientDescent.Linear;
|
||||||
@@ -22,11 +24,21 @@ public class Main {
|
|||||||
int nbrClass = 1;
|
int nbrClass = 1;
|
||||||
|
|
||||||
DataSet dataset = new DatasetExtractor()
|
DataSet dataset = new DatasetExtractor()
|
||||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_9.csv", nbrClass);
|
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_12.csv", nbrClass);
|
||||||
|
|
||||||
int[] neuronPerLayer = new int[]{10, 5, 5, dataset.getNbrLabels()};
|
int[] neuronPerLayer = new int[]{50, 50, 50, dataset.getNbrLabels()};
|
||||||
int nbrInput = dataset.getNbrInputs();
|
int nbrInput = dataset.getNbrInputs();
|
||||||
|
|
||||||
|
FullyConnectedNetwork network = createNetwork(neuronPerLayer, nbrInput);
|
||||||
|
|
||||||
|
Trainer trainer = new GradientBackpropagationTraining();
|
||||||
|
trainer.train(0.01F, 2000, network, dataset);
|
||||||
|
|
||||||
|
//plotGraph(dataset, network);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){
|
||||||
List<Layer> layers = new ArrayList<>();
|
List<Layer> layers = new ArrayList<>();
|
||||||
for (int i = 0; i < neuronPerLayer.length; i++){
|
for (int i = 0; i < neuronPerLayer.length; i++){
|
||||||
|
|
||||||
@@ -49,26 +61,27 @@ public class Main {
|
|||||||
layers.add(layer);
|
layers.add(layer);
|
||||||
}
|
}
|
||||||
|
|
||||||
FullyConnectedNetwork network = new FullyConnectedNetwork(layers.toArray(new Layer[0]));
|
return new FullyConnectedNetwork(layers.toArray(new Layer[0]));
|
||||||
|
}
|
||||||
Trainer trainer = new GradientBackpropagationTraining();
|
|
||||||
trainer.train(0.0005F, 15000, network, dataset);
|
|
||||||
|
|
||||||
|
private static void plotGraph(DataSet dataset, Model network){
|
||||||
GraphVisualizer visualizer = new GraphVisualizer();
|
GraphVisualizer visualizer = new GraphVisualizer();
|
||||||
|
|
||||||
for (DataSetEntry entry : dataset) {
|
for (DataSetEntry entry : dataset) {
|
||||||
List<Float> label = dataset.getLabelsAsFloat(entry);
|
List<Float> label = dataset.getLabelsAsFloat(entry);
|
||||||
visualizer.addPoint("Label " + label.getFirst(), entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
|
label.forEach(l -> {
|
||||||
|
visualizer.addPoint("Label " + l,
|
||||||
|
entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
float min = 0F;
|
float min = -5F;
|
||||||
float max = 15F;
|
float max = 5F;
|
||||||
float step = 0.03F;
|
float step = 0.01F;
|
||||||
for (float x = min; x < max; x+=step){
|
for (float x = min; x < max; x+=step){
|
||||||
for (float y = min; y < max; y+=step){
|
for (float y = min; y < max; y+=step){
|
||||||
float prediction = network.predict(List.of(new Input(x), new Input(y))).getFirst();
|
List<Float> predictions = network.predict(List.of(new Input(x), new Input(y)));
|
||||||
float predSeries = prediction > 0.5F ? 1 : -1;
|
visualizer.addPoint(Float.toString(Math.round(predictions.getFirst())), x, y);
|
||||||
visualizer.addPoint(Float.toString(predSeries), x, y);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -25,8 +25,4 @@ public class Synapse {
|
|||||||
public void setWeight(float value){
|
public void setWeight(float value){
|
||||||
this.weight.setValue(value);
|
this.weight.setValue(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ package com.naaturel.ANN.implementation.multiLayers;
|
|||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||||
|
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
|
||||||
public class BackpropagationCorrectionStep implements AlgorithmStep {
|
public class BackpropagationCorrectionStep implements AlgorithmStep {
|
||||||
|
|
||||||
private GradientBackpropagationContext context;
|
private GradientBackpropagationContext context;
|
||||||
@@ -12,13 +14,29 @@ public class BackpropagationCorrectionStep implements AlgorithmStep {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void run() {
|
public void run() {
|
||||||
|
|
||||||
|
AtomicInteger synIndex = new AtomicInteger(0);
|
||||||
this.context.model.forEachNeuron(n -> {
|
this.context.model.forEachNeuron(n -> {
|
||||||
|
float signal = context.errorSignals.get(n);
|
||||||
n.forEachSynapse(syn -> {
|
n.forEachSynapse(syn -> {
|
||||||
float lr = context.learningRate;
|
float lr = context.learningRate;
|
||||||
float signal = context.errorSignals.get(n);
|
float corrector = lr * signal * syn.getInput();
|
||||||
float newWeight = syn.getWeight() + (lr * signal * syn.getInput());
|
float existingCorrector = context.correctionBuffer[synIndex.get()];
|
||||||
|
float newCorrector = existingCorrector + corrector;
|
||||||
|
|
||||||
|
if(context.currentSample >= context.batchSize){
|
||||||
|
float newWeight = syn.getWeight() + newCorrector;
|
||||||
syn.setWeight(newWeight);
|
syn.setWeight(newWeight);
|
||||||
|
context.correctionBuffer[synIndex.get()] = 0;
|
||||||
|
} else {
|
||||||
|
context.correctionBuffer[synIndex.get()] = newCorrector;
|
||||||
|
}
|
||||||
|
synIndex.incrementAndGet();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
if(context.currentSample >= context.batchSize) {
|
||||||
|
context.currentSample = 0;
|
||||||
|
}
|
||||||
|
context.currentSample += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,11 @@
|
|||||||
|
package com.naaturel.ANN.implementation.multiLayers;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||||
|
|
||||||
|
public class BatchAccumulatorStep implements AlgorithmStep {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void run() {
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,14 +1,29 @@
|
|||||||
package com.naaturel.ANN.implementation.multiLayers;
|
package com.naaturel.ANN.implementation.multiLayers;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.Model;
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||||
|
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
public class GradientBackpropagationContext extends TrainingContext {
|
public class GradientBackpropagationContext extends TrainingContext {
|
||||||
|
|
||||||
public Map<Neuron, Float> errorSignals;
|
public final Map<Neuron, Float> errorSignals;
|
||||||
|
public final float[] correctionBuffer;
|
||||||
|
|
||||||
public GradientBackpropagationContext(){
|
public int currentSample;
|
||||||
|
public int batchSize;
|
||||||
|
|
||||||
|
public GradientBackpropagationContext(Model model, DataSet dataSet, float learningRate, int batchSize){
|
||||||
|
this.model = model;
|
||||||
|
this.dataset = dataSet;
|
||||||
|
this.learningRate = learningRate;
|
||||||
|
this.batchSize = batchSize;
|
||||||
|
|
||||||
|
this.errorSignals = new HashMap<>();
|
||||||
|
this.correctionBuffer = new float[model.synCount()];
|
||||||
|
this.currentSample = 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ public class OutputLayerErrorStep implements AlgorithmStep {
|
|||||||
List<Float> expectations = this.context.dataset.getLabelsAsFloat(entry);
|
List<Float> expectations = this.context.dataset.getLabelsAsFloat(entry);
|
||||||
AtomicInteger index = new AtomicInteger(0);
|
AtomicInteger index = new AtomicInteger(0);
|
||||||
|
|
||||||
context.errorSignals = new HashMap<>();
|
context.errorSignals.clear();
|
||||||
this.context.model.forEachOutputNeurons(n -> {
|
this.context.model.forEachOutputNeurons(n -> {
|
||||||
float expected = expectations.get(index.get());
|
float expected = expectations.get(index.get());
|
||||||
float predicted = n.getOutput();
|
float predicted = n.getOutput();
|
||||||
|
|||||||
@@ -16,10 +16,8 @@ import java.util.List;
|
|||||||
public class GradientBackpropagationTraining implements Trainer {
|
public class GradientBackpropagationTraining implements Trainer {
|
||||||
@Override
|
@Override
|
||||||
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
|
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
|
||||||
GradientBackpropagationContext context = new GradientBackpropagationContext();
|
GradientBackpropagationContext context =
|
||||||
context.dataset = dataset;
|
new GradientBackpropagationContext(model, dataset, learningRate, dataset.size());
|
||||||
context.model = model;
|
|
||||||
context.learningRate = learningRate;
|
|
||||||
|
|
||||||
List<AlgorithmStep> steps = List.of(
|
List<AlgorithmStep> steps = List.of(
|
||||||
new SimplePredictionStep(context),
|
new SimplePredictionStep(context),
|
||||||
@@ -34,7 +32,9 @@ public class GradientBackpropagationTraining implements Trainer {
|
|||||||
.beforeEpoch(ctx -> {
|
.beforeEpoch(ctx -> {
|
||||||
ctx.globalLoss = 0.0F;
|
ctx.globalLoss = 0.0F;
|
||||||
})
|
})
|
||||||
.afterEpoch(ctx -> ctx.globalLoss /= dataset.size())
|
.afterEpoch(ctx -> {
|
||||||
|
ctx.globalLoss /= dataset.size();
|
||||||
|
})
|
||||||
.withVerbose(true,epoch/10)
|
.withVerbose(true,epoch/10)
|
||||||
.withTimeMeasurement(true)
|
.withTimeMeasurement(true)
|
||||||
.run(context);
|
.run(context);
|
||||||
|
|||||||
@@ -19,10 +19,3 @@
|
|||||||
4,7,-1
|
4,7,-1
|
||||||
4,9,1
|
4,9,1
|
||||||
4,10,1
|
4,10,1
|
||||||
2,6,-1
|
|
||||||
7,7,-1
|
|
||||||
5,9,1
|
|
||||||
9,10,1
|
|
||||||
7,1,-1
|
|
||||||
5,0,1
|
|
||||||
9,5,1
|
|
||||||
|
Reference in New Issue
Block a user