Just a regular commit

This commit is contained in:
2026-03-22 23:36:44 +01:00
parent 56f88bded3
commit 76bc791889
6 changed files with 57 additions and 31 deletions

View File

@@ -1,6 +1,7 @@
package com.naaturel.ANN; package com.naaturel.ANN;
import com.naaturel.ANN.domain.abstraction.Neuron; import com.naaturel.ANN.domain.abstraction.Neuron;
import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.model.dataset.DataSet; import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry; import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
import com.naaturel.ANN.domain.model.dataset.Label; import com.naaturel.ANN.domain.model.dataset.Label;
@@ -8,7 +9,6 @@ import com.naaturel.ANN.domain.model.neuron.Bias;
import com.naaturel.ANN.domain.model.neuron.Input; import com.naaturel.ANN.domain.model.neuron.Input;
import com.naaturel.ANN.domain.model.neuron.Synapse; import com.naaturel.ANN.domain.model.neuron.Synapse;
import com.naaturel.ANN.domain.model.neuron.Weight; import com.naaturel.ANN.domain.model.neuron.Weight;
import com.naaturel.ANN.implementation.activationFunction.Heaviside;
import com.naaturel.ANN.implementation.activationFunction.Linear; import com.naaturel.ANN.implementation.activationFunction.Linear;
import com.naaturel.ANN.implementation.neuron.SimplePerceptron; import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
import com.naaturel.ANN.implementation.training.AdalineTraining; import com.naaturel.ANN.implementation.training.AdalineTraining;
@@ -29,8 +29,8 @@ public class Main {
DataSet andDataSet = new DataSet(Map.ofEntries( DataSet andDataSet = new DataSet(Map.ofEntries(
Map.entry(new DataSetEntry(List.of(0.0F, 0.0F)), new Label(-1.0F)), Map.entry(new DataSetEntry(List.of(0.0F, 0.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(0.0F, 1.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(1.0F, 0.0F)), new Label(-1.0F)), Map.entry(new DataSetEntry(List.of(1.0F, 0.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(0.0F, 1.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(1.0F, 1.0F)), new Label(1.0F)) Map.entry(new DataSetEntry(List.of(1.0F, 1.0F)), new Label(1.0F))
)); ));
@@ -65,11 +65,11 @@ public class Main {
Bias bias = new Bias(new Weight(0)); Bias bias = new Bias(new Weight(0));
Neuron n = new SimplePerceptron(syns, bias, new Linear()); Neuron n = new SimplePerceptron(syns, bias, new Linear());
AdalineTraining st = new AdalineTraining(); Trainer trainer = new AdalineTraining();
long start = System.currentTimeMillis(); long start = System.currentTimeMillis();
st.train(n, 0.03F, andDataSet); trainer.train(n, 0.03F, andDataSet);
long end = System.currentTimeMillis(); long end = System.currentTimeMillis();
System.out.printf("Training completed in %.2f s%n", (end - start) / 1000.0); System.out.printf("Training completed in %.2f s%n", (end - start) / 1000.0);

View File

@@ -0,0 +1,8 @@
package com.naaturel.ANN.domain.abstraction;
import com.naaturel.ANN.domain.model.dataset.DataSet;
public interface Trainer {
void train(Neuron n, float learningRate, DataSet dataSet);
}

View File

@@ -27,4 +27,6 @@ public class Synapse {
} }
} }

View File

@@ -1,17 +1,15 @@
package com.naaturel.ANN.implementation.training; package com.naaturel.ANN.implementation.training;
import com.naaturel.ANN.domain.abstraction.Neuron; import com.naaturel.ANN.domain.abstraction.Neuron;
import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.model.dataset.DataSet; import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry; import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
import com.naaturel.ANN.domain.model.neuron.Bias;
import com.naaturel.ANN.domain.model.neuron.Input; import com.naaturel.ANN.domain.model.neuron.Input;
import com.naaturel.ANN.domain.model.neuron.Synapse; import com.naaturel.ANN.domain.model.neuron.Synapse;
import com.naaturel.ANN.domain.model.neuron.Weight; import com.naaturel.ANN.domain.model.neuron.Weight;
import java.util.ArrayList;
import java.util.List;
public class AdalineTraining { public class AdalineTraining implements Trainer {
public AdalineTraining(){ public AdalineTraining(){
@@ -19,15 +17,14 @@ public class AdalineTraining {
public void train(Neuron n, float learningRate, DataSet dataSet) { public void train(Neuron n, float learningRate, DataSet dataSet) {
int epoch = 1; int epoch = 1;
int maxEpoch = 1000; int maxEpoch = 202;
float errorThreshold = 0.0F; float errorThreshold = 0.0F;
float mse; float mse;
do { do {
if(epoch > maxEpoch) break; if(epoch > maxEpoch) break;
mse = 0; mse = 0;
for(DataSetEntry entry : dataSet) {
for(DataSetEntry entry : dataSet) {
this.updateInputs(n, entry); this.updateInputs(n, entry);
float prediction = n.predict(); float prediction = n.predict();
float expectation = dataSet.getLabel(entry).getValue(); float expectation = dataSet.getLabel(entry).getValue();
@@ -49,23 +46,22 @@ public class AdalineTraining {
System.out.printf("predicted : %.2f, ", prediction); System.out.printf("predicted : %.2f, ", prediction);
System.out.printf("expected : %.2f, ", expectation); System.out.printf("expected : %.2f, ", expectation);
System.out.printf("delta : %.2f, ", delta); System.out.printf("delta : %.2f, ", delta);
System.out.printf("loss : %.2f\n", loss); System.out.printf("loss : %.5f\n", loss);
} }
mse /= dataSet.size();
System.out.printf("[Total error : %f]\n", mse); System.out.printf("[Total error : %f]\n", mse);
System.out.println("[Final weights]");
System.out.printf("Bias: %f\n", n.getBias().getWeight());
int i = 1;
for(Synapse syn : n.getSynapses()){
System.out.printf("Syn %d: %f\n", i, syn.getWeight());
i++;
}
epoch++; epoch++;
} while(mse > errorThreshold); } while(mse > errorThreshold);
} }
private List<Float> initCorrectorTerms(int number){
List<Float> res = new ArrayList<>();
for(int i = 0; i < number; i++){
res.add(0F);
}
return res;
}
private void updateInputs(Neuron n, DataSetEntry entry){ private void updateInputs(Neuron n, DataSetEntry entry){
int index = 0; int index = 0;
for(float value : entry){ for(float value : entry){
@@ -82,8 +78,4 @@ public class AdalineTraining {
return (float) Math.pow(delta, 2)/2; return (float) Math.pow(delta, 2)/2;
} }
private float calculateWeightCorrection(float value, float delta){
return value * delta;
}
} }

View File

@@ -1,6 +1,7 @@
package com.naaturel.ANN.implementation.training; package com.naaturel.ANN.implementation.training;
import com.naaturel.ANN.domain.abstraction.Neuron; import com.naaturel.ANN.domain.abstraction.Neuron;
import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.model.dataset.DataSet; import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry; import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
import com.naaturel.ANN.domain.model.neuron.Bias; import com.naaturel.ANN.domain.model.neuron.Bias;
@@ -9,9 +10,10 @@ import com.naaturel.ANN.domain.model.neuron.Synapse;
import com.naaturel.ANN.domain.model.neuron.Weight; import com.naaturel.ANN.domain.model.neuron.Weight;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.List; import java.util.List;
public class GradientDescentTraining { public class GradientDescentTraining implements Trainer {
public GradientDescentTraining(){ public GradientDescentTraining(){
@@ -19,8 +21,8 @@ public class GradientDescentTraining {
public void train(Neuron n, float learningRate, DataSet dataSet) { public void train(Neuron n, float learningRate, DataSet dataSet) {
int epoch = 1; int epoch = 1;
int maxEpoch = 1000; int maxEpoch = 402;
float errorThreshold = 0.0F; float errorThreshold = 0F;
float mse; float mse;
do { do {
@@ -54,6 +56,7 @@ public class GradientDescentTraining {
System.out.printf("delta : %.2f, ", delta); System.out.printf("delta : %.2f, ", delta);
System.out.printf("loss : %.2f\n", loss); System.out.printf("loss : %.2f\n", loss);
} }
mse /= dataSet.size();
System.out.printf("[Total error : %f]\n", mse); System.out.printf("[Total error : %f]\n", mse);
float currentBias = n.getBias().getWeight(); float currentBias = n.getBias().getWeight();
@@ -69,6 +72,13 @@ public class GradientDescentTraining {
epoch++; epoch++;
} while(mse > errorThreshold); } while(mse > errorThreshold);
System.out.println("[Final weights]");
System.out.printf("Bias: %f\n", n.getBias().getWeight());
int i = 1;
for(Synapse syn : n.getSynapses()){
System.out.printf("Syn %d: %f\n", i, syn.getWeight());
i++;
}
} }
private List<Float> initCorrectorTerms(int number){ private List<Float> initCorrectorTerms(int number){
@@ -95,8 +105,21 @@ public class GradientDescentTraining {
return (float) Math.pow(delta, 2)/2; return (float) Math.pow(delta, 2)/2;
} }
private float calculateWeightCorrection(float value, float delta){ public float computeThreshold(DataSet dataSet) {
return value * delta; float sum = 0;
for (DataSetEntry entry : dataSet) {
sum += dataSet.getLabel(entry).getValue();
}
float mean = sum / dataSet.size();
float variance = 0;
for (DataSetEntry entry : dataSet) {
float diff = dataSet.getLabel(entry).getValue() - mean;
variance += diff * diff;
}
variance /= dataSet.size();
return variance;
} }
} }

View File

@@ -1,13 +1,14 @@
package com.naaturel.ANN.implementation.training; package com.naaturel.ANN.implementation.training;
import com.naaturel.ANN.domain.abstraction.Neuron; import com.naaturel.ANN.domain.abstraction.Neuron;
import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.model.dataset.DataSet; import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry; import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
import com.naaturel.ANN.domain.model.neuron.Input; import com.naaturel.ANN.domain.model.neuron.Input;
import com.naaturel.ANN.domain.model.neuron.Synapse; import com.naaturel.ANN.domain.model.neuron.Synapse;
import com.naaturel.ANN.domain.model.neuron.Weight; import com.naaturel.ANN.domain.model.neuron.Weight;
public class SimpleTraining { public class SimpleTraining implements Trainer {
public SimpleTraining() { public SimpleTraining() {