Reworked synapses data structure

This commit is contained in:
2026-04-01 22:48:06 +02:00
parent 4441b149f9
commit 5ddf6dc580
13 changed files with 77 additions and 94 deletions

View File

@@ -26,7 +26,7 @@ public class Main {
DataSet dataset = new DatasetExtractor() DataSet dataset = new DatasetExtractor()
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_12.csv", nbrClass); .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_12.csv", nbrClass);
int[] neuronPerLayer = new int[]{10, 5, 10, dataset.getNbrLabels()}; int[] neuronPerLayer = new int[]{100, 100, 50, dataset.getNbrLabels()};
int nbrInput = dataset.getNbrInputs(); int nbrInput = dataset.getNbrInputs();
FullyConnectedNetwork network = createNetwork(neuronPerLayer, nbrInput); FullyConnectedNetwork network = createNetwork(neuronPerLayer, nbrInput);
@@ -34,7 +34,7 @@ public class Main {
System.out.println(network.synCount()); System.out.println(network.synCount());
Trainer trainer = new GradientBackpropagationTraining(); Trainer trainer = new GradientBackpropagationTraining();
trainer.train(0.01F, 2000, network, dataset); trainer.train(0.001F, 2000, network, dataset);
plotGraph(dataset, network); plotGraph(dataset, network);
} }

View File

@@ -12,7 +12,7 @@ public interface Model {
int neuronCount(); int neuronCount();
int indexInLayerOf(Neuron n); int indexInLayerOf(Neuron n);
void forEachNeuron(Consumer<Neuron> consumer); void forEachNeuron(Consumer<Neuron> consumer);
void forEachSynapse(Consumer<Synapse> consumer); //void forEachSynapse(Consumer<Synapse> consumer);
void forEachOutputNeurons(Consumer<Neuron> consumer); void forEachOutputNeurons(Consumer<Neuron> consumer);
void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer); void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer);
List<Float> predict(List<Input> inputs); List<Float> predict(List<Input> inputs);

View File

@@ -51,13 +51,6 @@ public class FullyConnectedNetwork implements Model {
return res; return res;
} }
@Override
public void forEachSynapse(Consumer<Synapse> consumer) {
for(Layer l : this.layers){
l.forEachSynapse(consumer);
}
}
@Override @Override
public void forEachNeuron(Consumer<Neuron> consumer) { public void forEachNeuron(Consumer<Neuron> consumer) {
for(Layer l : this.layers){ for(Layer l : this.layers){

View File

@@ -54,12 +54,12 @@ public class Layer implements Model {
} }
} }
@Override /*@Override
public void forEachSynapse(Consumer<Synapse> consumer) { public void forEachSynapse(Consumer<Synapse> consumer) {
for (Neuron n : this.neurons){ for (Neuron n : this.neurons){
n.forEachSynapse(consumer); n.forEachSynapse(consumer);
} }
} }*/
@Override @Override
public void forEachOutputNeurons(Consumer<Neuron> consumer) { public void forEachOutputNeurons(Consumer<Neuron> consumer) {

View File

@@ -7,40 +7,37 @@ import java.util.function.Consumer;
public class Neuron implements Model { public class Neuron implements Model {
protected int id; private final int id;
protected final Synapse[] synapses; private float output;
protected Bias bias; private final float[] weights;
protected ActivationFunction activationFunction; private final float[] inputs;
protected Float output; private final ActivationFunction activationFunction;
protected Float weightedSum;
protected final float[] weights;
protected final float[] inputs;
public Neuron(int id, Synapse[] synapses, Bias bias, ActivationFunction func){ public Neuron(int id, Synapse[] synapses, Bias bias, ActivationFunction func){
this.id = id; this.id = id;
this.synapses = synapses;
this.bias = bias;
this.activationFunction = func; this.activationFunction = func;
this.output = null;
this.weightedSum = null;
weights = new float[synapses.length]; weights = new float[synapses.length+1]; //takes the bias into account
inputs = new float[synapses.length]; inputs = new float[synapses.length+1]; //takes the bias into account
weights[0] = bias.getWeight();
inputs[0] = bias.getInput();
for (int i = 0; i < synapses.length; i++){
weights[i+1] = synapses[i].getWeight();
inputs[i+1] = synapses[i].getInput();
}
} }
public void updateBias(Weight weight) { public void setWeight(int index, float value) {
this.bias.setWeight(weight.getValue()); this.weights[index] = value;
} }
public void updateWeight(int index, Weight weight) { public float getWeight(int index) {
this.synapses[index].setWeight(weight.getValue()); return this.weights[index];
} }
protected void setInputs(List<Input> inputs){ public float getInput(int index) {
for(int i = 0; i < inputs.size() && i < synapses.length; i++){ return this.inputs[index];
Synapse syn = this.synapses[i];
syn.setInput(inputs.get(i));
}
} }
public ActivationFunction getActivationFunction(){ public ActivationFunction getActivationFunction(){
@@ -51,21 +48,13 @@ public class Neuron implements Model {
return this.output; return this.output;
} }
public float getWeight(int index){
return this.synapses[index].getWeight();
}
public float getWeightedSum(){
return this.weightedSum;
}
public float calculateWeightedSum() { public float calculateWeightedSum() {
this.weightedSum = 0F; int count = synCount();
this.weightedSum += this.bias.getWeight() * this.bias.getInput(); float weightedSum = 0F;
for(Synapse syn : this.synapses){ for (int i = 0; i < count; i++){
this.weightedSum += syn.getWeight() * syn.getInput(); weightedSum += weights[i] * inputs[i];
} }
return this.weightedSum; return weightedSum;
} }
public int getId(){ public int getId(){
@@ -74,7 +63,7 @@ public class Neuron implements Model {
@Override @Override
public int synCount() { public int synCount() {
return this.synapses.length+1; //take the bias into account return this.weights.length;
} }
@Override @Override
@@ -99,14 +88,6 @@ public class Neuron implements Model {
consumer.accept(this); consumer.accept(this);
} }
@Override
public void forEachSynapse(Consumer<Synapse> consumer) {
consumer.accept(this.bias);
for (Synapse syn : this.synapses){
consumer.accept(syn);
}
}
@Override @Override
public void forEachOutputNeurons(Consumer<Neuron> consumer) { public void forEachOutputNeurons(Consumer<Neuron> consumer) {
consumer.accept(this); consumer.accept(this);
@@ -116,4 +97,11 @@ public class Neuron implements Model {
public void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer) { public void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer) {
throw new UnsupportedOperationException("Neurons have no connection with themselves"); throw new UnsupportedOperationException("Neurons have no connection with themselves");
} }
private void setInputs(List<Input> values){
for(int i = 0; i < values.size(); i++){
inputs[i+1] = values.get(i).getValue();
}
}
} }

View File

@@ -86,7 +86,7 @@ public class TrainingPipeline {
System.out.printf("[Training finished in %.3fs]", (end-start)/1000.0); System.out.printf("[Training finished in %.3fs]", (end-start)/1000.0);
} }
if(this.visualization) this.visualize(ctx); //if(this.visualization) this.visualize(ctx);
} }
private void executeSteps(TrainingContext ctx){ private void executeSteps(TrainingContext ctx){
@@ -109,7 +109,7 @@ public class TrainingPipeline {
} }
} }
private void visualize(TrainingContext ctx){ /*private void visualize(TrainingContext ctx){
AtomicInteger neuronIndex = new AtomicInteger(0); AtomicInteger neuronIndex = new AtomicInteger(0);
ctx.model.forEachNeuron(n -> { ctx.model.forEachNeuron(n -> {
List<Float> weights = new ArrayList<>(); List<Float> weights = new ArrayList<>();
@@ -129,6 +129,6 @@ public class TrainingPipeline {
i++; i++;
} }
this.visualizer.buildLineGraph(); this.visualizer.buildLineGraph();
} }*/
} }

View File

@@ -14,12 +14,14 @@ public class GradientDescentCorrectionStrategy implements AlgorithmStep {
@Override @Override
public void run() { public void run() {
AtomicInteger i = new AtomicInteger(0); int[] globalSynIndex = {0};
context.model.forEachSynapse(syn -> { context.model.forEachNeuron(n -> {
float corrector = context.correctorTerms.get(i.get()); for(int i = 0; i < n.synCount(); i++){
float c = syn.getWeight() + corrector; float corrector = context.correctorTerms.get(globalSynIndex[0]);
syn.setWeight(c); float c = n.getWeight(i) + corrector;
i.incrementAndGet(); n.setWeight(i, c);
globalSynIndex[0]++;
}
}); });
} }
} }

View File

@@ -22,13 +22,12 @@ public class GradientDescentErrorStrategy implements AlgorithmStep {
context.model.forEachNeuron(neuron -> { context.model.forEachNeuron(neuron -> {
float correspondingDelta = context.deltas[neuronIndex.get()]; float correspondingDelta = context.deltas[neuronIndex.get()];
neuron.forEachSynapse(syn -> { for(int i = 0; i < neuron.synCount(); i++){
float corrector = context.correctorTerms.get(synIndex.get()); float corrector = context.correctorTerms.get(synIndex.get());
corrector += context.learningRate * correspondingDelta * syn.getInput(); corrector += context.learningRate * correspondingDelta * neuron.getInput(i);
context.correctorTerms.set(synIndex.get(), corrector); context.correctorTerms.set(synIndex.get(), corrector);
synIndex.incrementAndGet(); synIndex.incrementAndGet();
}); }
neuronIndex.incrementAndGet(); neuronIndex.incrementAndGet();
}); });

View File

@@ -21,11 +21,11 @@ public class BackpropagationCorrectionStep implements AlgorithmStep {
int[] synIndex = {0}; int[] synIndex = {0};
context.model.forEachNeuron(n -> { context.model.forEachNeuron(n -> {
float signal = context.errorSignals[n.getId()]; float signal = context.errorSignals[n.getId()];
n.forEachSynapse(syn -> { for (int i = 0; i < n.synCount(); i++){
inputs[synIndex[0]] = syn.getInput(); inputs[synIndex[0]] = n.getInput(i);
signals[synIndex[0]] = signal; signals[synIndex[0]] = signal;
synIndex[0]++; synIndex[0]++;
}); }
}); });
float lr = context.learningRate; float lr = context.learningRate;
@@ -44,13 +44,13 @@ public class BackpropagationCorrectionStep implements AlgorithmStep {
} }
private void syncWeights() { private void syncWeights() {
int[] i = {0}; int[] synIndex = {0};
context.model.forEachNeuron(n -> { context.model.forEachNeuron(n -> {
n.forEachSynapse(syn -> { for (int i = 0; i < n.synCount(); i++) {
syn.setWeight(syn.getWeight() + context.correctionBuffer[i[0]]); n.setWeight(i, n.getWeight(i) + context.correctionBuffer[synIndex[0]]);
context.correctionBuffer[i[0]] = 0f; context.correctionBuffer[synIndex[0]] = 0f;
i[0]++; synIndex[0]++;
}); }
}); });
} }
} }

View File

@@ -13,7 +13,7 @@ public class OutputLayerErrorStep implements AlgorithmStep {
public OutputLayerErrorStep(GradientBackpropagationContext context){ public OutputLayerErrorStep(GradientBackpropagationContext context){
this.context = context; this.context = context;
this.expectations = new float[context.model.neuronCount()]; this.expectations = new float[context.dataset.getNbrLabels()];
} }
@Override @Override

View File

@@ -18,17 +18,16 @@ public class SimpleCorrectionStep implements AlgorithmStep {
public void run() { public void run() {
if(context.expectations.equals(context.predictions)) return; if(context.expectations.equals(context.predictions)) return;
AtomicInteger neuronIndex = new AtomicInteger(0); AtomicInteger neuronIndex = new AtomicInteger(0);
AtomicInteger synIndex = new AtomicInteger(0);
context.model.forEachNeuron(neuron -> { context.model.forEachNeuron(neuron -> {
float correspondingDelta = context.deltas[neuronIndex.get()]; float correspondingDelta = context.deltas[neuronIndex.get()];
neuron.forEachSynapse(syn -> {
float currentW = syn.getWeight(); for(int i = 0; i < neuron.synCount(); i++){
float currentInput = syn.getInput(); float currentW = neuron.getWeight(i);
float currentInput = neuron.getInput(i);
float newValue = currentW + (context.learningRate * correspondingDelta * currentInput); float newValue = currentW + (context.learningRate * correspondingDelta * currentInput);
syn.setWeight(newValue); neuron.setWeight(i, newValue);
synIndex.incrementAndGet(); }
});
neuronIndex.incrementAndGet(); neuronIndex.incrementAndGet();
}); });
} }

View File

@@ -17,7 +17,7 @@ public class GradientBackpropagationTraining implements Trainer {
@Override @Override
public void train(float learningRate, int epoch, Model model, DataSet dataset) { public void train(float learningRate, int epoch, Model model, DataSet dataset) {
GradientBackpropagationContext context = GradientBackpropagationContext context =
new GradientBackpropagationContext(model, dataset, learningRate, dataset.size()/3); new GradientBackpropagationContext(model, dataset, learningRate, dataset.size());
List<AlgorithmStep> steps = List.of( List<AlgorithmStep> steps = List.of(
new SimplePredictionStep(context), new SimplePredictionStep(context),
@@ -35,7 +35,7 @@ public class GradientBackpropagationTraining implements Trainer {
.afterEpoch(ctx -> { .afterEpoch(ctx -> {
ctx.globalLoss /= dataset.size(); ctx.globalLoss /= dataset.size();
}) })
.withVerbose(true,epoch/10) .withVerbose(false,epoch/10)
.withTimeMeasurement(true) .withTimeMeasurement(true)
.run(context); .run(context);
} }

View File

@@ -41,7 +41,9 @@ public class GradientDescentTraining implements Trainer {
GradientDescentTrainingContext gdCtx = (GradientDescentTrainingContext) ctx; GradientDescentTrainingContext gdCtx = (GradientDescentTrainingContext) ctx;
gdCtx.globalLoss = 0.0F; gdCtx.globalLoss = 0.0F;
gdCtx.correctorTerms.clear(); gdCtx.correctorTerms.clear();
gdCtx.model.forEachSynapse(syn -> gdCtx.correctorTerms.add(0F)); for(int i = 0; i < gdCtx.model.synCount(); i++){
gdCtx.correctorTerms.add(0F);
}
}) })
.afterEpoch(ctx -> { .afterEpoch(ctx -> {
context.globalLoss /= context.dataset.size(); context.globalLoss /= context.dataset.size();