Implement gradient backpropagation stub
This commit is contained in:
@@ -2,6 +2,9 @@ package com.naaturel.ANN;
|
||||
|
||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||
import com.naaturel.ANN.implementation.multiLayers.Sigmoid;
|
||||
import com.naaturel.ANN.implementation.multiLayers.TanH;
|
||||
import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
||||
import com.naaturel.ANN.domain.model.neuron.*;
|
||||
@@ -35,7 +38,7 @@ public class Main {
|
||||
|
||||
Bias bias = new Bias(new Weight(0));
|
||||
|
||||
Neuron n = new Neuron(syns, bias, new Linear());
|
||||
Neuron n = new Neuron(syns, bias, new Sigmoid(1));
|
||||
neurons.add(n);
|
||||
}
|
||||
Layer layer = new Layer(neurons);
|
||||
@@ -43,7 +46,7 @@ public class Main {
|
||||
}
|
||||
Network network = new Network(layers);
|
||||
|
||||
Trainer trainer = new GradientDescentTraining();
|
||||
Trainer trainer = new GradientBackpropagationTraining();
|
||||
trainer.train(network, dataset);
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
package com.naaturel.ANN.implementation.multiLayers;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||
|
||||
public class Sigmoid implements ActivationFunction {
|
||||
|
||||
private float steepness;
|
||||
|
||||
public Sigmoid(float steepness) {
|
||||
this.steepness = steepness;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float accept(Neuron n) {
|
||||
return (float) (1.0/(1.0 + Math.exp(-steepness * n.calculateWeightedSum())));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package com.naaturel.ANN.implementation.multiLayers;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||
|
||||
public class TanH implements ActivationFunction {
|
||||
|
||||
@Override
|
||||
public float accept(Neuron n) {
|
||||
//For educational purpose. Math.tanh() could have been used here
|
||||
float weightedSum = n.calculateWeightedSum();
|
||||
double exp = Math.exp(weightedSum);
|
||||
double res = (exp-(1/exp))/(exp+(1/exp));
|
||||
return (float)(res);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
package com.naaturel.ANN.implementation.training;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
|
||||
public class GradientBackpropagationTraining implements Trainer {
|
||||
@Override
|
||||
public void train(Model model, DataSet dataset) {
|
||||
|
||||
}
|
||||
}
|
||||
@@ -50,7 +50,7 @@ public class GradientDescentTraining implements Trainer {
|
||||
context.globalLoss /= context.dataset.size();
|
||||
new GradientDescentCorrectionStrategy(context).apply();
|
||||
})
|
||||
.withVerbose(true)
|
||||
//.withVerbose(true)
|
||||
.withTimeMeasurement(true)
|
||||
.withVisualization(true, new GraphVisualizer())
|
||||
.run(context);
|
||||
|
||||
Reference in New Issue
Block a user