Fix plotting
This commit is contained in:
@@ -1,14 +1,14 @@
|
||||
{
|
||||
"model": {
|
||||
"new": true,
|
||||
"parameters": [1],
|
||||
"parameters": [10, 5, 1],
|
||||
"path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/snapshot-test.json"
|
||||
},
|
||||
"training" : {
|
||||
"learning_rate" : 1.0,
|
||||
"learning_rate" : 0.002,
|
||||
"max_epoch" : 5000
|
||||
},
|
||||
"dataset" : {
|
||||
"path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and.csv"
|
||||
"path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_10.csv"
|
||||
}
|
||||
}
|
||||
@@ -21,7 +21,7 @@ public class Main {
|
||||
|
||||
public static void main(String[] args) throws Exception {
|
||||
|
||||
String[] types = {"Simple perceptron", "Gradient descent", "Adaline", "Gradient retro-propagation"};
|
||||
String[] types = {"Simple perceptron", "Gradient descent", "Adaline", "Gradient backpropagation"};
|
||||
|
||||
Scanner sc = new Scanner(System.in);
|
||||
for (int i = 0; i < types.length; i++) {
|
||||
@@ -77,22 +77,22 @@ public class Main {
|
||||
if(dataset.getNbrInputs() != 2) return;
|
||||
|
||||
GraphVisualizer visualizer = new GraphVisualizer();
|
||||
float min = -10F;
|
||||
float max = 10F;
|
||||
float min = 0F;
|
||||
float max = 15F;
|
||||
float step = Math.abs(max - min)/300;
|
||||
|
||||
boolean isOneHot = dataset.getNbrLabels() > 1;
|
||||
|
||||
//plot labels
|
||||
for (DataSetEntry entry : dataset) {
|
||||
List<Float> label = dataset.getLabelsAsFloat(entry);
|
||||
List<Float> labels = dataset.getLabelsAsFloat(entry);
|
||||
int expectedClass = 0;
|
||||
if (isOneHot) {
|
||||
for (int i = 1; i < label.size(); i++) {
|
||||
if (label.get(i) > label.get(expectedClass)) expectedClass = i;
|
||||
for (int i = 1; i < labels.size(); i++) {
|
||||
if (labels.get(i) > labels.get(expectedClass)) expectedClass = i;
|
||||
}
|
||||
} else {
|
||||
expectedClass = Math.round(label.getFirst());
|
||||
expectedClass = Math.round(labels.getFirst());
|
||||
}
|
||||
visualizer.addPoint("Label " + expectedClass,
|
||||
entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
|
||||
@@ -105,12 +105,7 @@ public class Main {
|
||||
List<Float> predictions = new ArrayList<>();
|
||||
for (float p : network.predict(new float[]{x, y})) predictions.add(p);
|
||||
|
||||
int predictedClass;
|
||||
if(isOneHot){
|
||||
predictedClass = predictions.indexOf(Collections.max(predictions));
|
||||
} else {
|
||||
predictedClass = Math.round(predictions.getFirst());
|
||||
}
|
||||
int predictedClass = dataset.getPredictedClass(predictions);
|
||||
|
||||
visualizer.addPoint("Class " + predictedClass, x, y);
|
||||
}
|
||||
|
||||
@@ -36,7 +36,6 @@ public class DataSet implements Iterable<DataSetEntry>{
|
||||
return firstEntry.map(inputs -> this.data.get(inputs).getValues().size()).orElse(0);
|
||||
}
|
||||
|
||||
|
||||
public int size() {
|
||||
return data.size();
|
||||
}
|
||||
@@ -57,6 +56,25 @@ public class DataSet implements Iterable<DataSetEntry>{
|
||||
return this.data.get(entry).getValues();
|
||||
}
|
||||
|
||||
public int getPredictedClass(List<Float> values) {
|
||||
if (nbrLabels == 1) {
|
||||
// single value — round to nearest known class
|
||||
float value = values.getFirst();
|
||||
return data.values().stream()
|
||||
.map(l -> l.getValues().getFirst())
|
||||
.distinct()
|
||||
.min(Comparator.comparingDouble(l -> Math.abs(l - value)))
|
||||
.map(Math::round)
|
||||
.orElseThrow();
|
||||
} else {
|
||||
// one-hot — return index of max value
|
||||
int max = 0;
|
||||
for (int i = 1; i < values.size(); i++)
|
||||
if (values.get(i) > values.get(max)) max = i;
|
||||
return max;
|
||||
}
|
||||
}
|
||||
|
||||
public DataSet toNormalized() {
|
||||
List<DataSetEntry> entries = this.getData();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user