Added dataset anf fixed gradient descent algorithm
This commit is contained in:
@@ -19,6 +19,13 @@ public class Main {
|
||||
|
||||
public static void main(String[] args){
|
||||
|
||||
DataSet orDataSet = new DataSet(Map.ofEntries(
|
||||
Map.entry(new DataSetEntry(List.of(0.0F, 0.0F)), new Label(0.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(0.0F, 1.0F)), new Label(1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(1.0F, 0.0F)), new Label(1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(1.0F, 1.0F)), new Label(1.0F))
|
||||
));
|
||||
|
||||
DataSet dataSet = new DataSet(Map.ofEntries(
|
||||
Map.entry(new DataSetEntry(List.of(1.0F, 6.0F)), new Label(1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(7.0F, 9.0F)), new Label(-1.0F)),
|
||||
@@ -51,7 +58,12 @@ public class Main {
|
||||
|
||||
Neuron n = new SimplePerceptron(syns, bias, new Linear());
|
||||
GradientDescentTraining st = new GradientDescentTraining();
|
||||
st.train(n, 0.0003F, dataSet);
|
||||
}
|
||||
|
||||
long start = System.currentTimeMillis();
|
||||
|
||||
st.train(n, 0.03F, orDataSet.toNormalized());
|
||||
|
||||
long end = System.currentTimeMillis();
|
||||
System.out.printf("Training completed in %.2f s%n", (end - start) / 1000.0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,6 +22,26 @@ public class DataSet implements Iterable<DataSetEntry>{
|
||||
return this.data.get(entry);
|
||||
}
|
||||
|
||||
public DataSet toNormalized() {
|
||||
List<DataSetEntry> entries = this.getData();
|
||||
|
||||
float maxAbs = entries.stream()
|
||||
.flatMap(e -> e.getData().stream())
|
||||
.map(Math::abs)
|
||||
.max(Float::compare)
|
||||
.orElse(1.0F);
|
||||
|
||||
Map<DataSetEntry, Label> normalized = new HashMap<>();
|
||||
for (DataSetEntry entry : entries) {
|
||||
List<Float> normalizedData = new ArrayList<>();
|
||||
for (float value : entry.getData()) {
|
||||
normalizedData.add(Math.round((value / maxAbs) * 100.0F) / 100.0F);
|
||||
}
|
||||
normalized.put(new DataSetEntry(normalizedData), this.data.get(entry));
|
||||
}
|
||||
|
||||
return new DataSet(normalized);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Iterator<DataSetEntry> iterator() {
|
||||
|
||||
@@ -25,7 +25,7 @@ public class SimplePerceptron extends Neuron implements Trainable {
|
||||
for(Synapse syn : super.synapses){
|
||||
res += syn.getWeight() * syn.getInput();
|
||||
}
|
||||
|
||||
res += this.bias.getWeight() * this.bias.getInput();
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ public class GradientDescentTraining {
|
||||
float loss = this.calculateLoss(delta);
|
||||
|
||||
currentError += loss;
|
||||
|
||||
Bias b = n.getBias();
|
||||
biasCorrector += this.calculateWeightCorrection(learningRate, b.getInput(), delta);
|
||||
|
||||
@@ -47,6 +48,8 @@ public class GradientDescentTraining {
|
||||
c += this.calculateWeightCorrection(learningRate, syn.getInput(), delta);
|
||||
correctorTerms.set(i, c);
|
||||
}
|
||||
|
||||
if(epoch % 10 != 0) continue;
|
||||
System.out.printf("Epoch : %d ", epoch);
|
||||
System.out.printf("predicted : %.2f, ", prediction);
|
||||
System.out.printf("expected : %.2f, ", expectation);
|
||||
@@ -54,13 +57,15 @@ public class GradientDescentTraining {
|
||||
System.out.printf("loss : %.2f\n", loss);
|
||||
|
||||
}
|
||||
System.out.printf("[Total error : %.2f]\n", currentError);
|
||||
|
||||
System.out.printf("[Total error : %.3f]\n", currentError);
|
||||
biasCorrector += n.getBias().getWeight();
|
||||
n.updateBias(new Weight(biasCorrector));
|
||||
|
||||
for(int i = 0; i < correctorTerms.size(); i++){
|
||||
Synapse syn = n.getSynapse(i);
|
||||
float c = correctorTerms.get(i);
|
||||
syn.setWeight(syn.getWeight() + c);
|
||||
float c = syn.getWeight() + correctorTerms.get(i);
|
||||
syn.setWeight(c);
|
||||
}
|
||||
|
||||
epoch++;
|
||||
|
||||
Reference in New Issue
Block a user