diff --git a/src/main/java/com/naaturel/ANN/Main.java b/src/main/java/com/naaturel/ANN/Main.java index 02e5de2..ee77513 100644 --- a/src/main/java/com/naaturel/ANN/Main.java +++ b/src/main/java/com/naaturel/ANN/Main.java @@ -26,6 +26,13 @@ public class Main { Map.entry(new DataSetEntry(List.of(1.0F, 1.0F)), new Label(1.0F)) )); + DataSet andDataSet = new DataSet(Map.ofEntries( + Map.entry(new DataSetEntry(List.of(0.0F, 0.0F)), new Label(-1.0F)), + Map.entry(new DataSetEntry(List.of(0.0F, 1.0F)), new Label(-1.0F)), + Map.entry(new DataSetEntry(List.of(1.0F, 0.0F)), new Label(-1.0F)), + Map.entry(new DataSetEntry(List.of(1.0F, 1.0F)), new Label(1.0F)) + )); + DataSet dataSet = new DataSet(Map.ofEntries( Map.entry(new DataSetEntry(List.of(1.0F, 6.0F)), new Label(1.0F)), Map.entry(new DataSetEntry(List.of(7.0F, 9.0F)), new Label(-1.0F)), @@ -51,17 +58,17 @@ public class Main { )); List syns = new ArrayList<>(); - syns.add(new Synapse(new Input(0), new Weight())); - syns.add(new Synapse(new Input(0), new Weight())); + syns.add(new Synapse(new Input(0), new Weight(0))); + syns.add(new Synapse(new Input(0), new Weight(0))); - Bias bias = new Bias(new Weight()); + Bias bias = new Bias(new Weight(0)); Neuron n = new SimplePerceptron(syns, bias, new Linear()); GradientDescentTraining st = new GradientDescentTraining(); long start = System.currentTimeMillis(); - st.train(n, 0.03F, orDataSet.toNormalized()); + st.train(n, 0.2F, andDataSet); long end = System.currentTimeMillis(); System.out.printf("Training completed in %.2f s%n", (end - start) / 1000.0); diff --git a/src/main/java/com/naaturel/ANN/domain/model/dataset/DataSet.java b/src/main/java/com/naaturel/ANN/domain/model/dataset/DataSet.java index 88073bb..d0df51e 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/dataset/DataSet.java +++ b/src/main/java/com/naaturel/ANN/domain/model/dataset/DataSet.java @@ -14,6 +14,10 @@ public class DataSet implements Iterable{ this.data = data; } + public int size() { + return data.size(); + } + public List getData(){ return new ArrayList<>(this.data.keySet()); } diff --git a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java index 843d731..42022e3 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java @@ -19,7 +19,7 @@ public class GradientDescentTraining { public void train(Neuron n, float learningRate, DataSet dataSet) { int epoch = 1; - int maxEpoch = 10000; + int maxEpoch = 200; float errorThreshold = 0.125F; float currentError; @@ -37,30 +37,28 @@ public class GradientDescentTraining { float delta = this.calculateDelta(expectation, prediction); float loss = this.calculateLoss(delta); - currentError += loss; + currentError += loss/dataSet.size(); - Bias b = n.getBias(); - biasCorrector += this.calculateWeightCorrection(learningRate, b.getInput(), delta); + biasCorrector += learningRate * delta * n.getBias().getInput(); for(int i = 0; i < correctorTerms.size(); i++){ Synapse syn = n.getSynapse(i); float c = correctorTerms.get(i); - c += this.calculateWeightCorrection(learningRate, syn.getInput(), delta); + c += learningRate * delta * syn.getInput(); correctorTerms.set(i, c); } - if(epoch % 10 != 0) continue; System.out.printf("Epoch : %d ", epoch); System.out.printf("predicted : %.2f, ", prediction); System.out.printf("expected : %.2f, ", expectation); System.out.printf("delta : %.2f, ", delta); System.out.printf("loss : %.2f\n", loss); - } - System.out.printf("[Total error : %.3f]\n", currentError); - biasCorrector += n.getBias().getWeight(); - n.updateBias(new Weight(biasCorrector)); + + float currentBias = n.getBias().getWeight(); + float newBias = currentBias + biasCorrector; + n.updateBias(new Weight(newBias)); for(int i = 0; i < correctorTerms.size(); i++){ Synapse syn = n.getSynapse(i); @@ -97,8 +95,8 @@ public class GradientDescentTraining { return ((float) Math.pow(delta, 2))/2; } - private float calculateWeightCorrection(float lr, float value, float delta){ - return lr * value * delta; + private float calculateWeightCorrection(float value, float delta){ + return value * delta; } }