Fully fixed gradient descent
This commit is contained in:
@@ -26,6 +26,13 @@ public class Main {
|
|||||||
Map.entry(new DataSetEntry(List.of(1.0F, 1.0F)), new Label(1.0F))
|
Map.entry(new DataSetEntry(List.of(1.0F, 1.0F)), new Label(1.0F))
|
||||||
));
|
));
|
||||||
|
|
||||||
|
DataSet andDataSet = new DataSet(Map.ofEntries(
|
||||||
|
Map.entry(new DataSetEntry(List.of(0.0F, 0.0F)), new Label(-1.0F)),
|
||||||
|
Map.entry(new DataSetEntry(List.of(0.0F, 1.0F)), new Label(-1.0F)),
|
||||||
|
Map.entry(new DataSetEntry(List.of(1.0F, 0.0F)), new Label(-1.0F)),
|
||||||
|
Map.entry(new DataSetEntry(List.of(1.0F, 1.0F)), new Label(1.0F))
|
||||||
|
));
|
||||||
|
|
||||||
DataSet dataSet = new DataSet(Map.ofEntries(
|
DataSet dataSet = new DataSet(Map.ofEntries(
|
||||||
Map.entry(new DataSetEntry(List.of(1.0F, 6.0F)), new Label(1.0F)),
|
Map.entry(new DataSetEntry(List.of(1.0F, 6.0F)), new Label(1.0F)),
|
||||||
Map.entry(new DataSetEntry(List.of(7.0F, 9.0F)), new Label(-1.0F)),
|
Map.entry(new DataSetEntry(List.of(7.0F, 9.0F)), new Label(-1.0F)),
|
||||||
@@ -51,17 +58,17 @@ public class Main {
|
|||||||
));
|
));
|
||||||
|
|
||||||
List<Synapse> syns = new ArrayList<>();
|
List<Synapse> syns = new ArrayList<>();
|
||||||
syns.add(new Synapse(new Input(0), new Weight()));
|
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||||
syns.add(new Synapse(new Input(0), new Weight()));
|
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||||
|
|
||||||
Bias bias = new Bias(new Weight());
|
Bias bias = new Bias(new Weight(0));
|
||||||
|
|
||||||
Neuron n = new SimplePerceptron(syns, bias, new Linear());
|
Neuron n = new SimplePerceptron(syns, bias, new Linear());
|
||||||
GradientDescentTraining st = new GradientDescentTraining();
|
GradientDescentTraining st = new GradientDescentTraining();
|
||||||
|
|
||||||
long start = System.currentTimeMillis();
|
long start = System.currentTimeMillis();
|
||||||
|
|
||||||
st.train(n, 0.03F, orDataSet.toNormalized());
|
st.train(n, 0.2F, andDataSet);
|
||||||
|
|
||||||
long end = System.currentTimeMillis();
|
long end = System.currentTimeMillis();
|
||||||
System.out.printf("Training completed in %.2f s%n", (end - start) / 1000.0);
|
System.out.printf("Training completed in %.2f s%n", (end - start) / 1000.0);
|
||||||
|
|||||||
@@ -14,6 +14,10 @@ public class DataSet implements Iterable<DataSetEntry>{
|
|||||||
this.data = data;
|
this.data = data;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public int size() {
|
||||||
|
return data.size();
|
||||||
|
}
|
||||||
|
|
||||||
public List<DataSetEntry> getData(){
|
public List<DataSetEntry> getData(){
|
||||||
return new ArrayList<>(this.data.keySet());
|
return new ArrayList<>(this.data.keySet());
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ public class GradientDescentTraining {
|
|||||||
|
|
||||||
public void train(Neuron n, float learningRate, DataSet dataSet) {
|
public void train(Neuron n, float learningRate, DataSet dataSet) {
|
||||||
int epoch = 1;
|
int epoch = 1;
|
||||||
int maxEpoch = 10000;
|
int maxEpoch = 200;
|
||||||
float errorThreshold = 0.125F;
|
float errorThreshold = 0.125F;
|
||||||
float currentError;
|
float currentError;
|
||||||
|
|
||||||
@@ -37,30 +37,28 @@ public class GradientDescentTraining {
|
|||||||
float delta = this.calculateDelta(expectation, prediction);
|
float delta = this.calculateDelta(expectation, prediction);
|
||||||
float loss = this.calculateLoss(delta);
|
float loss = this.calculateLoss(delta);
|
||||||
|
|
||||||
currentError += loss;
|
currentError += loss/dataSet.size();
|
||||||
|
|
||||||
Bias b = n.getBias();
|
biasCorrector += learningRate * delta * n.getBias().getInput();
|
||||||
biasCorrector += this.calculateWeightCorrection(learningRate, b.getInput(), delta);
|
|
||||||
|
|
||||||
for(int i = 0; i < correctorTerms.size(); i++){
|
for(int i = 0; i < correctorTerms.size(); i++){
|
||||||
Synapse syn = n.getSynapse(i);
|
Synapse syn = n.getSynapse(i);
|
||||||
float c = correctorTerms.get(i);
|
float c = correctorTerms.get(i);
|
||||||
c += this.calculateWeightCorrection(learningRate, syn.getInput(), delta);
|
c += learningRate * delta * syn.getInput();
|
||||||
correctorTerms.set(i, c);
|
correctorTerms.set(i, c);
|
||||||
}
|
}
|
||||||
|
|
||||||
if(epoch % 10 != 0) continue;
|
|
||||||
System.out.printf("Epoch : %d ", epoch);
|
System.out.printf("Epoch : %d ", epoch);
|
||||||
System.out.printf("predicted : %.2f, ", prediction);
|
System.out.printf("predicted : %.2f, ", prediction);
|
||||||
System.out.printf("expected : %.2f, ", expectation);
|
System.out.printf("expected : %.2f, ", expectation);
|
||||||
System.out.printf("delta : %.2f, ", delta);
|
System.out.printf("delta : %.2f, ", delta);
|
||||||
System.out.printf("loss : %.2f\n", loss);
|
System.out.printf("loss : %.2f\n", loss);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
System.out.printf("[Total error : %.3f]\n", currentError);
|
System.out.printf("[Total error : %.3f]\n", currentError);
|
||||||
biasCorrector += n.getBias().getWeight();
|
|
||||||
n.updateBias(new Weight(biasCorrector));
|
float currentBias = n.getBias().getWeight();
|
||||||
|
float newBias = currentBias + biasCorrector;
|
||||||
|
n.updateBias(new Weight(newBias));
|
||||||
|
|
||||||
for(int i = 0; i < correctorTerms.size(); i++){
|
for(int i = 0; i < correctorTerms.size(); i++){
|
||||||
Synapse syn = n.getSynapse(i);
|
Synapse syn = n.getSynapse(i);
|
||||||
@@ -97,8 +95,8 @@ public class GradientDescentTraining {
|
|||||||
return ((float) Math.pow(delta, 2))/2;
|
return ((float) Math.pow(delta, 2))/2;
|
||||||
}
|
}
|
||||||
|
|
||||||
private float calculateWeightCorrection(float lr, float value, float delta){
|
private float calculateWeightCorrection(float value, float delta){
|
||||||
return lr * value * delta;
|
return value * delta;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user