Minor perfomance improvements

This commit is contained in:
2026-03-30 22:14:33 +02:00
parent fd97d0853c
commit 881088df28
5 changed files with 46 additions and 34 deletions

View File

@@ -11,6 +11,7 @@ import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.infrastructure.graph.GraphVisualizer;
import java.io.Console;
import java.util.*;
public class Main {
@@ -40,19 +41,19 @@ public class Main {
Bias bias = new Bias(new Weight());
Neuron n = new Neuron(syns, bias, new TanH());
Neuron n = new Neuron(syns.toArray(new Synapse[0]), bias, new TanH());
neurons.add(n);
}
Layer layer = new Layer(neurons);
Layer layer = new Layer(neurons.toArray(new Neuron[0]));
layers.add(layer);
}
FullyConnectedNetwork network = new FullyConnectedNetwork(layers);
FullyConnectedNetwork network = new FullyConnectedNetwork(layers.toArray(new Layer[0]));
Trainer trainer = new GradientBackpropagationTraining();
trainer.train(0.001F, 1000, network, dataset);
trainer.train(0.01F, 5000, network, dataset);
/*GraphVisualizer visualizer = new GraphVisualizer();
GraphVisualizer visualizer = new GraphVisualizer();
for (DataSetEntry entry : dataset) {
List<Float> label = dataset.getLabelsAsFloat(entry);
@@ -71,7 +72,7 @@ public class Main {
}
visualizer.buildScatterGraph();*/
visualizer.buildScatterGraph();
}
}

View File

@@ -14,10 +14,10 @@ import java.util.function.Consumer;
*/
public class FullyConnectedNetwork implements Model {
private final List<Layer> layers;
private final Layer[] layers;
private final Map<Neuron, List<Neuron>> connectionMap;
private final Map<Neuron, Integer> layerIndexByNeuron;
public FullyConnectedNetwork(List<Layer> layers) {
public FullyConnectedNetwork(Layer[] layers) {
this.layers = layers;
this.connectionMap = this.createConnectionMap();
this.layerIndexByNeuron = this.createNeuronIndex();
@@ -53,17 +53,22 @@ public class FullyConnectedNetwork implements Model {
@Override
public void forEachSynapse(Consumer<Synapse> consumer) {
this.layers.forEach(layer -> layer.forEachSynapse(consumer));
for(Layer l : this.layers){
l.forEachSynapse(consumer);
}
}
@Override
public void forEachNeuron(Consumer<Neuron> consumer) {
this.layers.forEach(layer -> layer.forEachNeuron(consumer));
for(Layer l : this.layers){
l.forEachNeuron(consumer);
}
}
@Override
public void forEachOutputNeurons(Consumer<Neuron> consumer) {
this.layers.getLast().forEachNeuron(consumer);
int lastIndex = this.layers.length-1;
this.layers[lastIndex].forEachNeuron(consumer);
}
@Override
@@ -74,16 +79,16 @@ public class FullyConnectedNetwork implements Model {
@Override
public int indexInLayerOf(Neuron n) {
int layerIndex = this.layerIndexByNeuron.get(n);
return this.layers.get(layerIndex).indexInLayerOf(n);
return this.layers[layerIndex].indexInLayerOf(n);
}
private Map<Neuron, List<Neuron>> createConnectionMap() {
Map<Neuron, List<Neuron>> res = new HashMap<>();
for (int i = 0; i < this.layers.size() - 1; i++) {
for (int i = 0; i < this.layers.length - 1; i++) {
List<Neuron> nextLayerNeurons = new ArrayList<>();
this.layers.get(i + 1).forEachNeuron(nextLayerNeurons::add);
this.layers.get(i).forEachNeuron(n -> res.put(n, nextLayerNeurons));
this.layers[i + 1].forEachNeuron(nextLayerNeurons::add);
this.layers[i].forEachNeuron(n -> res.put(n, nextLayerNeurons));
}
return res;
}
@@ -91,10 +96,10 @@ public class FullyConnectedNetwork implements Model {
private Map<Neuron, Integer> createNeuronIndex() {
Map<Neuron, Integer> res = new HashMap<>();
AtomicInteger index = new AtomicInteger(0);
this.layers.forEach(l -> {
for(Layer l : this.layers){
l.forEachNeuron(n -> res.put(n, index.get()));
index.incrementAndGet();
});
}
return res;
}
}

View File

@@ -10,10 +10,10 @@ import java.util.function.Consumer;
public class Layer implements Model {
private final List<Neuron> neurons;
private final Neuron[] neurons;
private final Map<Neuron, Integer> neuronIndex;
public Layer(List<Neuron> neurons) {
public Layer(Neuron[] neurons) {
this.neurons = neurons;
this.neuronIndex = createNeuronIndex();
}
@@ -39,7 +39,7 @@ public class Layer implements Model {
@Override
public int neuronCount() {
return this.neurons.size();
return this.neurons.length;
}
@Override
@@ -49,17 +49,21 @@ public class Layer implements Model {
@Override
public void forEachNeuron(Consumer<Neuron> consumer) {
this.neurons.forEach(consumer);
for (Neuron n : this.neurons){
consumer.accept(n);
}
}
@Override
public void forEachSynapse(Consumer<Synapse> consumer) {
this.neurons.forEach(neuron -> neuron.forEachSynapse(consumer));
for (Neuron n : this.neurons){
n.forEachSynapse(consumer);
}
}
@Override
public void forEachOutputNeurons(Consumer<Neuron> consumer) {
this.neurons.forEach(consumer);
this.forEachNeuron(consumer);
}
@Override
@@ -70,7 +74,7 @@ public class Layer implements Model {
private Map<Neuron, Integer> createNeuronIndex() {
Map<Neuron, Integer> res = new HashMap<>();
int[] index = {0};
this.neurons.forEach(n -> {
this.forEachNeuron(n -> {
res.put(n, index[0]++);
});
return res;

View File

@@ -7,13 +7,13 @@ import java.util.function.Consumer;
public class Neuron implements Model {
protected List<Synapse> synapses;
protected Synapse[] synapses;
protected Bias bias;
protected ActivationFunction activationFunction;
protected Float output;
protected Float weightedSum;
public Neuron(List<Synapse> synapses, Bias bias, ActivationFunction func){
public Neuron(Synapse[] synapses, Bias bias, ActivationFunction func){
this.synapses = synapses;
this.bias = bias;
this.activationFunction = func;
@@ -26,12 +26,12 @@ public class Neuron implements Model {
}
public void updateWeight(int index, Weight weight) {
this.synapses.get(index).setWeight(weight.getValue());
this.synapses[index].setWeight(weight.getValue());
}
protected void setInputs(List<Input> inputs){
for(int i = 0; i < inputs.size() && i < synapses.size(); i++){
Synapse syn = this.synapses.get(i);
for(int i = 0; i < inputs.size() && i < synapses.length; i++){
Synapse syn = this.synapses[i];
syn.setInput(inputs.get(i));
}
}
@@ -45,7 +45,7 @@ public class Neuron implements Model {
}
public float getWeight(int index){
return this.synapses.get(index).getWeight();
return this.synapses[index].getWeight();
}
public float getWeightedSum(){
@@ -63,7 +63,7 @@ public class Neuron implements Model {
@Override
public int synCount() {
return this.synapses.size()+1; //take the bias into account
return this.synapses.length+1; //take the bias into account
}
@Override
@@ -91,7 +91,9 @@ public class Neuron implements Model {
@Override
public void forEachSynapse(Consumer<Synapse> consumer) {
consumer.accept(this.bias);
this.synapses.forEach(consumer);
for (Synapse syn : this.synapses){
consumer.accept(syn);
}
}
@Override

View File

@@ -30,12 +30,12 @@ public class GradientBackpropagationTraining implements Trainer {
);
new TrainingPipeline(steps)
.stopCondition(ctx -> ctx.globalLoss <= 0.0001F || ctx.epoch > epoch)
.stopCondition(ctx -> ctx.globalLoss <= 0.00F || ctx.epoch > epoch)
.beforeEpoch(ctx -> {
ctx.globalLoss = 0.0F;
})
.afterEpoch(ctx -> ctx.globalLoss /= dataset.size())
.withVerbose(true, 100)
.withVerbose(false, epoch/10)
.withTimeMeasurement(true)
.run(context);
}