Add plotting support for both one-hot vector and single integer classes
This commit is contained in:
@@ -1,14 +1,14 @@
|
|||||||
{
|
{
|
||||||
"model": {
|
"model": {
|
||||||
"new": true,
|
"new": true,
|
||||||
"parameters": [5, 2, 3],
|
"parameters": [5, 5, 1],
|
||||||
"path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/snapshot-test.json"
|
"path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/snapshot-4-12.json"
|
||||||
},
|
},
|
||||||
"training" : {
|
"training" : {
|
||||||
"learning_rate" : 0.0003,
|
"learning_rate" : 0.03,
|
||||||
"max_epoch" : 5000
|
"max_epoch" : 5000
|
||||||
},
|
},
|
||||||
"dataset" : {
|
"dataset" : {
|
||||||
"path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_3_1.csv"
|
"path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/xor.csv"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -37,7 +37,6 @@ public class Main {
|
|||||||
int nbrClass = input.isEmpty() ? modelParameters[modelParameters.length-1] : Integer.parseInt(input);
|
int nbrClass = input.isEmpty() ? modelParameters[modelParameters.length-1] : Integer.parseInt(input);
|
||||||
System.out.println();
|
System.out.println();
|
||||||
|
|
||||||
|
|
||||||
DataSet dataset = new DatasetExtractor().extract(datasetPath, nbrClass);
|
DataSet dataset = new DatasetExtractor().extract(datasetPath, nbrClass);
|
||||||
int nbrInput = dataset.getNbrInputs();
|
int nbrInput = dataset.getNbrInputs();
|
||||||
|
|
||||||
@@ -101,13 +100,19 @@ public class Main {
|
|||||||
float max = 10F;
|
float max = 10F;
|
||||||
float step = Math.abs(max - min)/300;
|
float step = Math.abs(max - min)/300;
|
||||||
|
|
||||||
|
boolean isOneHot = dataset.getNbrLabels() > 1;
|
||||||
|
|
||||||
//plot labels
|
//plot labels
|
||||||
for (DataSetEntry entry : dataset) {
|
for (DataSetEntry entry : dataset) {
|
||||||
List<Float> label = dataset.getLabelsAsFloat(entry);
|
List<Float> label = dataset.getLabelsAsFloat(entry);
|
||||||
int expectedClass = 0;
|
int expectedClass = 0;
|
||||||
|
if (isOneHot) {
|
||||||
for (int i = 1; i < label.size(); i++) {
|
for (int i = 1; i < label.size(); i++) {
|
||||||
if (label.get(i) > label.get(expectedClass)) expectedClass = i;
|
if (label.get(i) > label.get(expectedClass)) expectedClass = i;
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
expectedClass = Math.round(label.getFirst());
|
||||||
|
}
|
||||||
visualizer.addPoint("Label " + expectedClass,
|
visualizer.addPoint("Label " + expectedClass,
|
||||||
entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
|
entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
|
||||||
}
|
}
|
||||||
@@ -116,12 +121,16 @@ public class Main {
|
|||||||
for (float x = min; x < max; x+=step){
|
for (float x = min; x < max; x+=step){
|
||||||
for (float y = min; y < max; y+=step){
|
for (float y = min; y < max; y+=step){
|
||||||
List<Float> predictions = new ArrayList<>();
|
List<Float> predictions = new ArrayList<>();
|
||||||
|
|
||||||
for (float p : network.predict(new float[]{x, y})) predictions.add(p);
|
for (float p : network.predict(new float[]{x, y})) predictions.add(p);
|
||||||
float highest = Collections.max(predictions);
|
|
||||||
int predictedClass = predictions.indexOf(highest);
|
|
||||||
|
|
||||||
visualizer.addPoint("Predict " + predictedClass, x, y);
|
int predictedClass;
|
||||||
|
if(isOneHot){
|
||||||
|
predictedClass = predictions.indexOf(Collections.max(predictions));
|
||||||
|
} else {
|
||||||
|
predictedClass = Math.round(predictions.getFirst());
|
||||||
|
}
|
||||||
|
|
||||||
|
visualizer.addPoint("Class " + predictedClass, x, y);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user