Just a regular commit
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
package com.naaturel.ANN;
|
package com.naaturel.ANN;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||||
import com.naaturel.ANN.domain.model.dataset.Label;
|
import com.naaturel.ANN.domain.model.dataset.Label;
|
||||||
@@ -8,7 +9,6 @@ import com.naaturel.ANN.domain.model.neuron.Bias;
|
|||||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||||
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
||||||
import com.naaturel.ANN.domain.model.neuron.Weight;
|
import com.naaturel.ANN.domain.model.neuron.Weight;
|
||||||
import com.naaturel.ANN.implementation.activationFunction.Heaviside;
|
|
||||||
import com.naaturel.ANN.implementation.activationFunction.Linear;
|
import com.naaturel.ANN.implementation.activationFunction.Linear;
|
||||||
import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
|
import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
|
||||||
import com.naaturel.ANN.implementation.training.AdalineTraining;
|
import com.naaturel.ANN.implementation.training.AdalineTraining;
|
||||||
@@ -29,8 +29,8 @@ public class Main {
|
|||||||
|
|
||||||
DataSet andDataSet = new DataSet(Map.ofEntries(
|
DataSet andDataSet = new DataSet(Map.ofEntries(
|
||||||
Map.entry(new DataSetEntry(List.of(0.0F, 0.0F)), new Label(-1.0F)),
|
Map.entry(new DataSetEntry(List.of(0.0F, 0.0F)), new Label(-1.0F)),
|
||||||
Map.entry(new DataSetEntry(List.of(0.0F, 1.0F)), new Label(-1.0F)),
|
|
||||||
Map.entry(new DataSetEntry(List.of(1.0F, 0.0F)), new Label(-1.0F)),
|
Map.entry(new DataSetEntry(List.of(1.0F, 0.0F)), new Label(-1.0F)),
|
||||||
|
Map.entry(new DataSetEntry(List.of(0.0F, 1.0F)), new Label(-1.0F)),
|
||||||
Map.entry(new DataSetEntry(List.of(1.0F, 1.0F)), new Label(1.0F))
|
Map.entry(new DataSetEntry(List.of(1.0F, 1.0F)), new Label(1.0F))
|
||||||
));
|
));
|
||||||
|
|
||||||
@@ -65,11 +65,11 @@ public class Main {
|
|||||||
Bias bias = new Bias(new Weight(0));
|
Bias bias = new Bias(new Weight(0));
|
||||||
|
|
||||||
Neuron n = new SimplePerceptron(syns, bias, new Linear());
|
Neuron n = new SimplePerceptron(syns, bias, new Linear());
|
||||||
AdalineTraining st = new AdalineTraining();
|
Trainer trainer = new AdalineTraining();
|
||||||
|
|
||||||
long start = System.currentTimeMillis();
|
long start = System.currentTimeMillis();
|
||||||
|
|
||||||
st.train(n, 0.03F, andDataSet);
|
trainer.train(n, 0.03F, andDataSet);
|
||||||
|
|
||||||
long end = System.currentTimeMillis();
|
long end = System.currentTimeMillis();
|
||||||
System.out.printf("Training completed in %.2f s%n", (end - start) / 1000.0);
|
System.out.printf("Training completed in %.2f s%n", (end - start) / 1000.0);
|
||||||
|
|||||||
@@ -0,0 +1,8 @@
|
|||||||
|
package com.naaturel.ANN.domain.abstraction;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||||
|
|
||||||
|
public interface Trainer {
|
||||||
|
|
||||||
|
void train(Neuron n, float learningRate, DataSet dataSet);
|
||||||
|
}
|
||||||
@@ -27,4 +27,6 @@ public class Synapse {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,17 +1,15 @@
|
|||||||
package com.naaturel.ANN.implementation.training;
|
package com.naaturel.ANN.implementation.training;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||||
import com.naaturel.ANN.domain.model.neuron.Bias;
|
|
||||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||||
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
||||||
import com.naaturel.ANN.domain.model.neuron.Weight;
|
import com.naaturel.ANN.domain.model.neuron.Weight;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class AdalineTraining {
|
public class AdalineTraining implements Trainer {
|
||||||
|
|
||||||
public AdalineTraining(){
|
public AdalineTraining(){
|
||||||
|
|
||||||
@@ -19,15 +17,14 @@ public class AdalineTraining {
|
|||||||
|
|
||||||
public void train(Neuron n, float learningRate, DataSet dataSet) {
|
public void train(Neuron n, float learningRate, DataSet dataSet) {
|
||||||
int epoch = 1;
|
int epoch = 1;
|
||||||
int maxEpoch = 1000;
|
int maxEpoch = 202;
|
||||||
float errorThreshold = 0.0F;
|
float errorThreshold = 0.0F;
|
||||||
float mse;
|
float mse;
|
||||||
|
|
||||||
do {
|
do {
|
||||||
if(epoch > maxEpoch) break;
|
if(epoch > maxEpoch) break;
|
||||||
mse = 0;
|
mse = 0;
|
||||||
|
for(DataSetEntry entry : dataSet) {
|
||||||
for(DataSetEntry entry : dataSet) {
|
|
||||||
this.updateInputs(n, entry);
|
this.updateInputs(n, entry);
|
||||||
float prediction = n.predict();
|
float prediction = n.predict();
|
||||||
float expectation = dataSet.getLabel(entry).getValue();
|
float expectation = dataSet.getLabel(entry).getValue();
|
||||||
@@ -49,23 +46,22 @@ public class AdalineTraining {
|
|||||||
System.out.printf("predicted : %.2f, ", prediction);
|
System.out.printf("predicted : %.2f, ", prediction);
|
||||||
System.out.printf("expected : %.2f, ", expectation);
|
System.out.printf("expected : %.2f, ", expectation);
|
||||||
System.out.printf("delta : %.2f, ", delta);
|
System.out.printf("delta : %.2f, ", delta);
|
||||||
System.out.printf("loss : %.2f\n", loss);
|
System.out.printf("loss : %.5f\n", loss);
|
||||||
}
|
}
|
||||||
|
mse /= dataSet.size();
|
||||||
System.out.printf("[Total error : %f]\n", mse);
|
System.out.printf("[Total error : %f]\n", mse);
|
||||||
|
System.out.println("[Final weights]");
|
||||||
|
System.out.printf("Bias: %f\n", n.getBias().getWeight());
|
||||||
|
int i = 1;
|
||||||
|
for(Synapse syn : n.getSynapses()){
|
||||||
|
System.out.printf("Syn %d: %f\n", i, syn.getWeight());
|
||||||
|
i++;
|
||||||
|
}
|
||||||
epoch++;
|
epoch++;
|
||||||
} while(mse > errorThreshold);
|
} while(mse > errorThreshold);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<Float> initCorrectorTerms(int number){
|
|
||||||
List<Float> res = new ArrayList<>();
|
|
||||||
for(int i = 0; i < number; i++){
|
|
||||||
res.add(0F);
|
|
||||||
}
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void updateInputs(Neuron n, DataSetEntry entry){
|
private void updateInputs(Neuron n, DataSetEntry entry){
|
||||||
int index = 0;
|
int index = 0;
|
||||||
for(float value : entry){
|
for(float value : entry){
|
||||||
@@ -82,8 +78,4 @@ public class AdalineTraining {
|
|||||||
return (float) Math.pow(delta, 2)/2;
|
return (float) Math.pow(delta, 2)/2;
|
||||||
}
|
}
|
||||||
|
|
||||||
private float calculateWeightCorrection(float value, float delta){
|
|
||||||
return value * delta;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package com.naaturel.ANN.implementation.training;
|
package com.naaturel.ANN.implementation.training;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||||
import com.naaturel.ANN.domain.model.neuron.Bias;
|
import com.naaturel.ANN.domain.model.neuron.Bias;
|
||||||
@@ -9,9 +10,10 @@ import com.naaturel.ANN.domain.model.neuron.Synapse;
|
|||||||
import com.naaturel.ANN.domain.model.neuron.Weight;
|
import com.naaturel.ANN.domain.model.neuron.Weight;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public class GradientDescentTraining {
|
public class GradientDescentTraining implements Trainer {
|
||||||
|
|
||||||
public GradientDescentTraining(){
|
public GradientDescentTraining(){
|
||||||
|
|
||||||
@@ -19,8 +21,8 @@ public class GradientDescentTraining {
|
|||||||
|
|
||||||
public void train(Neuron n, float learningRate, DataSet dataSet) {
|
public void train(Neuron n, float learningRate, DataSet dataSet) {
|
||||||
int epoch = 1;
|
int epoch = 1;
|
||||||
int maxEpoch = 1000;
|
int maxEpoch = 402;
|
||||||
float errorThreshold = 0.0F;
|
float errorThreshold = 0F;
|
||||||
float mse;
|
float mse;
|
||||||
|
|
||||||
do {
|
do {
|
||||||
@@ -54,6 +56,7 @@ public class GradientDescentTraining {
|
|||||||
System.out.printf("delta : %.2f, ", delta);
|
System.out.printf("delta : %.2f, ", delta);
|
||||||
System.out.printf("loss : %.2f\n", loss);
|
System.out.printf("loss : %.2f\n", loss);
|
||||||
}
|
}
|
||||||
|
mse /= dataSet.size();
|
||||||
System.out.printf("[Total error : %f]\n", mse);
|
System.out.printf("[Total error : %f]\n", mse);
|
||||||
|
|
||||||
float currentBias = n.getBias().getWeight();
|
float currentBias = n.getBias().getWeight();
|
||||||
@@ -69,6 +72,13 @@ public class GradientDescentTraining {
|
|||||||
epoch++;
|
epoch++;
|
||||||
} while(mse > errorThreshold);
|
} while(mse > errorThreshold);
|
||||||
|
|
||||||
|
System.out.println("[Final weights]");
|
||||||
|
System.out.printf("Bias: %f\n", n.getBias().getWeight());
|
||||||
|
int i = 1;
|
||||||
|
for(Synapse syn : n.getSynapses()){
|
||||||
|
System.out.printf("Syn %d: %f\n", i, syn.getWeight());
|
||||||
|
i++;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<Float> initCorrectorTerms(int number){
|
private List<Float> initCorrectorTerms(int number){
|
||||||
@@ -95,8 +105,21 @@ public class GradientDescentTraining {
|
|||||||
return (float) Math.pow(delta, 2)/2;
|
return (float) Math.pow(delta, 2)/2;
|
||||||
}
|
}
|
||||||
|
|
||||||
private float calculateWeightCorrection(float value, float delta){
|
public float computeThreshold(DataSet dataSet) {
|
||||||
return value * delta;
|
float sum = 0;
|
||||||
|
for (DataSetEntry entry : dataSet) {
|
||||||
|
sum += dataSet.getLabel(entry).getValue();
|
||||||
|
}
|
||||||
|
float mean = sum / dataSet.size();
|
||||||
|
|
||||||
|
float variance = 0;
|
||||||
|
for (DataSetEntry entry : dataSet) {
|
||||||
|
float diff = dataSet.getLabel(entry).getValue() - mean;
|
||||||
|
variance += diff * diff;
|
||||||
|
}
|
||||||
|
variance /= dataSet.size();
|
||||||
|
|
||||||
|
return variance;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +1,14 @@
|
|||||||
package com.naaturel.ANN.implementation.training;
|
package com.naaturel.ANN.implementation.training;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||||
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
||||||
import com.naaturel.ANN.domain.model.neuron.Weight;
|
import com.naaturel.ANN.domain.model.neuron.Weight;
|
||||||
|
|
||||||
public class SimpleTraining {
|
public class SimpleTraining implements Trainer {
|
||||||
|
|
||||||
public SimpleTraining() {
|
public SimpleTraining() {
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user