Minor perfomance improvements
This commit is contained in:
@@ -11,6 +11,7 @@ import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
||||
import com.naaturel.ANN.domain.model.neuron.*;
|
||||
import com.naaturel.ANN.infrastructure.graph.GraphVisualizer;
|
||||
|
||||
import java.io.Console;
|
||||
import java.util.*;
|
||||
|
||||
public class Main {
|
||||
@@ -40,19 +41,19 @@ public class Main {
|
||||
|
||||
Bias bias = new Bias(new Weight());
|
||||
|
||||
Neuron n = new Neuron(syns, bias, new TanH());
|
||||
Neuron n = new Neuron(syns.toArray(new Synapse[0]), bias, new TanH());
|
||||
neurons.add(n);
|
||||
}
|
||||
Layer layer = new Layer(neurons);
|
||||
Layer layer = new Layer(neurons.toArray(new Neuron[0]));
|
||||
layers.add(layer);
|
||||
}
|
||||
|
||||
FullyConnectedNetwork network = new FullyConnectedNetwork(layers);
|
||||
FullyConnectedNetwork network = new FullyConnectedNetwork(layers.toArray(new Layer[0]));
|
||||
|
||||
Trainer trainer = new GradientBackpropagationTraining();
|
||||
trainer.train(0.001F, 1000, network, dataset);
|
||||
trainer.train(0.01F, 5000, network, dataset);
|
||||
|
||||
/*GraphVisualizer visualizer = new GraphVisualizer();
|
||||
GraphVisualizer visualizer = new GraphVisualizer();
|
||||
|
||||
for (DataSetEntry entry : dataset) {
|
||||
List<Float> label = dataset.getLabelsAsFloat(entry);
|
||||
@@ -71,7 +72,7 @@ public class Main {
|
||||
}
|
||||
|
||||
|
||||
visualizer.buildScatterGraph();*/
|
||||
visualizer.buildScatterGraph();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -14,10 +14,10 @@ import java.util.function.Consumer;
|
||||
*/
|
||||
public class FullyConnectedNetwork implements Model {
|
||||
|
||||
private final List<Layer> layers;
|
||||
private final Layer[] layers;
|
||||
private final Map<Neuron, List<Neuron>> connectionMap;
|
||||
private final Map<Neuron, Integer> layerIndexByNeuron;
|
||||
public FullyConnectedNetwork(List<Layer> layers) {
|
||||
public FullyConnectedNetwork(Layer[] layers) {
|
||||
this.layers = layers;
|
||||
this.connectionMap = this.createConnectionMap();
|
||||
this.layerIndexByNeuron = this.createNeuronIndex();
|
||||
@@ -53,17 +53,22 @@ public class FullyConnectedNetwork implements Model {
|
||||
|
||||
@Override
|
||||
public void forEachSynapse(Consumer<Synapse> consumer) {
|
||||
this.layers.forEach(layer -> layer.forEachSynapse(consumer));
|
||||
for(Layer l : this.layers){
|
||||
l.forEachSynapse(consumer);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void forEachNeuron(Consumer<Neuron> consumer) {
|
||||
this.layers.forEach(layer -> layer.forEachNeuron(consumer));
|
||||
for(Layer l : this.layers){
|
||||
l.forEachNeuron(consumer);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void forEachOutputNeurons(Consumer<Neuron> consumer) {
|
||||
this.layers.getLast().forEachNeuron(consumer);
|
||||
int lastIndex = this.layers.length-1;
|
||||
this.layers[lastIndex].forEachNeuron(consumer);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -74,16 +79,16 @@ public class FullyConnectedNetwork implements Model {
|
||||
@Override
|
||||
public int indexInLayerOf(Neuron n) {
|
||||
int layerIndex = this.layerIndexByNeuron.get(n);
|
||||
return this.layers.get(layerIndex).indexInLayerOf(n);
|
||||
return this.layers[layerIndex].indexInLayerOf(n);
|
||||
}
|
||||
|
||||
private Map<Neuron, List<Neuron>> createConnectionMap() {
|
||||
Map<Neuron, List<Neuron>> res = new HashMap<>();
|
||||
|
||||
for (int i = 0; i < this.layers.size() - 1; i++) {
|
||||
for (int i = 0; i < this.layers.length - 1; i++) {
|
||||
List<Neuron> nextLayerNeurons = new ArrayList<>();
|
||||
this.layers.get(i + 1).forEachNeuron(nextLayerNeurons::add);
|
||||
this.layers.get(i).forEachNeuron(n -> res.put(n, nextLayerNeurons));
|
||||
this.layers[i + 1].forEachNeuron(nextLayerNeurons::add);
|
||||
this.layers[i].forEachNeuron(n -> res.put(n, nextLayerNeurons));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
@@ -91,10 +96,10 @@ public class FullyConnectedNetwork implements Model {
|
||||
private Map<Neuron, Integer> createNeuronIndex() {
|
||||
Map<Neuron, Integer> res = new HashMap<>();
|
||||
AtomicInteger index = new AtomicInteger(0);
|
||||
this.layers.forEach(l -> {
|
||||
for(Layer l : this.layers){
|
||||
l.forEachNeuron(n -> res.put(n, index.get()));
|
||||
index.incrementAndGet();
|
||||
});
|
||||
}
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,10 +10,10 @@ import java.util.function.Consumer;
|
||||
|
||||
public class Layer implements Model {
|
||||
|
||||
private final List<Neuron> neurons;
|
||||
private final Neuron[] neurons;
|
||||
private final Map<Neuron, Integer> neuronIndex;
|
||||
|
||||
public Layer(List<Neuron> neurons) {
|
||||
public Layer(Neuron[] neurons) {
|
||||
this.neurons = neurons;
|
||||
this.neuronIndex = createNeuronIndex();
|
||||
}
|
||||
@@ -39,7 +39,7 @@ public class Layer implements Model {
|
||||
|
||||
@Override
|
||||
public int neuronCount() {
|
||||
return this.neurons.size();
|
||||
return this.neurons.length;
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -49,17 +49,21 @@ public class Layer implements Model {
|
||||
|
||||
@Override
|
||||
public void forEachNeuron(Consumer<Neuron> consumer) {
|
||||
this.neurons.forEach(consumer);
|
||||
for (Neuron n : this.neurons){
|
||||
consumer.accept(n);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void forEachSynapse(Consumer<Synapse> consumer) {
|
||||
this.neurons.forEach(neuron -> neuron.forEachSynapse(consumer));
|
||||
for (Neuron n : this.neurons){
|
||||
n.forEachSynapse(consumer);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void forEachOutputNeurons(Consumer<Neuron> consumer) {
|
||||
this.neurons.forEach(consumer);
|
||||
this.forEachNeuron(consumer);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -70,7 +74,7 @@ public class Layer implements Model {
|
||||
private Map<Neuron, Integer> createNeuronIndex() {
|
||||
Map<Neuron, Integer> res = new HashMap<>();
|
||||
int[] index = {0};
|
||||
this.neurons.forEach(n -> {
|
||||
this.forEachNeuron(n -> {
|
||||
res.put(n, index[0]++);
|
||||
});
|
||||
return res;
|
||||
|
||||
@@ -7,13 +7,13 @@ import java.util.function.Consumer;
|
||||
|
||||
public class Neuron implements Model {
|
||||
|
||||
protected List<Synapse> synapses;
|
||||
protected Synapse[] synapses;
|
||||
protected Bias bias;
|
||||
protected ActivationFunction activationFunction;
|
||||
protected Float output;
|
||||
protected Float weightedSum;
|
||||
|
||||
public Neuron(List<Synapse> synapses, Bias bias, ActivationFunction func){
|
||||
public Neuron(Synapse[] synapses, Bias bias, ActivationFunction func){
|
||||
this.synapses = synapses;
|
||||
this.bias = bias;
|
||||
this.activationFunction = func;
|
||||
@@ -26,12 +26,12 @@ public class Neuron implements Model {
|
||||
}
|
||||
|
||||
public void updateWeight(int index, Weight weight) {
|
||||
this.synapses.get(index).setWeight(weight.getValue());
|
||||
this.synapses[index].setWeight(weight.getValue());
|
||||
}
|
||||
|
||||
protected void setInputs(List<Input> inputs){
|
||||
for(int i = 0; i < inputs.size() && i < synapses.size(); i++){
|
||||
Synapse syn = this.synapses.get(i);
|
||||
for(int i = 0; i < inputs.size() && i < synapses.length; i++){
|
||||
Synapse syn = this.synapses[i];
|
||||
syn.setInput(inputs.get(i));
|
||||
}
|
||||
}
|
||||
@@ -45,7 +45,7 @@ public class Neuron implements Model {
|
||||
}
|
||||
|
||||
public float getWeight(int index){
|
||||
return this.synapses.get(index).getWeight();
|
||||
return this.synapses[index].getWeight();
|
||||
}
|
||||
|
||||
public float getWeightedSum(){
|
||||
@@ -63,7 +63,7 @@ public class Neuron implements Model {
|
||||
|
||||
@Override
|
||||
public int synCount() {
|
||||
return this.synapses.size()+1; //take the bias into account
|
||||
return this.synapses.length+1; //take the bias into account
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -91,7 +91,9 @@ public class Neuron implements Model {
|
||||
@Override
|
||||
public void forEachSynapse(Consumer<Synapse> consumer) {
|
||||
consumer.accept(this.bias);
|
||||
this.synapses.forEach(consumer);
|
||||
for (Synapse syn : this.synapses){
|
||||
consumer.accept(syn);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -30,12 +30,12 @@ public class GradientBackpropagationTraining implements Trainer {
|
||||
);
|
||||
|
||||
new TrainingPipeline(steps)
|
||||
.stopCondition(ctx -> ctx.globalLoss <= 0.0001F || ctx.epoch > epoch)
|
||||
.stopCondition(ctx -> ctx.globalLoss <= 0.00F || ctx.epoch > epoch)
|
||||
.beforeEpoch(ctx -> {
|
||||
ctx.globalLoss = 0.0F;
|
||||
})
|
||||
.afterEpoch(ctx -> ctx.globalLoss /= dataset.size())
|
||||
.withVerbose(true, 100)
|
||||
.withVerbose(false, epoch/10)
|
||||
.withTimeMeasurement(true)
|
||||
.run(context);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user