Minor changes
This commit is contained in:
@@ -8,6 +8,7 @@ import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
|
||||
import com.naaturel.ANN.infrastructure.config.ConfigDto;
|
||||
import com.naaturel.ANN.infrastructure.config.ConfigLoader;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
||||
import com.naaturel.ANN.domain.model.neuron.*;
|
||||
import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer;
|
||||
@@ -20,6 +21,8 @@ public class Main {
|
||||
|
||||
public static void main(String[] args) throws Exception {
|
||||
|
||||
Scanner sc = new Scanner(System.in);
|
||||
|
||||
ConfigDto config = ConfigLoader.load("C:/Users/Laurent/Desktop/ANN-framework/config.json");
|
||||
|
||||
boolean newModel = config.getModelProperty("new", Boolean.class);
|
||||
@@ -29,7 +32,12 @@ public class Main {
|
||||
float learningRate = config.getTrainingProperty("learning_rate", Double.class).floatValue();
|
||||
String datasetPath = config.getDatasetProperty("path", String.class);
|
||||
|
||||
int nbrClass = 5;
|
||||
System.out.printf("How many classes ? [%d] : ", modelParameters[modelParameters.length-1]);
|
||||
String input = sc.nextLine().trim();
|
||||
int nbrClass = input.isEmpty() ? modelParameters[modelParameters.length-1] : Integer.parseInt(input);
|
||||
System.out.println();
|
||||
|
||||
|
||||
DataSet dataset = new DatasetExtractor().extract(datasetPath, nbrClass);
|
||||
int nbrInput = dataset.getNbrInputs();
|
||||
|
||||
@@ -48,7 +56,8 @@ public class Main {
|
||||
snapshot.loadFromFile(modelPath);
|
||||
network = snapshot.getModel();
|
||||
}
|
||||
//plotGraph(dataset, network);
|
||||
|
||||
plotGraph(dataset, network);
|
||||
|
||||
new ModelVisualizer(network)
|
||||
.withWeights(true)
|
||||
@@ -84,23 +93,35 @@ public class Main {
|
||||
}
|
||||
|
||||
private static void plotGraph(DataSet dataset, Model network){
|
||||
|
||||
if(dataset.getNbrInputs() != 2) return;
|
||||
|
||||
GraphVisualizer visualizer = new GraphVisualizer();
|
||||
float min = -10F;
|
||||
float max = 10F;
|
||||
float step = Math.abs(max - min)/300;
|
||||
|
||||
/*for (DataSetEntry entry : dataset) {
|
||||
//plot labels
|
||||
for (DataSetEntry entry : dataset) {
|
||||
List<Float> label = dataset.getLabelsAsFloat(entry);
|
||||
label.forEach(l -> {
|
||||
visualizer.addPoint("Label " + l,
|
||||
entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
|
||||
});
|
||||
}*/
|
||||
int expectedClass = 0;
|
||||
for (int i = 1; i < label.size(); i++) {
|
||||
if (label.get(i) > label.get(expectedClass)) expectedClass = i;
|
||||
}
|
||||
visualizer.addPoint("Label " + expectedClass,
|
||||
entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
|
||||
}
|
||||
|
||||
float min = -50F;
|
||||
float max = 50F;
|
||||
float step = 0.03F;
|
||||
//plot predictions
|
||||
for (float x = min; x < max; x+=step){
|
||||
for (float y = min; y < max; y+=step){
|
||||
float[] predictions = network.predict(new float[]{x, y});
|
||||
visualizer.addPoint(Float.toString(Math.round(predictions[0])), x, y);
|
||||
List<Float> predictions = new ArrayList<>();
|
||||
|
||||
for (float p : network.predict(new float[]{x, y})) predictions.add(p);
|
||||
float highest = Collections.max(predictions);
|
||||
int predictedClass = predictions.indexOf(highest);
|
||||
|
||||
visualizer.addPoint("Predict " + predictedClass, x, y);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user