Reworked synapses data structure

This commit is contained in:
2026-04-01 22:48:06 +02:00
parent 4441b149f9
commit 5ddf6dc580
13 changed files with 77 additions and 94 deletions

View File

@@ -26,7 +26,7 @@ public class Main {
DataSet dataset = new DatasetExtractor()
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_12.csv", nbrClass);
int[] neuronPerLayer = new int[]{10, 5, 10, dataset.getNbrLabels()};
int[] neuronPerLayer = new int[]{100, 100, 50, dataset.getNbrLabels()};
int nbrInput = dataset.getNbrInputs();
FullyConnectedNetwork network = createNetwork(neuronPerLayer, nbrInput);
@@ -34,7 +34,7 @@ public class Main {
System.out.println(network.synCount());
Trainer trainer = new GradientBackpropagationTraining();
trainer.train(0.01F, 2000, network, dataset);
trainer.train(0.001F, 2000, network, dataset);
plotGraph(dataset, network);
}

View File

@@ -12,7 +12,7 @@ public interface Model {
int neuronCount();
int indexInLayerOf(Neuron n);
void forEachNeuron(Consumer<Neuron> consumer);
void forEachSynapse(Consumer<Synapse> consumer);
//void forEachSynapse(Consumer<Synapse> consumer);
void forEachOutputNeurons(Consumer<Neuron> consumer);
void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer);
List<Float> predict(List<Input> inputs);

View File

@@ -51,13 +51,6 @@ public class FullyConnectedNetwork implements Model {
return res;
}
@Override
public void forEachSynapse(Consumer<Synapse> consumer) {
for(Layer l : this.layers){
l.forEachSynapse(consumer);
}
}
@Override
public void forEachNeuron(Consumer<Neuron> consumer) {
for(Layer l : this.layers){

View File

@@ -54,12 +54,12 @@ public class Layer implements Model {
}
}
@Override
/*@Override
public void forEachSynapse(Consumer<Synapse> consumer) {
for (Neuron n : this.neurons){
n.forEachSynapse(consumer);
}
}
}*/
@Override
public void forEachOutputNeurons(Consumer<Neuron> consumer) {

View File

@@ -7,40 +7,37 @@ import java.util.function.Consumer;
public class Neuron implements Model {
protected int id;
protected final Synapse[] synapses;
protected Bias bias;
protected ActivationFunction activationFunction;
protected Float output;
protected Float weightedSum;
protected final float[] weights;
protected final float[] inputs;
private final int id;
private float output;
private final float[] weights;
private final float[] inputs;
private final ActivationFunction activationFunction;
public Neuron(int id, Synapse[] synapses, Bias bias, ActivationFunction func){
this.id = id;
this.synapses = synapses;
this.bias = bias;
this.activationFunction = func;
this.output = null;
this.weightedSum = null;
weights = new float[synapses.length];
inputs = new float[synapses.length];
weights = new float[synapses.length+1]; //takes the bias into account
inputs = new float[synapses.length+1]; //takes the bias into account
weights[0] = bias.getWeight();
inputs[0] = bias.getInput();
for (int i = 0; i < synapses.length; i++){
weights[i+1] = synapses[i].getWeight();
inputs[i+1] = synapses[i].getInput();
}
}
public void updateBias(Weight weight) {
this.bias.setWeight(weight.getValue());
public void setWeight(int index, float value) {
this.weights[index] = value;
}
public void updateWeight(int index, Weight weight) {
this.synapses[index].setWeight(weight.getValue());
public float getWeight(int index) {
return this.weights[index];
}
protected void setInputs(List<Input> inputs){
for(int i = 0; i < inputs.size() && i < synapses.length; i++){
Synapse syn = this.synapses[i];
syn.setInput(inputs.get(i));
}
public float getInput(int index) {
return this.inputs[index];
}
public ActivationFunction getActivationFunction(){
@@ -51,21 +48,13 @@ public class Neuron implements Model {
return this.output;
}
public float getWeight(int index){
return this.synapses[index].getWeight();
}
public float getWeightedSum(){
return this.weightedSum;
}
public float calculateWeightedSum() {
this.weightedSum = 0F;
this.weightedSum += this.bias.getWeight() * this.bias.getInput();
for(Synapse syn : this.synapses){
this.weightedSum += syn.getWeight() * syn.getInput();
int count = synCount();
float weightedSum = 0F;
for (int i = 0; i < count; i++){
weightedSum += weights[i] * inputs[i];
}
return this.weightedSum;
return weightedSum;
}
public int getId(){
@@ -74,7 +63,7 @@ public class Neuron implements Model {
@Override
public int synCount() {
return this.synapses.length+1; //take the bias into account
return this.weights.length;
}
@Override
@@ -99,14 +88,6 @@ public class Neuron implements Model {
consumer.accept(this);
}
@Override
public void forEachSynapse(Consumer<Synapse> consumer) {
consumer.accept(this.bias);
for (Synapse syn : this.synapses){
consumer.accept(syn);
}
}
@Override
public void forEachOutputNeurons(Consumer<Neuron> consumer) {
consumer.accept(this);
@@ -116,4 +97,11 @@ public class Neuron implements Model {
public void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer) {
throw new UnsupportedOperationException("Neurons have no connection with themselves");
}
private void setInputs(List<Input> values){
for(int i = 0; i < values.size(); i++){
inputs[i+1] = values.get(i).getValue();
}
}
}

View File

@@ -86,7 +86,7 @@ public class TrainingPipeline {
System.out.printf("[Training finished in %.3fs]", (end-start)/1000.0);
}
if(this.visualization) this.visualize(ctx);
//if(this.visualization) this.visualize(ctx);
}
private void executeSteps(TrainingContext ctx){
@@ -109,7 +109,7 @@ public class TrainingPipeline {
}
}
private void visualize(TrainingContext ctx){
/*private void visualize(TrainingContext ctx){
AtomicInteger neuronIndex = new AtomicInteger(0);
ctx.model.forEachNeuron(n -> {
List<Float> weights = new ArrayList<>();
@@ -129,6 +129,6 @@ public class TrainingPipeline {
i++;
}
this.visualizer.buildLineGraph();
}
}*/
}

View File

@@ -14,12 +14,14 @@ public class GradientDescentCorrectionStrategy implements AlgorithmStep {
@Override
public void run() {
AtomicInteger i = new AtomicInteger(0);
context.model.forEachSynapse(syn -> {
float corrector = context.correctorTerms.get(i.get());
float c = syn.getWeight() + corrector;
syn.setWeight(c);
i.incrementAndGet();
int[] globalSynIndex = {0};
context.model.forEachNeuron(n -> {
for(int i = 0; i < n.synCount(); i++){
float corrector = context.correctorTerms.get(globalSynIndex[0]);
float c = n.getWeight(i) + corrector;
n.setWeight(i, c);
globalSynIndex[0]++;
}
});
}
}

View File

@@ -22,13 +22,12 @@ public class GradientDescentErrorStrategy implements AlgorithmStep {
context.model.forEachNeuron(neuron -> {
float correspondingDelta = context.deltas[neuronIndex.get()];
neuron.forEachSynapse(syn -> {
for(int i = 0; i < neuron.synCount(); i++){
float corrector = context.correctorTerms.get(synIndex.get());
corrector += context.learningRate * correspondingDelta * syn.getInput();
corrector += context.learningRate * correspondingDelta * neuron.getInput(i);
context.correctorTerms.set(synIndex.get(), corrector);
synIndex.incrementAndGet();
});
}
neuronIndex.incrementAndGet();
});

View File

@@ -21,11 +21,11 @@ public class BackpropagationCorrectionStep implements AlgorithmStep {
int[] synIndex = {0};
context.model.forEachNeuron(n -> {
float signal = context.errorSignals[n.getId()];
n.forEachSynapse(syn -> {
inputs[synIndex[0]] = syn.getInput();
for (int i = 0; i < n.synCount(); i++){
inputs[synIndex[0]] = n.getInput(i);
signals[synIndex[0]] = signal;
synIndex[0]++;
});
}
});
float lr = context.learningRate;
@@ -44,13 +44,13 @@ public class BackpropagationCorrectionStep implements AlgorithmStep {
}
private void syncWeights() {
int[] i = {0};
int[] synIndex = {0};
context.model.forEachNeuron(n -> {
n.forEachSynapse(syn -> {
syn.setWeight(syn.getWeight() + context.correctionBuffer[i[0]]);
context.correctionBuffer[i[0]] = 0f;
i[0]++;
});
for (int i = 0; i < n.synCount(); i++) {
n.setWeight(i, n.getWeight(i) + context.correctionBuffer[synIndex[0]]);
context.correctionBuffer[synIndex[0]] = 0f;
synIndex[0]++;
}
});
}
}

View File

@@ -13,7 +13,7 @@ public class OutputLayerErrorStep implements AlgorithmStep {
public OutputLayerErrorStep(GradientBackpropagationContext context){
this.context = context;
this.expectations = new float[context.model.neuronCount()];
this.expectations = new float[context.dataset.getNbrLabels()];
}
@Override

View File

@@ -18,17 +18,16 @@ public class SimpleCorrectionStep implements AlgorithmStep {
public void run() {
if(context.expectations.equals(context.predictions)) return;
AtomicInteger neuronIndex = new AtomicInteger(0);
AtomicInteger synIndex = new AtomicInteger(0);
context.model.forEachNeuron(neuron -> {
float correspondingDelta = context.deltas[neuronIndex.get()];
neuron.forEachSynapse(syn -> {
float currentW = syn.getWeight();
float currentInput = syn.getInput();
for(int i = 0; i < neuron.synCount(); i++){
float currentW = neuron.getWeight(i);
float currentInput = neuron.getInput(i);
float newValue = currentW + (context.learningRate * correspondingDelta * currentInput);
syn.setWeight(newValue);
synIndex.incrementAndGet();
});
neuron.setWeight(i, newValue);
}
neuronIndex.incrementAndGet();
});
}

View File

@@ -17,7 +17,7 @@ public class GradientBackpropagationTraining implements Trainer {
@Override
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
GradientBackpropagationContext context =
new GradientBackpropagationContext(model, dataset, learningRate, dataset.size()/3);
new GradientBackpropagationContext(model, dataset, learningRate, dataset.size());
List<AlgorithmStep> steps = List.of(
new SimplePredictionStep(context),
@@ -35,7 +35,7 @@ public class GradientBackpropagationTraining implements Trainer {
.afterEpoch(ctx -> {
ctx.globalLoss /= dataset.size();
})
.withVerbose(true,epoch/10)
.withVerbose(false,epoch/10)
.withTimeMeasurement(true)
.run(context);
}

View File

@@ -41,7 +41,9 @@ public class GradientDescentTraining implements Trainer {
GradientDescentTrainingContext gdCtx = (GradientDescentTrainingContext) ctx;
gdCtx.globalLoss = 0.0F;
gdCtx.correctorTerms.clear();
gdCtx.model.forEachSynapse(syn -> gdCtx.correctorTerms.add(0F));
for(int i = 0; i < gdCtx.model.synCount(); i++){
gdCtx.correctorTerms.add(0F);
}
})
.afterEpoch(ctx -> {
context.globalLoss /= context.dataset.size();