Fully fixed gradient descent

This commit is contained in:
2026-03-18 13:15:01 +01:00
parent 7dc55a370f
commit a2a74566ba
3 changed files with 25 additions and 16 deletions

View File

@@ -26,6 +26,13 @@ public class Main {
Map.entry(new DataSetEntry(List.of(1.0F, 1.0F)), new Label(1.0F)) Map.entry(new DataSetEntry(List.of(1.0F, 1.0F)), new Label(1.0F))
)); ));
DataSet andDataSet = new DataSet(Map.ofEntries(
Map.entry(new DataSetEntry(List.of(0.0F, 0.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(0.0F, 1.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(1.0F, 0.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(1.0F, 1.0F)), new Label(1.0F))
));
DataSet dataSet = new DataSet(Map.ofEntries( DataSet dataSet = new DataSet(Map.ofEntries(
Map.entry(new DataSetEntry(List.of(1.0F, 6.0F)), new Label(1.0F)), Map.entry(new DataSetEntry(List.of(1.0F, 6.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(7.0F, 9.0F)), new Label(-1.0F)), Map.entry(new DataSetEntry(List.of(7.0F, 9.0F)), new Label(-1.0F)),
@@ -51,17 +58,17 @@ public class Main {
)); ));
List<Synapse> syns = new ArrayList<>(); List<Synapse> syns = new ArrayList<>();
syns.add(new Synapse(new Input(0), new Weight())); syns.add(new Synapse(new Input(0), new Weight(0)));
syns.add(new Synapse(new Input(0), new Weight())); syns.add(new Synapse(new Input(0), new Weight(0)));
Bias bias = new Bias(new Weight()); Bias bias = new Bias(new Weight(0));
Neuron n = new SimplePerceptron(syns, bias, new Linear()); Neuron n = new SimplePerceptron(syns, bias, new Linear());
GradientDescentTraining st = new GradientDescentTraining(); GradientDescentTraining st = new GradientDescentTraining();
long start = System.currentTimeMillis(); long start = System.currentTimeMillis();
st.train(n, 0.03F, orDataSet.toNormalized()); st.train(n, 0.2F, andDataSet);
long end = System.currentTimeMillis(); long end = System.currentTimeMillis();
System.out.printf("Training completed in %.2f s%n", (end - start) / 1000.0); System.out.printf("Training completed in %.2f s%n", (end - start) / 1000.0);

View File

@@ -14,6 +14,10 @@ public class DataSet implements Iterable<DataSetEntry>{
this.data = data; this.data = data;
} }
public int size() {
return data.size();
}
public List<DataSetEntry> getData(){ public List<DataSetEntry> getData(){
return new ArrayList<>(this.data.keySet()); return new ArrayList<>(this.data.keySet());
} }

View File

@@ -19,7 +19,7 @@ public class GradientDescentTraining {
public void train(Neuron n, float learningRate, DataSet dataSet) { public void train(Neuron n, float learningRate, DataSet dataSet) {
int epoch = 1; int epoch = 1;
int maxEpoch = 10000; int maxEpoch = 200;
float errorThreshold = 0.125F; float errorThreshold = 0.125F;
float currentError; float currentError;
@@ -37,30 +37,28 @@ public class GradientDescentTraining {
float delta = this.calculateDelta(expectation, prediction); float delta = this.calculateDelta(expectation, prediction);
float loss = this.calculateLoss(delta); float loss = this.calculateLoss(delta);
currentError += loss; currentError += loss/dataSet.size();
Bias b = n.getBias(); biasCorrector += learningRate * delta * n.getBias().getInput();
biasCorrector += this.calculateWeightCorrection(learningRate, b.getInput(), delta);
for(int i = 0; i < correctorTerms.size(); i++){ for(int i = 0; i < correctorTerms.size(); i++){
Synapse syn = n.getSynapse(i); Synapse syn = n.getSynapse(i);
float c = correctorTerms.get(i); float c = correctorTerms.get(i);
c += this.calculateWeightCorrection(learningRate, syn.getInput(), delta); c += learningRate * delta * syn.getInput();
correctorTerms.set(i, c); correctorTerms.set(i, c);
} }
if(epoch % 10 != 0) continue;
System.out.printf("Epoch : %d ", epoch); System.out.printf("Epoch : %d ", epoch);
System.out.printf("predicted : %.2f, ", prediction); System.out.printf("predicted : %.2f, ", prediction);
System.out.printf("expected : %.2f, ", expectation); System.out.printf("expected : %.2f, ", expectation);
System.out.printf("delta : %.2f, ", delta); System.out.printf("delta : %.2f, ", delta);
System.out.printf("loss : %.2f\n", loss); System.out.printf("loss : %.2f\n", loss);
} }
System.out.printf("[Total error : %.3f]\n", currentError); System.out.printf("[Total error : %.3f]\n", currentError);
biasCorrector += n.getBias().getWeight();
n.updateBias(new Weight(biasCorrector)); float currentBias = n.getBias().getWeight();
float newBias = currentBias + biasCorrector;
n.updateBias(new Weight(newBias));
for(int i = 0; i < correctorTerms.size(); i++){ for(int i = 0; i < correctorTerms.size(); i++){
Synapse syn = n.getSynapse(i); Synapse syn = n.getSynapse(i);
@@ -97,8 +95,8 @@ public class GradientDescentTraining {
return ((float) Math.pow(delta, 2))/2; return ((float) Math.pow(delta, 2))/2;
} }
private float calculateWeightCorrection(float lr, float value, float delta){ private float calculateWeightCorrection(float value, float delta){
return lr * value * delta; return value * delta;
} }
} }