Just a regular commit

This commit is contained in:
2026-03-20 16:58:51 +01:00
parent a2a74566ba
commit 6742b18473
3 changed files with 96 additions and 10 deletions

View File

@@ -11,6 +11,7 @@ import com.naaturel.ANN.domain.model.neuron.Weight;
import com.naaturel.ANN.implementation.activationFunction.Heaviside; import com.naaturel.ANN.implementation.activationFunction.Heaviside;
import com.naaturel.ANN.implementation.activationFunction.Linear; import com.naaturel.ANN.implementation.activationFunction.Linear;
import com.naaturel.ANN.implementation.neuron.SimplePerceptron; import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
import com.naaturel.ANN.implementation.training.AdalineTraining;
import com.naaturel.ANN.implementation.training.GradientDescentTraining; import com.naaturel.ANN.implementation.training.GradientDescentTraining;
import java.util.*; import java.util.*;
@@ -64,11 +65,11 @@ public class Main {
Bias bias = new Bias(new Weight(0)); Bias bias = new Bias(new Weight(0));
Neuron n = new SimplePerceptron(syns, bias, new Linear()); Neuron n = new SimplePerceptron(syns, bias, new Linear());
GradientDescentTraining st = new GradientDescentTraining(); AdalineTraining st = new AdalineTraining();
long start = System.currentTimeMillis(); long start = System.currentTimeMillis();
st.train(n, 0.2F, andDataSet); st.train(n, 0.03F, andDataSet);
long end = System.currentTimeMillis(); long end = System.currentTimeMillis();
System.out.printf("Training completed in %.2f s%n", (end - start) / 1000.0); System.out.printf("Training completed in %.2f s%n", (end - start) / 1000.0);

View File

@@ -1,4 +1,89 @@
package com.naaturel.ANN.implementation.training; package com.naaturel.ANN.implementation.training;
import com.naaturel.ANN.domain.abstraction.Neuron;
import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
import com.naaturel.ANN.domain.model.neuron.Bias;
import com.naaturel.ANN.domain.model.neuron.Input;
import com.naaturel.ANN.domain.model.neuron.Synapse;
import com.naaturel.ANN.domain.model.neuron.Weight;
import java.util.ArrayList;
import java.util.List;
public class AdalineTraining { public class AdalineTraining {
public AdalineTraining(){
}
public void train(Neuron n, float learningRate, DataSet dataSet) {
int epoch = 1;
int maxEpoch = 1000;
float errorThreshold = 0.0F;
float mse;
do {
if(epoch > maxEpoch) break;
mse = 0;
for(DataSetEntry entry : dataSet) {
this.updateInputs(n, entry);
float prediction = n.predict();
float expectation = dataSet.getLabel(entry).getValue();
float delta = this.calculateDelta(expectation, prediction);
float loss = this.calculateLoss(delta);
mse += loss;
float currentBias = n.getBias().getWeight();
float biasCorrector = currentBias + (learningRate * delta * n.getBias().getInput());
n.updateBias(new Weight(biasCorrector));
for(Synapse syn : n.getSynapses()){
float synCorrector = syn.getWeight() + (learningRate * delta * syn.getInput());
syn.setWeight(synCorrector);
}
System.out.printf("Epoch : %d ", epoch);
System.out.printf("predicted : %.2f, ", prediction);
System.out.printf("expected : %.2f, ", expectation);
System.out.printf("delta : %.2f, ", delta);
System.out.printf("loss : %.2f\n", loss);
}
System.out.printf("[Total error : %f]\n", mse);
epoch++;
} while(mse > errorThreshold);
}
private List<Float> initCorrectorTerms(int number){
List<Float> res = new ArrayList<>();
for(int i = 0; i < number; i++){
res.add(0F);
}
return res;
}
private void updateInputs(Neuron n, DataSetEntry entry){
int index = 0;
for(float value : entry){
n.setInput(index, new Input(value));
index++;
}
}
private float calculateDelta(float expected, float predicted){
return expected - predicted;
}
private float calculateLoss(float delta){
return (float) Math.pow(delta, 2)/2;
}
private float calculateWeightCorrection(float value, float delta){
return value * delta;
}
} }

View File

@@ -19,15 +19,15 @@ public class GradientDescentTraining {
public void train(Neuron n, float learningRate, DataSet dataSet) { public void train(Neuron n, float learningRate, DataSet dataSet) {
int epoch = 1; int epoch = 1;
int maxEpoch = 200; int maxEpoch = 1000;
float errorThreshold = 0.125F; float errorThreshold = 0.0F;
float currentError; float mse;
do { do {
if(epoch > maxEpoch) break; if(epoch > maxEpoch) break;
float biasCorrector = 0; float biasCorrector = 0;
currentError = 0; mse = 0;
List<Float> correctorTerms = this.initCorrectorTerms(n.getSynCount()); List<Float> correctorTerms = this.initCorrectorTerms(n.getSynCount());
for(DataSetEntry entry : dataSet) { for(DataSetEntry entry : dataSet) {
@@ -37,7 +37,7 @@ public class GradientDescentTraining {
float delta = this.calculateDelta(expectation, prediction); float delta = this.calculateDelta(expectation, prediction);
float loss = this.calculateLoss(delta); float loss = this.calculateLoss(delta);
currentError += loss/dataSet.size(); mse += loss;
biasCorrector += learningRate * delta * n.getBias().getInput(); biasCorrector += learningRate * delta * n.getBias().getInput();
@@ -54,7 +54,7 @@ public class GradientDescentTraining {
System.out.printf("delta : %.2f, ", delta); System.out.printf("delta : %.2f, ", delta);
System.out.printf("loss : %.2f\n", loss); System.out.printf("loss : %.2f\n", loss);
} }
System.out.printf("[Total error : %.3f]\n", currentError); System.out.printf("[Total error : %f]\n", mse);
float currentBias = n.getBias().getWeight(); float currentBias = n.getBias().getWeight();
float newBias = currentBias + biasCorrector; float newBias = currentBias + biasCorrector;
@@ -67,7 +67,7 @@ public class GradientDescentTraining {
} }
epoch++; epoch++;
} while(currentError > errorThreshold); } while(mse > errorThreshold);
} }
@@ -92,7 +92,7 @@ public class GradientDescentTraining {
} }
private float calculateLoss(float delta){ private float calculateLoss(float delta){
return ((float) Math.pow(delta, 2))/2; return (float) Math.pow(delta, 2)/2;
} }
private float calculateWeightCorrection(float value, float delta){ private float calculateWeightCorrection(float value, float delta){