Just a regular commit
This commit is contained in:
@@ -11,6 +11,7 @@ import com.naaturel.ANN.domain.model.neuron.Weight;
|
||||
import com.naaturel.ANN.implementation.activationFunction.Heaviside;
|
||||
import com.naaturel.ANN.implementation.activationFunction.Linear;
|
||||
import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
|
||||
import com.naaturel.ANN.implementation.training.AdalineTraining;
|
||||
import com.naaturel.ANN.implementation.training.GradientDescentTraining;
|
||||
|
||||
import java.util.*;
|
||||
@@ -64,11 +65,11 @@ public class Main {
|
||||
Bias bias = new Bias(new Weight(0));
|
||||
|
||||
Neuron n = new SimplePerceptron(syns, bias, new Linear());
|
||||
GradientDescentTraining st = new GradientDescentTraining();
|
||||
AdalineTraining st = new AdalineTraining();
|
||||
|
||||
long start = System.currentTimeMillis();
|
||||
|
||||
st.train(n, 0.2F, andDataSet);
|
||||
st.train(n, 0.03F, andDataSet);
|
||||
|
||||
long end = System.currentTimeMillis();
|
||||
System.out.printf("Training completed in %.2f s%n", (end - start) / 1000.0);
|
||||
|
||||
@@ -1,4 +1,89 @@
|
||||
package com.naaturel.ANN.implementation.training;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||
import com.naaturel.ANN.domain.model.neuron.Bias;
|
||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
||||
import com.naaturel.ANN.domain.model.neuron.Weight;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class AdalineTraining {
|
||||
|
||||
public AdalineTraining(){
|
||||
|
||||
}
|
||||
|
||||
public void train(Neuron n, float learningRate, DataSet dataSet) {
|
||||
int epoch = 1;
|
||||
int maxEpoch = 1000;
|
||||
float errorThreshold = 0.0F;
|
||||
float mse;
|
||||
|
||||
do {
|
||||
if(epoch > maxEpoch) break;
|
||||
mse = 0;
|
||||
|
||||
for(DataSetEntry entry : dataSet) {
|
||||
this.updateInputs(n, entry);
|
||||
float prediction = n.predict();
|
||||
float expectation = dataSet.getLabel(entry).getValue();
|
||||
float delta = this.calculateDelta(expectation, prediction);
|
||||
float loss = this.calculateLoss(delta);
|
||||
|
||||
mse += loss;
|
||||
|
||||
float currentBias = n.getBias().getWeight();
|
||||
float biasCorrector = currentBias + (learningRate * delta * n.getBias().getInput());
|
||||
n.updateBias(new Weight(biasCorrector));
|
||||
|
||||
for(Synapse syn : n.getSynapses()){
|
||||
float synCorrector = syn.getWeight() + (learningRate * delta * syn.getInput());
|
||||
syn.setWeight(synCorrector);
|
||||
}
|
||||
|
||||
System.out.printf("Epoch : %d ", epoch);
|
||||
System.out.printf("predicted : %.2f, ", prediction);
|
||||
System.out.printf("expected : %.2f, ", expectation);
|
||||
System.out.printf("delta : %.2f, ", delta);
|
||||
System.out.printf("loss : %.2f\n", loss);
|
||||
}
|
||||
System.out.printf("[Total error : %f]\n", mse);
|
||||
|
||||
epoch++;
|
||||
} while(mse > errorThreshold);
|
||||
|
||||
}
|
||||
|
||||
private List<Float> initCorrectorTerms(int number){
|
||||
List<Float> res = new ArrayList<>();
|
||||
for(int i = 0; i < number; i++){
|
||||
res.add(0F);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
private void updateInputs(Neuron n, DataSetEntry entry){
|
||||
int index = 0;
|
||||
for(float value : entry){
|
||||
n.setInput(index, new Input(value));
|
||||
index++;
|
||||
}
|
||||
}
|
||||
|
||||
private float calculateDelta(float expected, float predicted){
|
||||
return expected - predicted;
|
||||
}
|
||||
|
||||
private float calculateLoss(float delta){
|
||||
return (float) Math.pow(delta, 2)/2;
|
||||
}
|
||||
|
||||
private float calculateWeightCorrection(float value, float delta){
|
||||
return value * delta;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -19,15 +19,15 @@ public class GradientDescentTraining {
|
||||
|
||||
public void train(Neuron n, float learningRate, DataSet dataSet) {
|
||||
int epoch = 1;
|
||||
int maxEpoch = 200;
|
||||
float errorThreshold = 0.125F;
|
||||
float currentError;
|
||||
int maxEpoch = 1000;
|
||||
float errorThreshold = 0.0F;
|
||||
float mse;
|
||||
|
||||
do {
|
||||
if(epoch > maxEpoch) break;
|
||||
|
||||
float biasCorrector = 0;
|
||||
currentError = 0;
|
||||
mse = 0;
|
||||
List<Float> correctorTerms = this.initCorrectorTerms(n.getSynCount());
|
||||
|
||||
for(DataSetEntry entry : dataSet) {
|
||||
@@ -37,7 +37,7 @@ public class GradientDescentTraining {
|
||||
float delta = this.calculateDelta(expectation, prediction);
|
||||
float loss = this.calculateLoss(delta);
|
||||
|
||||
currentError += loss/dataSet.size();
|
||||
mse += loss;
|
||||
|
||||
biasCorrector += learningRate * delta * n.getBias().getInput();
|
||||
|
||||
@@ -54,7 +54,7 @@ public class GradientDescentTraining {
|
||||
System.out.printf("delta : %.2f, ", delta);
|
||||
System.out.printf("loss : %.2f\n", loss);
|
||||
}
|
||||
System.out.printf("[Total error : %.3f]\n", currentError);
|
||||
System.out.printf("[Total error : %f]\n", mse);
|
||||
|
||||
float currentBias = n.getBias().getWeight();
|
||||
float newBias = currentBias + biasCorrector;
|
||||
@@ -67,7 +67,7 @@ public class GradientDescentTraining {
|
||||
}
|
||||
|
||||
epoch++;
|
||||
} while(currentError > errorThreshold);
|
||||
} while(mse > errorThreshold);
|
||||
|
||||
}
|
||||
|
||||
@@ -92,7 +92,7 @@ public class GradientDescentTraining {
|
||||
}
|
||||
|
||||
private float calculateLoss(float delta){
|
||||
return ((float) Math.pow(delta, 2))/2;
|
||||
return (float) Math.pow(delta, 2)/2;
|
||||
}
|
||||
|
||||
private float calculateWeightCorrection(float value, float delta){
|
||||
|
||||
Reference in New Issue
Block a user