Minor perfomance improvements

This commit is contained in:
2026-03-30 22:14:33 +02:00
parent fd97d0853c
commit 881088df28
5 changed files with 46 additions and 34 deletions

View File

@@ -11,6 +11,7 @@ import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
import com.naaturel.ANN.domain.model.neuron.*; import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.infrastructure.graph.GraphVisualizer; import com.naaturel.ANN.infrastructure.graph.GraphVisualizer;
import java.io.Console;
import java.util.*; import java.util.*;
public class Main { public class Main {
@@ -40,19 +41,19 @@ public class Main {
Bias bias = new Bias(new Weight()); Bias bias = new Bias(new Weight());
Neuron n = new Neuron(syns, bias, new TanH()); Neuron n = new Neuron(syns.toArray(new Synapse[0]), bias, new TanH());
neurons.add(n); neurons.add(n);
} }
Layer layer = new Layer(neurons); Layer layer = new Layer(neurons.toArray(new Neuron[0]));
layers.add(layer); layers.add(layer);
} }
FullyConnectedNetwork network = new FullyConnectedNetwork(layers); FullyConnectedNetwork network = new FullyConnectedNetwork(layers.toArray(new Layer[0]));
Trainer trainer = new GradientBackpropagationTraining(); Trainer trainer = new GradientBackpropagationTraining();
trainer.train(0.001F, 1000, network, dataset); trainer.train(0.01F, 5000, network, dataset);
/*GraphVisualizer visualizer = new GraphVisualizer(); GraphVisualizer visualizer = new GraphVisualizer();
for (DataSetEntry entry : dataset) { for (DataSetEntry entry : dataset) {
List<Float> label = dataset.getLabelsAsFloat(entry); List<Float> label = dataset.getLabelsAsFloat(entry);
@@ -71,7 +72,7 @@ public class Main {
} }
visualizer.buildScatterGraph();*/ visualizer.buildScatterGraph();
} }
} }

View File

@@ -14,10 +14,10 @@ import java.util.function.Consumer;
*/ */
public class FullyConnectedNetwork implements Model { public class FullyConnectedNetwork implements Model {
private final List<Layer> layers; private final Layer[] layers;
private final Map<Neuron, List<Neuron>> connectionMap; private final Map<Neuron, List<Neuron>> connectionMap;
private final Map<Neuron, Integer> layerIndexByNeuron; private final Map<Neuron, Integer> layerIndexByNeuron;
public FullyConnectedNetwork(List<Layer> layers) { public FullyConnectedNetwork(Layer[] layers) {
this.layers = layers; this.layers = layers;
this.connectionMap = this.createConnectionMap(); this.connectionMap = this.createConnectionMap();
this.layerIndexByNeuron = this.createNeuronIndex(); this.layerIndexByNeuron = this.createNeuronIndex();
@@ -53,17 +53,22 @@ public class FullyConnectedNetwork implements Model {
@Override @Override
public void forEachSynapse(Consumer<Synapse> consumer) { public void forEachSynapse(Consumer<Synapse> consumer) {
this.layers.forEach(layer -> layer.forEachSynapse(consumer)); for(Layer l : this.layers){
l.forEachSynapse(consumer);
}
} }
@Override @Override
public void forEachNeuron(Consumer<Neuron> consumer) { public void forEachNeuron(Consumer<Neuron> consumer) {
this.layers.forEach(layer -> layer.forEachNeuron(consumer)); for(Layer l : this.layers){
l.forEachNeuron(consumer);
}
} }
@Override @Override
public void forEachOutputNeurons(Consumer<Neuron> consumer) { public void forEachOutputNeurons(Consumer<Neuron> consumer) {
this.layers.getLast().forEachNeuron(consumer); int lastIndex = this.layers.length-1;
this.layers[lastIndex].forEachNeuron(consumer);
} }
@Override @Override
@@ -74,16 +79,16 @@ public class FullyConnectedNetwork implements Model {
@Override @Override
public int indexInLayerOf(Neuron n) { public int indexInLayerOf(Neuron n) {
int layerIndex = this.layerIndexByNeuron.get(n); int layerIndex = this.layerIndexByNeuron.get(n);
return this.layers.get(layerIndex).indexInLayerOf(n); return this.layers[layerIndex].indexInLayerOf(n);
} }
private Map<Neuron, List<Neuron>> createConnectionMap() { private Map<Neuron, List<Neuron>> createConnectionMap() {
Map<Neuron, List<Neuron>> res = new HashMap<>(); Map<Neuron, List<Neuron>> res = new HashMap<>();
for (int i = 0; i < this.layers.size() - 1; i++) { for (int i = 0; i < this.layers.length - 1; i++) {
List<Neuron> nextLayerNeurons = new ArrayList<>(); List<Neuron> nextLayerNeurons = new ArrayList<>();
this.layers.get(i + 1).forEachNeuron(nextLayerNeurons::add); this.layers[i + 1].forEachNeuron(nextLayerNeurons::add);
this.layers.get(i).forEachNeuron(n -> res.put(n, nextLayerNeurons)); this.layers[i].forEachNeuron(n -> res.put(n, nextLayerNeurons));
} }
return res; return res;
} }
@@ -91,10 +96,10 @@ public class FullyConnectedNetwork implements Model {
private Map<Neuron, Integer> createNeuronIndex() { private Map<Neuron, Integer> createNeuronIndex() {
Map<Neuron, Integer> res = new HashMap<>(); Map<Neuron, Integer> res = new HashMap<>();
AtomicInteger index = new AtomicInteger(0); AtomicInteger index = new AtomicInteger(0);
this.layers.forEach(l -> { for(Layer l : this.layers){
l.forEachNeuron(n -> res.put(n, index.get())); l.forEachNeuron(n -> res.put(n, index.get()));
index.incrementAndGet(); index.incrementAndGet();
}); }
return res; return res;
} }
} }

View File

@@ -10,10 +10,10 @@ import java.util.function.Consumer;
public class Layer implements Model { public class Layer implements Model {
private final List<Neuron> neurons; private final Neuron[] neurons;
private final Map<Neuron, Integer> neuronIndex; private final Map<Neuron, Integer> neuronIndex;
public Layer(List<Neuron> neurons) { public Layer(Neuron[] neurons) {
this.neurons = neurons; this.neurons = neurons;
this.neuronIndex = createNeuronIndex(); this.neuronIndex = createNeuronIndex();
} }
@@ -39,7 +39,7 @@ public class Layer implements Model {
@Override @Override
public int neuronCount() { public int neuronCount() {
return this.neurons.size(); return this.neurons.length;
} }
@Override @Override
@@ -49,17 +49,21 @@ public class Layer implements Model {
@Override @Override
public void forEachNeuron(Consumer<Neuron> consumer) { public void forEachNeuron(Consumer<Neuron> consumer) {
this.neurons.forEach(consumer); for (Neuron n : this.neurons){
consumer.accept(n);
}
} }
@Override @Override
public void forEachSynapse(Consumer<Synapse> consumer) { public void forEachSynapse(Consumer<Synapse> consumer) {
this.neurons.forEach(neuron -> neuron.forEachSynapse(consumer)); for (Neuron n : this.neurons){
n.forEachSynapse(consumer);
}
} }
@Override @Override
public void forEachOutputNeurons(Consumer<Neuron> consumer) { public void forEachOutputNeurons(Consumer<Neuron> consumer) {
this.neurons.forEach(consumer); this.forEachNeuron(consumer);
} }
@Override @Override
@@ -70,7 +74,7 @@ public class Layer implements Model {
private Map<Neuron, Integer> createNeuronIndex() { private Map<Neuron, Integer> createNeuronIndex() {
Map<Neuron, Integer> res = new HashMap<>(); Map<Neuron, Integer> res = new HashMap<>();
int[] index = {0}; int[] index = {0};
this.neurons.forEach(n -> { this.forEachNeuron(n -> {
res.put(n, index[0]++); res.put(n, index[0]++);
}); });
return res; return res;

View File

@@ -7,13 +7,13 @@ import java.util.function.Consumer;
public class Neuron implements Model { public class Neuron implements Model {
protected List<Synapse> synapses; protected Synapse[] synapses;
protected Bias bias; protected Bias bias;
protected ActivationFunction activationFunction; protected ActivationFunction activationFunction;
protected Float output; protected Float output;
protected Float weightedSum; protected Float weightedSum;
public Neuron(List<Synapse> synapses, Bias bias, ActivationFunction func){ public Neuron(Synapse[] synapses, Bias bias, ActivationFunction func){
this.synapses = synapses; this.synapses = synapses;
this.bias = bias; this.bias = bias;
this.activationFunction = func; this.activationFunction = func;
@@ -26,12 +26,12 @@ public class Neuron implements Model {
} }
public void updateWeight(int index, Weight weight) { public void updateWeight(int index, Weight weight) {
this.synapses.get(index).setWeight(weight.getValue()); this.synapses[index].setWeight(weight.getValue());
} }
protected void setInputs(List<Input> inputs){ protected void setInputs(List<Input> inputs){
for(int i = 0; i < inputs.size() && i < synapses.size(); i++){ for(int i = 0; i < inputs.size() && i < synapses.length; i++){
Synapse syn = this.synapses.get(i); Synapse syn = this.synapses[i];
syn.setInput(inputs.get(i)); syn.setInput(inputs.get(i));
} }
} }
@@ -45,7 +45,7 @@ public class Neuron implements Model {
} }
public float getWeight(int index){ public float getWeight(int index){
return this.synapses.get(index).getWeight(); return this.synapses[index].getWeight();
} }
public float getWeightedSum(){ public float getWeightedSum(){
@@ -63,7 +63,7 @@ public class Neuron implements Model {
@Override @Override
public int synCount() { public int synCount() {
return this.synapses.size()+1; //take the bias into account return this.synapses.length+1; //take the bias into account
} }
@Override @Override
@@ -91,7 +91,9 @@ public class Neuron implements Model {
@Override @Override
public void forEachSynapse(Consumer<Synapse> consumer) { public void forEachSynapse(Consumer<Synapse> consumer) {
consumer.accept(this.bias); consumer.accept(this.bias);
this.synapses.forEach(consumer); for (Synapse syn : this.synapses){
consumer.accept(syn);
}
} }
@Override @Override

View File

@@ -30,12 +30,12 @@ public class GradientBackpropagationTraining implements Trainer {
); );
new TrainingPipeline(steps) new TrainingPipeline(steps)
.stopCondition(ctx -> ctx.globalLoss <= 0.0001F || ctx.epoch > epoch) .stopCondition(ctx -> ctx.globalLoss <= 0.00F || ctx.epoch > epoch)
.beforeEpoch(ctx -> { .beforeEpoch(ctx -> {
ctx.globalLoss = 0.0F; ctx.globalLoss = 0.0F;
}) })
.afterEpoch(ctx -> ctx.globalLoss /= dataset.size()) .afterEpoch(ctx -> ctx.globalLoss /= dataset.size())
.withVerbose(true, 100) .withVerbose(false, epoch/10)
.withTimeMeasurement(true) .withTimeMeasurement(true)
.run(context); .run(context);
} }