Implement gradient backpropagation stub
This commit is contained in:
@@ -2,6 +2,9 @@ package com.naaturel.ANN;
|
||||
|
||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||
import com.naaturel.ANN.implementation.multiLayers.Sigmoid;
|
||||
import com.naaturel.ANN.implementation.multiLayers.TanH;
|
||||
import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
||||
import com.naaturel.ANN.domain.model.neuron.*;
|
||||
@@ -35,7 +38,7 @@ public class Main {
|
||||
|
||||
Bias bias = new Bias(new Weight(0));
|
||||
|
||||
Neuron n = new Neuron(syns, bias, new Linear());
|
||||
Neuron n = new Neuron(syns, bias, new Sigmoid(1));
|
||||
neurons.add(n);
|
||||
}
|
||||
Layer layer = new Layer(neurons);
|
||||
@@ -43,7 +46,7 @@ public class Main {
|
||||
}
|
||||
Network network = new Network(layers);
|
||||
|
||||
Trainer trainer = new GradientDescentTraining();
|
||||
Trainer trainer = new GradientBackpropagationTraining();
|
||||
trainer.train(network, dataset);
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user