Add plotting support for both one-hot vector and single integer classes
This commit is contained in:
@@ -1,14 +1,14 @@
|
||||
{
|
||||
"model": {
|
||||
"new": true,
|
||||
"parameters": [5, 2, 3],
|
||||
"path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/snapshot-test.json"
|
||||
"parameters": [5, 5, 1],
|
||||
"path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/snapshot-4-12.json"
|
||||
},
|
||||
"training" : {
|
||||
"learning_rate" : 0.0003,
|
||||
"learning_rate" : 0.03,
|
||||
"max_epoch" : 5000
|
||||
},
|
||||
"dataset" : {
|
||||
"path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_3_1.csv"
|
||||
"path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/xor.csv"
|
||||
}
|
||||
}
|
||||
@@ -37,7 +37,6 @@ public class Main {
|
||||
int nbrClass = input.isEmpty() ? modelParameters[modelParameters.length-1] : Integer.parseInt(input);
|
||||
System.out.println();
|
||||
|
||||
|
||||
DataSet dataset = new DatasetExtractor().extract(datasetPath, nbrClass);
|
||||
int nbrInput = dataset.getNbrInputs();
|
||||
|
||||
@@ -101,13 +100,19 @@ public class Main {
|
||||
float max = 10F;
|
||||
float step = Math.abs(max - min)/300;
|
||||
|
||||
boolean isOneHot = dataset.getNbrLabels() > 1;
|
||||
|
||||
//plot labels
|
||||
for (DataSetEntry entry : dataset) {
|
||||
List<Float> label = dataset.getLabelsAsFloat(entry);
|
||||
int expectedClass = 0;
|
||||
if (isOneHot) {
|
||||
for (int i = 1; i < label.size(); i++) {
|
||||
if (label.get(i) > label.get(expectedClass)) expectedClass = i;
|
||||
}
|
||||
} else {
|
||||
expectedClass = Math.round(label.getFirst());
|
||||
}
|
||||
visualizer.addPoint("Label " + expectedClass,
|
||||
entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
|
||||
}
|
||||
@@ -116,12 +121,16 @@ public class Main {
|
||||
for (float x = min; x < max; x+=step){
|
||||
for (float y = min; y < max; y+=step){
|
||||
List<Float> predictions = new ArrayList<>();
|
||||
|
||||
for (float p : network.predict(new float[]{x, y})) predictions.add(p);
|
||||
float highest = Collections.max(predictions);
|
||||
int predictedClass = predictions.indexOf(highest);
|
||||
|
||||
visualizer.addPoint("Predict " + predictedClass, x, y);
|
||||
int predictedClass;
|
||||
if(isOneHot){
|
||||
predictedClass = predictions.indexOf(Collections.max(predictions));
|
||||
} else {
|
||||
predictedClass = Math.round(predictions.getFirst());
|
||||
}
|
||||
|
||||
visualizer.addPoint("Class " + predictedClass, x, y);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user