diff --git a/src/main/java/com/naaturel/ANN/Main.java b/src/main/java/com/naaturel/ANN/Main.java index 723da24..58de78f 100644 --- a/src/main/java/com/naaturel/ANN/Main.java +++ b/src/main/java/com/naaturel/ANN/Main.java @@ -11,6 +11,7 @@ import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor; import com.naaturel.ANN.domain.model.neuron.*; import com.naaturel.ANN.infrastructure.graph.GraphVisualizer; +import java.io.Console; import java.util.*; public class Main { @@ -40,19 +41,19 @@ public class Main { Bias bias = new Bias(new Weight()); - Neuron n = new Neuron(syns, bias, new TanH()); + Neuron n = new Neuron(syns.toArray(new Synapse[0]), bias, new TanH()); neurons.add(n); } - Layer layer = new Layer(neurons); + Layer layer = new Layer(neurons.toArray(new Neuron[0])); layers.add(layer); } - FullyConnectedNetwork network = new FullyConnectedNetwork(layers); + FullyConnectedNetwork network = new FullyConnectedNetwork(layers.toArray(new Layer[0])); Trainer trainer = new GradientBackpropagationTraining(); - trainer.train(0.001F, 1000, network, dataset); + trainer.train(0.01F, 5000, network, dataset); - /*GraphVisualizer visualizer = new GraphVisualizer(); + GraphVisualizer visualizer = new GraphVisualizer(); for (DataSetEntry entry : dataset) { List label = dataset.getLabelsAsFloat(entry); @@ -71,7 +72,7 @@ public class Main { } - visualizer.buildScatterGraph();*/ + visualizer.buildScatterGraph(); } } diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/FullyConnectedNetwork.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/FullyConnectedNetwork.java index f2a63b8..dfd51ef 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/neuron/FullyConnectedNetwork.java +++ b/src/main/java/com/naaturel/ANN/domain/model/neuron/FullyConnectedNetwork.java @@ -14,10 +14,10 @@ import java.util.function.Consumer; */ public class FullyConnectedNetwork implements Model { - private final List layers; + private final Layer[] layers; private final Map> connectionMap; private final Map layerIndexByNeuron; - public FullyConnectedNetwork(List layers) { + public FullyConnectedNetwork(Layer[] layers) { this.layers = layers; this.connectionMap = this.createConnectionMap(); this.layerIndexByNeuron = this.createNeuronIndex(); @@ -53,17 +53,22 @@ public class FullyConnectedNetwork implements Model { @Override public void forEachSynapse(Consumer consumer) { - this.layers.forEach(layer -> layer.forEachSynapse(consumer)); + for(Layer l : this.layers){ + l.forEachSynapse(consumer); + } } @Override public void forEachNeuron(Consumer consumer) { - this.layers.forEach(layer -> layer.forEachNeuron(consumer)); + for(Layer l : this.layers){ + l.forEachNeuron(consumer); + } } @Override public void forEachOutputNeurons(Consumer consumer) { - this.layers.getLast().forEachNeuron(consumer); + int lastIndex = this.layers.length-1; + this.layers[lastIndex].forEachNeuron(consumer); } @Override @@ -74,16 +79,16 @@ public class FullyConnectedNetwork implements Model { @Override public int indexInLayerOf(Neuron n) { int layerIndex = this.layerIndexByNeuron.get(n); - return this.layers.get(layerIndex).indexInLayerOf(n); + return this.layers[layerIndex].indexInLayerOf(n); } private Map> createConnectionMap() { Map> res = new HashMap<>(); - for (int i = 0; i < this.layers.size() - 1; i++) { + for (int i = 0; i < this.layers.length - 1; i++) { List nextLayerNeurons = new ArrayList<>(); - this.layers.get(i + 1).forEachNeuron(nextLayerNeurons::add); - this.layers.get(i).forEachNeuron(n -> res.put(n, nextLayerNeurons)); + this.layers[i + 1].forEachNeuron(nextLayerNeurons::add); + this.layers[i].forEachNeuron(n -> res.put(n, nextLayerNeurons)); } return res; } @@ -91,10 +96,10 @@ public class FullyConnectedNetwork implements Model { private Map createNeuronIndex() { Map res = new HashMap<>(); AtomicInteger index = new AtomicInteger(0); - this.layers.forEach(l -> { + for(Layer l : this.layers){ l.forEachNeuron(n -> res.put(n, index.get())); index.incrementAndGet(); - }); + } return res; } } diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java index 36c487d..3d3039f 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java +++ b/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java @@ -10,10 +10,10 @@ import java.util.function.Consumer; public class Layer implements Model { - private final List neurons; + private final Neuron[] neurons; private final Map neuronIndex; - public Layer(List neurons) { + public Layer(Neuron[] neurons) { this.neurons = neurons; this.neuronIndex = createNeuronIndex(); } @@ -39,7 +39,7 @@ public class Layer implements Model { @Override public int neuronCount() { - return this.neurons.size(); + return this.neurons.length; } @Override @@ -49,17 +49,21 @@ public class Layer implements Model { @Override public void forEachNeuron(Consumer consumer) { - this.neurons.forEach(consumer); + for (Neuron n : this.neurons){ + consumer.accept(n); + } } @Override public void forEachSynapse(Consumer consumer) { - this.neurons.forEach(neuron -> neuron.forEachSynapse(consumer)); + for (Neuron n : this.neurons){ + n.forEachSynapse(consumer); + } } @Override public void forEachOutputNeurons(Consumer consumer) { - this.neurons.forEach(consumer); + this.forEachNeuron(consumer); } @Override @@ -70,7 +74,7 @@ public class Layer implements Model { private Map createNeuronIndex() { Map res = new HashMap<>(); int[] index = {0}; - this.neurons.forEach(n -> { + this.forEachNeuron(n -> { res.put(n, index[0]++); }); return res; diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java index 1ca4327..587febc 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java +++ b/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java @@ -7,13 +7,13 @@ import java.util.function.Consumer; public class Neuron implements Model { - protected List synapses; + protected Synapse[] synapses; protected Bias bias; protected ActivationFunction activationFunction; protected Float output; protected Float weightedSum; - public Neuron(List synapses, Bias bias, ActivationFunction func){ + public Neuron(Synapse[] synapses, Bias bias, ActivationFunction func){ this.synapses = synapses; this.bias = bias; this.activationFunction = func; @@ -26,12 +26,12 @@ public class Neuron implements Model { } public void updateWeight(int index, Weight weight) { - this.synapses.get(index).setWeight(weight.getValue()); + this.synapses[index].setWeight(weight.getValue()); } protected void setInputs(List inputs){ - for(int i = 0; i < inputs.size() && i < synapses.size(); i++){ - Synapse syn = this.synapses.get(i); + for(int i = 0; i < inputs.size() && i < synapses.length; i++){ + Synapse syn = this.synapses[i]; syn.setInput(inputs.get(i)); } } @@ -45,7 +45,7 @@ public class Neuron implements Model { } public float getWeight(int index){ - return this.synapses.get(index).getWeight(); + return this.synapses[index].getWeight(); } public float getWeightedSum(){ @@ -63,7 +63,7 @@ public class Neuron implements Model { @Override public int synCount() { - return this.synapses.size()+1; //take the bias into account + return this.synapses.length+1; //take the bias into account } @Override @@ -91,7 +91,9 @@ public class Neuron implements Model { @Override public void forEachSynapse(Consumer consumer) { consumer.accept(this.bias); - this.synapses.forEach(consumer); + for (Synapse syn : this.synapses){ + consumer.accept(syn); + } } @Override diff --git a/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java index 833ae01..2383b95 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java @@ -30,12 +30,12 @@ public class GradientBackpropagationTraining implements Trainer { ); new TrainingPipeline(steps) - .stopCondition(ctx -> ctx.globalLoss <= 0.0001F || ctx.epoch > epoch) + .stopCondition(ctx -> ctx.globalLoss <= 0.00F || ctx.epoch > epoch) .beforeEpoch(ctx -> { ctx.globalLoss = 0.0F; }) .afterEpoch(ctx -> ctx.globalLoss /= dataset.size()) - .withVerbose(true, 100) + .withVerbose(false, epoch/10) .withTimeMeasurement(true) .run(context); }