diff --git a/config.json b/config.json index 4ad6443..479e407 100644 --- a/config.json +++ b/config.json @@ -1,14 +1,14 @@ { "model": { "new": true, - "parameters": [2, 4, 2, 1], - "path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/best-snapshot.json" + "parameters": [5, 2, 3], + "path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/snapshot-test.json" }, "training" : { "learning_rate" : 0.0003, "max_epoch" : 5000 }, "dataset" : { - "path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/LangageDesSignes/data_formatted.csv" + "path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_3_1.csv" } } \ No newline at end of file diff --git a/src/main/java/com/naaturel/ANN/Main.java b/src/main/java/com/naaturel/ANN/Main.java index d57d3a3..0c5d71e 100644 --- a/src/main/java/com/naaturel/ANN/Main.java +++ b/src/main/java/com/naaturel/ANN/Main.java @@ -8,6 +8,7 @@ import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining; import com.naaturel.ANN.infrastructure.config.ConfigDto; import com.naaturel.ANN.infrastructure.config.ConfigLoader; import com.naaturel.ANN.infrastructure.dataset.DataSet; +import com.naaturel.ANN.infrastructure.dataset.DataSetEntry; import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor; import com.naaturel.ANN.domain.model.neuron.*; import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer; @@ -20,6 +21,8 @@ public class Main { public static void main(String[] args) throws Exception { + Scanner sc = new Scanner(System.in); + ConfigDto config = ConfigLoader.load("C:/Users/Laurent/Desktop/ANN-framework/config.json"); boolean newModel = config.getModelProperty("new", Boolean.class); @@ -29,7 +32,12 @@ public class Main { float learningRate = config.getTrainingProperty("learning_rate", Double.class).floatValue(); String datasetPath = config.getDatasetProperty("path", String.class); - int nbrClass = 5; + System.out.printf("How many classes ? [%d] : ", modelParameters[modelParameters.length-1]); + String input = sc.nextLine().trim(); + int nbrClass = input.isEmpty() ? modelParameters[modelParameters.length-1] : Integer.parseInt(input); + System.out.println(); + + DataSet dataset = new DatasetExtractor().extract(datasetPath, nbrClass); int nbrInput = dataset.getNbrInputs(); @@ -48,7 +56,8 @@ public class Main { snapshot.loadFromFile(modelPath); network = snapshot.getModel(); } - //plotGraph(dataset, network); + + plotGraph(dataset, network); new ModelVisualizer(network) .withWeights(true) @@ -84,23 +93,35 @@ public class Main { } private static void plotGraph(DataSet dataset, Model network){ + + if(dataset.getNbrInputs() != 2) return; + GraphVisualizer visualizer = new GraphVisualizer(); + float min = -10F; + float max = 10F; + float step = Math.abs(max - min)/300; - /*for (DataSetEntry entry : dataset) { + //plot labels + for (DataSetEntry entry : dataset) { List label = dataset.getLabelsAsFloat(entry); - label.forEach(l -> { - visualizer.addPoint("Label " + l, - entry.getData().get(0).getValue(), entry.getData().get(1).getValue()); - }); - }*/ + int expectedClass = 0; + for (int i = 1; i < label.size(); i++) { + if (label.get(i) > label.get(expectedClass)) expectedClass = i; + } + visualizer.addPoint("Label " + expectedClass, + entry.getData().get(0).getValue(), entry.getData().get(1).getValue()); + } - float min = -50F; - float max = 50F; - float step = 0.03F; + //plot predictions for (float x = min; x < max; x+=step){ for (float y = min; y < max; y+=step){ - float[] predictions = network.predict(new float[]{x, y}); - visualizer.addPoint(Float.toString(Math.round(predictions[0])), x, y); + List predictions = new ArrayList<>(); + + for (float p : network.predict(new float[]{x, y})) predictions.add(p); + float highest = Collections.max(predictions); + int predictedClass = predictions.indexOf(highest); + + visualizer.addPoint("Predict " + predictedClass, x, y); } } diff --git a/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java b/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java index b1c58bc..ce29457 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java +++ b/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java @@ -84,7 +84,7 @@ public class TrainingPipeline { System.out.printf("[Training finished in %.3fs]\n", (end-start)/1000.0); } System.out.printf("[Final global error] : %f\n", ctx.globalLoss); - //if(this.visualization) this.visualize(ctx); + System.out.println("Accuracy: " + accuracyCheck(ctx) + "%"); } private void executeSteps(TrainingContext ctx){ @@ -107,26 +107,28 @@ public class TrainingPipeline { } } - /*private void visualize(TrainingContext ctx){ - AtomicInteger neuronIndex = new AtomicInteger(0); - ctx.model.forEachNeuron(n -> { - List weights = new ArrayList<>(); - n.forEachSynapse(syn -> weights.add(syn.getWeight())); + private float accuracyCheck(TrainingContext ctx){ + int correct = 0; + for (DataSetEntry entry : ctx.dataset) { + float[] predictions = ctx.model.predict(entry.getDataAsFloat()); + List labels = ctx.dataset.getLabelsAsFloat(entry); - float b = weights.get(0); - float w1 = weights.get(1); - float w2 = weights.get(2); + // argmax of predictions + int predictedClass = 0; + for (int i = 1; i < predictions.length; i++) { + if (predictions[i] > predictions[predictedClass]) predictedClass = i; + } - this.visualizer.addEquation("boundary_" + neuronIndex.getAndIncrement(), w1, w2, b, -3, 3); - }); - int i = 0; - for(DataSetEntry entry : ctx.dataset){ - List inputs = entry.getData(); - this.visualizer.addPoint("p"+i, inputs.get(0).getValue(), inputs.get(1).getValue()); - this.visualizer.addPoint("p"+i, inputs.get(0).getValue()+0.01F, inputs.get(1).getValue()+0.01F); - i++; + // argmax of labels + int expectedClass = 0; + for (int i = 1; i < labels.size(); i++) { + if (labels.get(i) > labels.get(expectedClass)) expectedClass = i; + } + + if (predictedClass == expectedClass) correct++; } - this.visualizer.buildLineGraph(); - }*/ + + return 100f * correct / ctx.dataset.size(); + } } diff --git a/src/main/java/com/naaturel/ANN/infrastructure/dataset/DataSetEntry.java b/src/main/java/com/naaturel/ANN/infrastructure/dataset/DataSetEntry.java index 99941ce..598889f 100644 --- a/src/main/java/com/naaturel/ANN/infrastructure/dataset/DataSetEntry.java +++ b/src/main/java/com/naaturel/ANN/infrastructure/dataset/DataSetEntry.java @@ -17,6 +17,16 @@ public class DataSetEntry implements Iterable { } + public float[] getDataAsFloat() { + float[] res = new float[data.size()]; + for(int i = 0; i < data.size(); i++){ + res[i] = data.get(i).getValue(); + } + return res; + } + + + @Override public int hashCode() { return Objects.hash(this.data);