Fix plotting
This commit is contained in:
@@ -1,14 +1,14 @@
|
|||||||
{
|
{
|
||||||
"model": {
|
"model": {
|
||||||
"new": true,
|
"new": true,
|
||||||
"parameters": [1],
|
"parameters": [10, 5, 1],
|
||||||
"path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/snapshot-test.json"
|
"path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/snapshot-test.json"
|
||||||
},
|
},
|
||||||
"training" : {
|
"training" : {
|
||||||
"learning_rate" : 1.0,
|
"learning_rate" : 0.002,
|
||||||
"max_epoch" : 5000
|
"max_epoch" : 5000
|
||||||
},
|
},
|
||||||
"dataset" : {
|
"dataset" : {
|
||||||
"path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and.csv"
|
"path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_10.csv"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -21,7 +21,7 @@ public class Main {
|
|||||||
|
|
||||||
public static void main(String[] args) throws Exception {
|
public static void main(String[] args) throws Exception {
|
||||||
|
|
||||||
String[] types = {"Simple perceptron", "Gradient descent", "Adaline", "Gradient retro-propagation"};
|
String[] types = {"Simple perceptron", "Gradient descent", "Adaline", "Gradient backpropagation"};
|
||||||
|
|
||||||
Scanner sc = new Scanner(System.in);
|
Scanner sc = new Scanner(System.in);
|
||||||
for (int i = 0; i < types.length; i++) {
|
for (int i = 0; i < types.length; i++) {
|
||||||
@@ -77,22 +77,22 @@ public class Main {
|
|||||||
if(dataset.getNbrInputs() != 2) return;
|
if(dataset.getNbrInputs() != 2) return;
|
||||||
|
|
||||||
GraphVisualizer visualizer = new GraphVisualizer();
|
GraphVisualizer visualizer = new GraphVisualizer();
|
||||||
float min = -10F;
|
float min = 0F;
|
||||||
float max = 10F;
|
float max = 15F;
|
||||||
float step = Math.abs(max - min)/300;
|
float step = Math.abs(max - min)/300;
|
||||||
|
|
||||||
boolean isOneHot = dataset.getNbrLabels() > 1;
|
boolean isOneHot = dataset.getNbrLabels() > 1;
|
||||||
|
|
||||||
//plot labels
|
//plot labels
|
||||||
for (DataSetEntry entry : dataset) {
|
for (DataSetEntry entry : dataset) {
|
||||||
List<Float> label = dataset.getLabelsAsFloat(entry);
|
List<Float> labels = dataset.getLabelsAsFloat(entry);
|
||||||
int expectedClass = 0;
|
int expectedClass = 0;
|
||||||
if (isOneHot) {
|
if (isOneHot) {
|
||||||
for (int i = 1; i < label.size(); i++) {
|
for (int i = 1; i < labels.size(); i++) {
|
||||||
if (label.get(i) > label.get(expectedClass)) expectedClass = i;
|
if (labels.get(i) > labels.get(expectedClass)) expectedClass = i;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
expectedClass = Math.round(label.getFirst());
|
expectedClass = Math.round(labels.getFirst());
|
||||||
}
|
}
|
||||||
visualizer.addPoint("Label " + expectedClass,
|
visualizer.addPoint("Label " + expectedClass,
|
||||||
entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
|
entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
|
||||||
@@ -105,12 +105,7 @@ public class Main {
|
|||||||
List<Float> predictions = new ArrayList<>();
|
List<Float> predictions = new ArrayList<>();
|
||||||
for (float p : network.predict(new float[]{x, y})) predictions.add(p);
|
for (float p : network.predict(new float[]{x, y})) predictions.add(p);
|
||||||
|
|
||||||
int predictedClass;
|
int predictedClass = dataset.getPredictedClass(predictions);
|
||||||
if(isOneHot){
|
|
||||||
predictedClass = predictions.indexOf(Collections.max(predictions));
|
|
||||||
} else {
|
|
||||||
predictedClass = Math.round(predictions.getFirst());
|
|
||||||
}
|
|
||||||
|
|
||||||
visualizer.addPoint("Class " + predictedClass, x, y);
|
visualizer.addPoint("Class " + predictedClass, x, y);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -36,7 +36,6 @@ public class DataSet implements Iterable<DataSetEntry>{
|
|||||||
return firstEntry.map(inputs -> this.data.get(inputs).getValues().size()).orElse(0);
|
return firstEntry.map(inputs -> this.data.get(inputs).getValues().size()).orElse(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public int size() {
|
public int size() {
|
||||||
return data.size();
|
return data.size();
|
||||||
}
|
}
|
||||||
@@ -57,6 +56,25 @@ public class DataSet implements Iterable<DataSetEntry>{
|
|||||||
return this.data.get(entry).getValues();
|
return this.data.get(entry).getValues();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public int getPredictedClass(List<Float> values) {
|
||||||
|
if (nbrLabels == 1) {
|
||||||
|
// single value — round to nearest known class
|
||||||
|
float value = values.getFirst();
|
||||||
|
return data.values().stream()
|
||||||
|
.map(l -> l.getValues().getFirst())
|
||||||
|
.distinct()
|
||||||
|
.min(Comparator.comparingDouble(l -> Math.abs(l - value)))
|
||||||
|
.map(Math::round)
|
||||||
|
.orElseThrow();
|
||||||
|
} else {
|
||||||
|
// one-hot — return index of max value
|
||||||
|
int max = 0;
|
||||||
|
for (int i = 1; i < values.size(); i++)
|
||||||
|
if (values.get(i) > values.get(max)) max = i;
|
||||||
|
return max;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public DataSet toNormalized() {
|
public DataSet toNormalized() {
|
||||||
List<DataSetEntry> entries = this.getData();
|
List<DataSetEntry> entries = this.getData();
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user