Minor changes
This commit is contained in:
@@ -1,14 +1,14 @@
|
||||
{
|
||||
"model": {
|
||||
"new": true,
|
||||
"parameters": [2, 4, 2, 1],
|
||||
"path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/best-snapshot.json"
|
||||
"parameters": [5, 2, 3],
|
||||
"path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/snapshot-test.json"
|
||||
},
|
||||
"training" : {
|
||||
"learning_rate" : 0.0003,
|
||||
"max_epoch" : 5000
|
||||
},
|
||||
"dataset" : {
|
||||
"path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/LangageDesSignes/data_formatted.csv"
|
||||
"path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_3_1.csv"
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
|
||||
import com.naaturel.ANN.infrastructure.config.ConfigDto;
|
||||
import com.naaturel.ANN.infrastructure.config.ConfigLoader;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
||||
import com.naaturel.ANN.domain.model.neuron.*;
|
||||
import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer;
|
||||
@@ -20,6 +21,8 @@ public class Main {
|
||||
|
||||
public static void main(String[] args) throws Exception {
|
||||
|
||||
Scanner sc = new Scanner(System.in);
|
||||
|
||||
ConfigDto config = ConfigLoader.load("C:/Users/Laurent/Desktop/ANN-framework/config.json");
|
||||
|
||||
boolean newModel = config.getModelProperty("new", Boolean.class);
|
||||
@@ -29,7 +32,12 @@ public class Main {
|
||||
float learningRate = config.getTrainingProperty("learning_rate", Double.class).floatValue();
|
||||
String datasetPath = config.getDatasetProperty("path", String.class);
|
||||
|
||||
int nbrClass = 5;
|
||||
System.out.printf("How many classes ? [%d] : ", modelParameters[modelParameters.length-1]);
|
||||
String input = sc.nextLine().trim();
|
||||
int nbrClass = input.isEmpty() ? modelParameters[modelParameters.length-1] : Integer.parseInt(input);
|
||||
System.out.println();
|
||||
|
||||
|
||||
DataSet dataset = new DatasetExtractor().extract(datasetPath, nbrClass);
|
||||
int nbrInput = dataset.getNbrInputs();
|
||||
|
||||
@@ -48,7 +56,8 @@ public class Main {
|
||||
snapshot.loadFromFile(modelPath);
|
||||
network = snapshot.getModel();
|
||||
}
|
||||
//plotGraph(dataset, network);
|
||||
|
||||
plotGraph(dataset, network);
|
||||
|
||||
new ModelVisualizer(network)
|
||||
.withWeights(true)
|
||||
@@ -84,23 +93,35 @@ public class Main {
|
||||
}
|
||||
|
||||
private static void plotGraph(DataSet dataset, Model network){
|
||||
|
||||
if(dataset.getNbrInputs() != 2) return;
|
||||
|
||||
GraphVisualizer visualizer = new GraphVisualizer();
|
||||
float min = -10F;
|
||||
float max = 10F;
|
||||
float step = Math.abs(max - min)/300;
|
||||
|
||||
/*for (DataSetEntry entry : dataset) {
|
||||
//plot labels
|
||||
for (DataSetEntry entry : dataset) {
|
||||
List<Float> label = dataset.getLabelsAsFloat(entry);
|
||||
label.forEach(l -> {
|
||||
visualizer.addPoint("Label " + l,
|
||||
int expectedClass = 0;
|
||||
for (int i = 1; i < label.size(); i++) {
|
||||
if (label.get(i) > label.get(expectedClass)) expectedClass = i;
|
||||
}
|
||||
visualizer.addPoint("Label " + expectedClass,
|
||||
entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
|
||||
});
|
||||
}*/
|
||||
}
|
||||
|
||||
float min = -50F;
|
||||
float max = 50F;
|
||||
float step = 0.03F;
|
||||
//plot predictions
|
||||
for (float x = min; x < max; x+=step){
|
||||
for (float y = min; y < max; y+=step){
|
||||
float[] predictions = network.predict(new float[]{x, y});
|
||||
visualizer.addPoint(Float.toString(Math.round(predictions[0])), x, y);
|
||||
List<Float> predictions = new ArrayList<>();
|
||||
|
||||
for (float p : network.predict(new float[]{x, y})) predictions.add(p);
|
||||
float highest = Collections.max(predictions);
|
||||
int predictedClass = predictions.indexOf(highest);
|
||||
|
||||
visualizer.addPoint("Predict " + predictedClass, x, y);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -84,7 +84,7 @@ public class TrainingPipeline {
|
||||
System.out.printf("[Training finished in %.3fs]\n", (end-start)/1000.0);
|
||||
}
|
||||
System.out.printf("[Final global error] : %f\n", ctx.globalLoss);
|
||||
//if(this.visualization) this.visualize(ctx);
|
||||
System.out.println("Accuracy: " + accuracyCheck(ctx) + "%");
|
||||
}
|
||||
|
||||
private void executeSteps(TrainingContext ctx){
|
||||
@@ -107,26 +107,28 @@ public class TrainingPipeline {
|
||||
}
|
||||
}
|
||||
|
||||
/*private void visualize(TrainingContext ctx){
|
||||
AtomicInteger neuronIndex = new AtomicInteger(0);
|
||||
ctx.model.forEachNeuron(n -> {
|
||||
List<Float> weights = new ArrayList<>();
|
||||
n.forEachSynapse(syn -> weights.add(syn.getWeight()));
|
||||
private float accuracyCheck(TrainingContext ctx){
|
||||
int correct = 0;
|
||||
for (DataSetEntry entry : ctx.dataset) {
|
||||
float[] predictions = ctx.model.predict(entry.getDataAsFloat());
|
||||
List<Float> labels = ctx.dataset.getLabelsAsFloat(entry);
|
||||
|
||||
float b = weights.get(0);
|
||||
float w1 = weights.get(1);
|
||||
float w2 = weights.get(2);
|
||||
|
||||
this.visualizer.addEquation("boundary_" + neuronIndex.getAndIncrement(), w1, w2, b, -3, 3);
|
||||
});
|
||||
int i = 0;
|
||||
for(DataSetEntry entry : ctx.dataset){
|
||||
List<Input> inputs = entry.getData();
|
||||
this.visualizer.addPoint("p"+i, inputs.get(0).getValue(), inputs.get(1).getValue());
|
||||
this.visualizer.addPoint("p"+i, inputs.get(0).getValue()+0.01F, inputs.get(1).getValue()+0.01F);
|
||||
i++;
|
||||
// argmax of predictions
|
||||
int predictedClass = 0;
|
||||
for (int i = 1; i < predictions.length; i++) {
|
||||
if (predictions[i] > predictions[predictedClass]) predictedClass = i;
|
||||
}
|
||||
|
||||
// argmax of labels
|
||||
int expectedClass = 0;
|
||||
for (int i = 1; i < labels.size(); i++) {
|
||||
if (labels.get(i) > labels.get(expectedClass)) expectedClass = i;
|
||||
}
|
||||
|
||||
if (predictedClass == expectedClass) correct++;
|
||||
}
|
||||
|
||||
return 100f * correct / ctx.dataset.size();
|
||||
}
|
||||
this.visualizer.buildLineGraph();
|
||||
}*/
|
||||
|
||||
}
|
||||
|
||||
@@ -17,6 +17,16 @@ public class DataSetEntry implements Iterable<Input> {
|
||||
}
|
||||
|
||||
|
||||
public float[] getDataAsFloat() {
|
||||
float[] res = new float[data.size()];
|
||||
for(int i = 0; i < data.size(); i++){
|
||||
res[i] = data.get(i).getValue();
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(this.data);
|
||||
|
||||
Reference in New Issue
Block a user