Minor perfomance improvements
This commit is contained in:
@@ -11,6 +11,7 @@ import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
|||||||
import com.naaturel.ANN.domain.model.neuron.*;
|
import com.naaturel.ANN.domain.model.neuron.*;
|
||||||
import com.naaturel.ANN.infrastructure.graph.GraphVisualizer;
|
import com.naaturel.ANN.infrastructure.graph.GraphVisualizer;
|
||||||
|
|
||||||
|
import java.io.Console;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
public class Main {
|
public class Main {
|
||||||
@@ -40,19 +41,19 @@ public class Main {
|
|||||||
|
|
||||||
Bias bias = new Bias(new Weight());
|
Bias bias = new Bias(new Weight());
|
||||||
|
|
||||||
Neuron n = new Neuron(syns, bias, new TanH());
|
Neuron n = new Neuron(syns.toArray(new Synapse[0]), bias, new TanH());
|
||||||
neurons.add(n);
|
neurons.add(n);
|
||||||
}
|
}
|
||||||
Layer layer = new Layer(neurons);
|
Layer layer = new Layer(neurons.toArray(new Neuron[0]));
|
||||||
layers.add(layer);
|
layers.add(layer);
|
||||||
}
|
}
|
||||||
|
|
||||||
FullyConnectedNetwork network = new FullyConnectedNetwork(layers);
|
FullyConnectedNetwork network = new FullyConnectedNetwork(layers.toArray(new Layer[0]));
|
||||||
|
|
||||||
Trainer trainer = new GradientBackpropagationTraining();
|
Trainer trainer = new GradientBackpropagationTraining();
|
||||||
trainer.train(0.001F, 1000, network, dataset);
|
trainer.train(0.01F, 5000, network, dataset);
|
||||||
|
|
||||||
/*GraphVisualizer visualizer = new GraphVisualizer();
|
GraphVisualizer visualizer = new GraphVisualizer();
|
||||||
|
|
||||||
for (DataSetEntry entry : dataset) {
|
for (DataSetEntry entry : dataset) {
|
||||||
List<Float> label = dataset.getLabelsAsFloat(entry);
|
List<Float> label = dataset.getLabelsAsFloat(entry);
|
||||||
@@ -71,7 +72,7 @@ public class Main {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
visualizer.buildScatterGraph();*/
|
visualizer.buildScatterGraph();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,10 +14,10 @@ import java.util.function.Consumer;
|
|||||||
*/
|
*/
|
||||||
public class FullyConnectedNetwork implements Model {
|
public class FullyConnectedNetwork implements Model {
|
||||||
|
|
||||||
private final List<Layer> layers;
|
private final Layer[] layers;
|
||||||
private final Map<Neuron, List<Neuron>> connectionMap;
|
private final Map<Neuron, List<Neuron>> connectionMap;
|
||||||
private final Map<Neuron, Integer> layerIndexByNeuron;
|
private final Map<Neuron, Integer> layerIndexByNeuron;
|
||||||
public FullyConnectedNetwork(List<Layer> layers) {
|
public FullyConnectedNetwork(Layer[] layers) {
|
||||||
this.layers = layers;
|
this.layers = layers;
|
||||||
this.connectionMap = this.createConnectionMap();
|
this.connectionMap = this.createConnectionMap();
|
||||||
this.layerIndexByNeuron = this.createNeuronIndex();
|
this.layerIndexByNeuron = this.createNeuronIndex();
|
||||||
@@ -53,17 +53,22 @@ public class FullyConnectedNetwork implements Model {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void forEachSynapse(Consumer<Synapse> consumer) {
|
public void forEachSynapse(Consumer<Synapse> consumer) {
|
||||||
this.layers.forEach(layer -> layer.forEachSynapse(consumer));
|
for(Layer l : this.layers){
|
||||||
|
l.forEachSynapse(consumer);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void forEachNeuron(Consumer<Neuron> consumer) {
|
public void forEachNeuron(Consumer<Neuron> consumer) {
|
||||||
this.layers.forEach(layer -> layer.forEachNeuron(consumer));
|
for(Layer l : this.layers){
|
||||||
|
l.forEachNeuron(consumer);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void forEachOutputNeurons(Consumer<Neuron> consumer) {
|
public void forEachOutputNeurons(Consumer<Neuron> consumer) {
|
||||||
this.layers.getLast().forEachNeuron(consumer);
|
int lastIndex = this.layers.length-1;
|
||||||
|
this.layers[lastIndex].forEachNeuron(consumer);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -74,16 +79,16 @@ public class FullyConnectedNetwork implements Model {
|
|||||||
@Override
|
@Override
|
||||||
public int indexInLayerOf(Neuron n) {
|
public int indexInLayerOf(Neuron n) {
|
||||||
int layerIndex = this.layerIndexByNeuron.get(n);
|
int layerIndex = this.layerIndexByNeuron.get(n);
|
||||||
return this.layers.get(layerIndex).indexInLayerOf(n);
|
return this.layers[layerIndex].indexInLayerOf(n);
|
||||||
}
|
}
|
||||||
|
|
||||||
private Map<Neuron, List<Neuron>> createConnectionMap() {
|
private Map<Neuron, List<Neuron>> createConnectionMap() {
|
||||||
Map<Neuron, List<Neuron>> res = new HashMap<>();
|
Map<Neuron, List<Neuron>> res = new HashMap<>();
|
||||||
|
|
||||||
for (int i = 0; i < this.layers.size() - 1; i++) {
|
for (int i = 0; i < this.layers.length - 1; i++) {
|
||||||
List<Neuron> nextLayerNeurons = new ArrayList<>();
|
List<Neuron> nextLayerNeurons = new ArrayList<>();
|
||||||
this.layers.get(i + 1).forEachNeuron(nextLayerNeurons::add);
|
this.layers[i + 1].forEachNeuron(nextLayerNeurons::add);
|
||||||
this.layers.get(i).forEachNeuron(n -> res.put(n, nextLayerNeurons));
|
this.layers[i].forEachNeuron(n -> res.put(n, nextLayerNeurons));
|
||||||
}
|
}
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
@@ -91,10 +96,10 @@ public class FullyConnectedNetwork implements Model {
|
|||||||
private Map<Neuron, Integer> createNeuronIndex() {
|
private Map<Neuron, Integer> createNeuronIndex() {
|
||||||
Map<Neuron, Integer> res = new HashMap<>();
|
Map<Neuron, Integer> res = new HashMap<>();
|
||||||
AtomicInteger index = new AtomicInteger(0);
|
AtomicInteger index = new AtomicInteger(0);
|
||||||
this.layers.forEach(l -> {
|
for(Layer l : this.layers){
|
||||||
l.forEachNeuron(n -> res.put(n, index.get()));
|
l.forEachNeuron(n -> res.put(n, index.get()));
|
||||||
index.incrementAndGet();
|
index.incrementAndGet();
|
||||||
});
|
}
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,10 +10,10 @@ import java.util.function.Consumer;
|
|||||||
|
|
||||||
public class Layer implements Model {
|
public class Layer implements Model {
|
||||||
|
|
||||||
private final List<Neuron> neurons;
|
private final Neuron[] neurons;
|
||||||
private final Map<Neuron, Integer> neuronIndex;
|
private final Map<Neuron, Integer> neuronIndex;
|
||||||
|
|
||||||
public Layer(List<Neuron> neurons) {
|
public Layer(Neuron[] neurons) {
|
||||||
this.neurons = neurons;
|
this.neurons = neurons;
|
||||||
this.neuronIndex = createNeuronIndex();
|
this.neuronIndex = createNeuronIndex();
|
||||||
}
|
}
|
||||||
@@ -39,7 +39,7 @@ public class Layer implements Model {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int neuronCount() {
|
public int neuronCount() {
|
||||||
return this.neurons.size();
|
return this.neurons.length;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -49,17 +49,21 @@ public class Layer implements Model {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void forEachNeuron(Consumer<Neuron> consumer) {
|
public void forEachNeuron(Consumer<Neuron> consumer) {
|
||||||
this.neurons.forEach(consumer);
|
for (Neuron n : this.neurons){
|
||||||
|
consumer.accept(n);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void forEachSynapse(Consumer<Synapse> consumer) {
|
public void forEachSynapse(Consumer<Synapse> consumer) {
|
||||||
this.neurons.forEach(neuron -> neuron.forEachSynapse(consumer));
|
for (Neuron n : this.neurons){
|
||||||
|
n.forEachSynapse(consumer);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void forEachOutputNeurons(Consumer<Neuron> consumer) {
|
public void forEachOutputNeurons(Consumer<Neuron> consumer) {
|
||||||
this.neurons.forEach(consumer);
|
this.forEachNeuron(consumer);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -70,7 +74,7 @@ public class Layer implements Model {
|
|||||||
private Map<Neuron, Integer> createNeuronIndex() {
|
private Map<Neuron, Integer> createNeuronIndex() {
|
||||||
Map<Neuron, Integer> res = new HashMap<>();
|
Map<Neuron, Integer> res = new HashMap<>();
|
||||||
int[] index = {0};
|
int[] index = {0};
|
||||||
this.neurons.forEach(n -> {
|
this.forEachNeuron(n -> {
|
||||||
res.put(n, index[0]++);
|
res.put(n, index[0]++);
|
||||||
});
|
});
|
||||||
return res;
|
return res;
|
||||||
|
|||||||
@@ -7,13 +7,13 @@ import java.util.function.Consumer;
|
|||||||
|
|
||||||
public class Neuron implements Model {
|
public class Neuron implements Model {
|
||||||
|
|
||||||
protected List<Synapse> synapses;
|
protected Synapse[] synapses;
|
||||||
protected Bias bias;
|
protected Bias bias;
|
||||||
protected ActivationFunction activationFunction;
|
protected ActivationFunction activationFunction;
|
||||||
protected Float output;
|
protected Float output;
|
||||||
protected Float weightedSum;
|
protected Float weightedSum;
|
||||||
|
|
||||||
public Neuron(List<Synapse> synapses, Bias bias, ActivationFunction func){
|
public Neuron(Synapse[] synapses, Bias bias, ActivationFunction func){
|
||||||
this.synapses = synapses;
|
this.synapses = synapses;
|
||||||
this.bias = bias;
|
this.bias = bias;
|
||||||
this.activationFunction = func;
|
this.activationFunction = func;
|
||||||
@@ -26,12 +26,12 @@ public class Neuron implements Model {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public void updateWeight(int index, Weight weight) {
|
public void updateWeight(int index, Weight weight) {
|
||||||
this.synapses.get(index).setWeight(weight.getValue());
|
this.synapses[index].setWeight(weight.getValue());
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void setInputs(List<Input> inputs){
|
protected void setInputs(List<Input> inputs){
|
||||||
for(int i = 0; i < inputs.size() && i < synapses.size(); i++){
|
for(int i = 0; i < inputs.size() && i < synapses.length; i++){
|
||||||
Synapse syn = this.synapses.get(i);
|
Synapse syn = this.synapses[i];
|
||||||
syn.setInput(inputs.get(i));
|
syn.setInput(inputs.get(i));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -45,7 +45,7 @@ public class Neuron implements Model {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public float getWeight(int index){
|
public float getWeight(int index){
|
||||||
return this.synapses.get(index).getWeight();
|
return this.synapses[index].getWeight();
|
||||||
}
|
}
|
||||||
|
|
||||||
public float getWeightedSum(){
|
public float getWeightedSum(){
|
||||||
@@ -63,7 +63,7 @@ public class Neuron implements Model {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int synCount() {
|
public int synCount() {
|
||||||
return this.synapses.size()+1; //take the bias into account
|
return this.synapses.length+1; //take the bias into account
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -91,7 +91,9 @@ public class Neuron implements Model {
|
|||||||
@Override
|
@Override
|
||||||
public void forEachSynapse(Consumer<Synapse> consumer) {
|
public void forEachSynapse(Consumer<Synapse> consumer) {
|
||||||
consumer.accept(this.bias);
|
consumer.accept(this.bias);
|
||||||
this.synapses.forEach(consumer);
|
for (Synapse syn : this.synapses){
|
||||||
|
consumer.accept(syn);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -30,12 +30,12 @@ public class GradientBackpropagationTraining implements Trainer {
|
|||||||
);
|
);
|
||||||
|
|
||||||
new TrainingPipeline(steps)
|
new TrainingPipeline(steps)
|
||||||
.stopCondition(ctx -> ctx.globalLoss <= 0.0001F || ctx.epoch > epoch)
|
.stopCondition(ctx -> ctx.globalLoss <= 0.00F || ctx.epoch > epoch)
|
||||||
.beforeEpoch(ctx -> {
|
.beforeEpoch(ctx -> {
|
||||||
ctx.globalLoss = 0.0F;
|
ctx.globalLoss = 0.0F;
|
||||||
})
|
})
|
||||||
.afterEpoch(ctx -> ctx.globalLoss /= dataset.size())
|
.afterEpoch(ctx -> ctx.globalLoss /= dataset.size())
|
||||||
.withVerbose(true, 100)
|
.withVerbose(false, epoch/10)
|
||||||
.withTimeMeasurement(true)
|
.withTimeMeasurement(true)
|
||||||
.run(context);
|
.run(context);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user