Implement gradient backpropagation stub
This commit is contained in:
@@ -2,6 +2,9 @@ package com.naaturel.ANN;
|
|||||||
|
|
||||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||||
|
import com.naaturel.ANN.implementation.multiLayers.Sigmoid;
|
||||||
|
import com.naaturel.ANN.implementation.multiLayers.TanH;
|
||||||
|
import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
||||||
import com.naaturel.ANN.domain.model.neuron.*;
|
import com.naaturel.ANN.domain.model.neuron.*;
|
||||||
@@ -35,7 +38,7 @@ public class Main {
|
|||||||
|
|
||||||
Bias bias = new Bias(new Weight(0));
|
Bias bias = new Bias(new Weight(0));
|
||||||
|
|
||||||
Neuron n = new Neuron(syns, bias, new Linear());
|
Neuron n = new Neuron(syns, bias, new Sigmoid(1));
|
||||||
neurons.add(n);
|
neurons.add(n);
|
||||||
}
|
}
|
||||||
Layer layer = new Layer(neurons);
|
Layer layer = new Layer(neurons);
|
||||||
@@ -43,7 +46,7 @@ public class Main {
|
|||||||
}
|
}
|
||||||
Network network = new Network(layers);
|
Network network = new Network(layers);
|
||||||
|
|
||||||
Trainer trainer = new GradientDescentTraining();
|
Trainer trainer = new GradientBackpropagationTraining();
|
||||||
trainer.train(network, dataset);
|
trainer.train(network, dataset);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,18 @@
|
|||||||
|
package com.naaturel.ANN.implementation.multiLayers;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
||||||
|
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||||
|
|
||||||
|
public class Sigmoid implements ActivationFunction {
|
||||||
|
|
||||||
|
private float steepness;
|
||||||
|
|
||||||
|
public Sigmoid(float steepness) {
|
||||||
|
this.steepness = steepness;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public float accept(Neuron n) {
|
||||||
|
return (float) (1.0/(1.0 + Math.exp(-steepness * n.calculateWeightedSum())));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
package com.naaturel.ANN.implementation.multiLayers;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
||||||
|
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||||
|
|
||||||
|
public class TanH implements ActivationFunction {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public float accept(Neuron n) {
|
||||||
|
//For educational purpose. Math.tanh() could have been used here
|
||||||
|
float weightedSum = n.calculateWeightedSum();
|
||||||
|
double exp = Math.exp(weightedSum);
|
||||||
|
double res = (exp-(1/exp))/(exp+(1/exp));
|
||||||
|
return (float)(res);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
package com.naaturel.ANN.implementation.training;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.Model;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||||
|
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||||
|
|
||||||
|
public class GradientBackpropagationTraining implements Trainer {
|
||||||
|
@Override
|
||||||
|
public void train(Model model, DataSet dataset) {
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -50,7 +50,7 @@ public class GradientDescentTraining implements Trainer {
|
|||||||
context.globalLoss /= context.dataset.size();
|
context.globalLoss /= context.dataset.size();
|
||||||
new GradientDescentCorrectionStrategy(context).apply();
|
new GradientDescentCorrectionStrategy(context).apply();
|
||||||
})
|
})
|
||||||
.withVerbose(true)
|
//.withVerbose(true)
|
||||||
.withTimeMeasurement(true)
|
.withTimeMeasurement(true)
|
||||||
.withVisualization(true, new GraphVisualizer())
|
.withVisualization(true, new GraphVisualizer())
|
||||||
.run(context);
|
.run(context);
|
||||||
|
|||||||
Reference in New Issue
Block a user