Minor changes
This commit is contained in:
@@ -1,14 +1,14 @@
|
|||||||
{
|
{
|
||||||
"model": {
|
"model": {
|
||||||
"new": true,
|
"new": true,
|
||||||
"parameters": [2, 4, 2, 1],
|
"parameters": [5, 2, 3],
|
||||||
"path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/best-snapshot.json"
|
"path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/snapshot-test.json"
|
||||||
},
|
},
|
||||||
"training" : {
|
"training" : {
|
||||||
"learning_rate" : 0.0003,
|
"learning_rate" : 0.0003,
|
||||||
"max_epoch" : 5000
|
"max_epoch" : 5000
|
||||||
},
|
},
|
||||||
"dataset" : {
|
"dataset" : {
|
||||||
"path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/LangageDesSignes/data_formatted.csv"
|
"path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_3_1.csv"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -8,6 +8,7 @@ import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
|
|||||||
import com.naaturel.ANN.infrastructure.config.ConfigDto;
|
import com.naaturel.ANN.infrastructure.config.ConfigDto;
|
||||||
import com.naaturel.ANN.infrastructure.config.ConfigLoader;
|
import com.naaturel.ANN.infrastructure.config.ConfigLoader;
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||||
|
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
||||||
import com.naaturel.ANN.domain.model.neuron.*;
|
import com.naaturel.ANN.domain.model.neuron.*;
|
||||||
import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer;
|
import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer;
|
||||||
@@ -20,6 +21,8 @@ public class Main {
|
|||||||
|
|
||||||
public static void main(String[] args) throws Exception {
|
public static void main(String[] args) throws Exception {
|
||||||
|
|
||||||
|
Scanner sc = new Scanner(System.in);
|
||||||
|
|
||||||
ConfigDto config = ConfigLoader.load("C:/Users/Laurent/Desktop/ANN-framework/config.json");
|
ConfigDto config = ConfigLoader.load("C:/Users/Laurent/Desktop/ANN-framework/config.json");
|
||||||
|
|
||||||
boolean newModel = config.getModelProperty("new", Boolean.class);
|
boolean newModel = config.getModelProperty("new", Boolean.class);
|
||||||
@@ -29,7 +32,12 @@ public class Main {
|
|||||||
float learningRate = config.getTrainingProperty("learning_rate", Double.class).floatValue();
|
float learningRate = config.getTrainingProperty("learning_rate", Double.class).floatValue();
|
||||||
String datasetPath = config.getDatasetProperty("path", String.class);
|
String datasetPath = config.getDatasetProperty("path", String.class);
|
||||||
|
|
||||||
int nbrClass = 5;
|
System.out.printf("How many classes ? [%d] : ", modelParameters[modelParameters.length-1]);
|
||||||
|
String input = sc.nextLine().trim();
|
||||||
|
int nbrClass = input.isEmpty() ? modelParameters[modelParameters.length-1] : Integer.parseInt(input);
|
||||||
|
System.out.println();
|
||||||
|
|
||||||
|
|
||||||
DataSet dataset = new DatasetExtractor().extract(datasetPath, nbrClass);
|
DataSet dataset = new DatasetExtractor().extract(datasetPath, nbrClass);
|
||||||
int nbrInput = dataset.getNbrInputs();
|
int nbrInput = dataset.getNbrInputs();
|
||||||
|
|
||||||
@@ -48,7 +56,8 @@ public class Main {
|
|||||||
snapshot.loadFromFile(modelPath);
|
snapshot.loadFromFile(modelPath);
|
||||||
network = snapshot.getModel();
|
network = snapshot.getModel();
|
||||||
}
|
}
|
||||||
//plotGraph(dataset, network);
|
|
||||||
|
plotGraph(dataset, network);
|
||||||
|
|
||||||
new ModelVisualizer(network)
|
new ModelVisualizer(network)
|
||||||
.withWeights(true)
|
.withWeights(true)
|
||||||
@@ -84,23 +93,35 @@ public class Main {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private static void plotGraph(DataSet dataset, Model network){
|
private static void plotGraph(DataSet dataset, Model network){
|
||||||
|
|
||||||
|
if(dataset.getNbrInputs() != 2) return;
|
||||||
|
|
||||||
GraphVisualizer visualizer = new GraphVisualizer();
|
GraphVisualizer visualizer = new GraphVisualizer();
|
||||||
|
float min = -10F;
|
||||||
|
float max = 10F;
|
||||||
|
float step = Math.abs(max - min)/300;
|
||||||
|
|
||||||
/*for (DataSetEntry entry : dataset) {
|
//plot labels
|
||||||
|
for (DataSetEntry entry : dataset) {
|
||||||
List<Float> label = dataset.getLabelsAsFloat(entry);
|
List<Float> label = dataset.getLabelsAsFloat(entry);
|
||||||
label.forEach(l -> {
|
int expectedClass = 0;
|
||||||
visualizer.addPoint("Label " + l,
|
for (int i = 1; i < label.size(); i++) {
|
||||||
|
if (label.get(i) > label.get(expectedClass)) expectedClass = i;
|
||||||
|
}
|
||||||
|
visualizer.addPoint("Label " + expectedClass,
|
||||||
entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
|
entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
|
||||||
});
|
}
|
||||||
}*/
|
|
||||||
|
|
||||||
float min = -50F;
|
//plot predictions
|
||||||
float max = 50F;
|
|
||||||
float step = 0.03F;
|
|
||||||
for (float x = min; x < max; x+=step){
|
for (float x = min; x < max; x+=step){
|
||||||
for (float y = min; y < max; y+=step){
|
for (float y = min; y < max; y+=step){
|
||||||
float[] predictions = network.predict(new float[]{x, y});
|
List<Float> predictions = new ArrayList<>();
|
||||||
visualizer.addPoint(Float.toString(Math.round(predictions[0])), x, y);
|
|
||||||
|
for (float p : network.predict(new float[]{x, y})) predictions.add(p);
|
||||||
|
float highest = Collections.max(predictions);
|
||||||
|
int predictedClass = predictions.indexOf(highest);
|
||||||
|
|
||||||
|
visualizer.addPoint("Predict " + predictedClass, x, y);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ public class TrainingPipeline {
|
|||||||
System.out.printf("[Training finished in %.3fs]\n", (end-start)/1000.0);
|
System.out.printf("[Training finished in %.3fs]\n", (end-start)/1000.0);
|
||||||
}
|
}
|
||||||
System.out.printf("[Final global error] : %f\n", ctx.globalLoss);
|
System.out.printf("[Final global error] : %f\n", ctx.globalLoss);
|
||||||
//if(this.visualization) this.visualize(ctx);
|
System.out.println("Accuracy: " + accuracyCheck(ctx) + "%");
|
||||||
}
|
}
|
||||||
|
|
||||||
private void executeSteps(TrainingContext ctx){
|
private void executeSteps(TrainingContext ctx){
|
||||||
@@ -107,26 +107,28 @@ public class TrainingPipeline {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*private void visualize(TrainingContext ctx){
|
private float accuracyCheck(TrainingContext ctx){
|
||||||
AtomicInteger neuronIndex = new AtomicInteger(0);
|
int correct = 0;
|
||||||
ctx.model.forEachNeuron(n -> {
|
for (DataSetEntry entry : ctx.dataset) {
|
||||||
List<Float> weights = new ArrayList<>();
|
float[] predictions = ctx.model.predict(entry.getDataAsFloat());
|
||||||
n.forEachSynapse(syn -> weights.add(syn.getWeight()));
|
List<Float> labels = ctx.dataset.getLabelsAsFloat(entry);
|
||||||
|
|
||||||
float b = weights.get(0);
|
// argmax of predictions
|
||||||
float w1 = weights.get(1);
|
int predictedClass = 0;
|
||||||
float w2 = weights.get(2);
|
for (int i = 1; i < predictions.length; i++) {
|
||||||
|
if (predictions[i] > predictions[predictedClass]) predictedClass = i;
|
||||||
this.visualizer.addEquation("boundary_" + neuronIndex.getAndIncrement(), w1, w2, b, -3, 3);
|
}
|
||||||
});
|
|
||||||
int i = 0;
|
// argmax of labels
|
||||||
for(DataSetEntry entry : ctx.dataset){
|
int expectedClass = 0;
|
||||||
List<Input> inputs = entry.getData();
|
for (int i = 1; i < labels.size(); i++) {
|
||||||
this.visualizer.addPoint("p"+i, inputs.get(0).getValue(), inputs.get(1).getValue());
|
if (labels.get(i) > labels.get(expectedClass)) expectedClass = i;
|
||||||
this.visualizer.addPoint("p"+i, inputs.get(0).getValue()+0.01F, inputs.get(1).getValue()+0.01F);
|
}
|
||||||
i++;
|
|
||||||
|
if (predictedClass == expectedClass) correct++;
|
||||||
|
}
|
||||||
|
|
||||||
|
return 100f * correct / ctx.dataset.size();
|
||||||
}
|
}
|
||||||
this.visualizer.buildLineGraph();
|
|
||||||
}*/
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,6 +17,16 @@ public class DataSetEntry implements Iterable<Input> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public float[] getDataAsFloat() {
|
||||||
|
float[] res = new float[data.size()];
|
||||||
|
for(int i = 0; i < data.size(); i++){
|
||||||
|
res[i] = data.get(i).getValue();
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Objects.hash(this.data);
|
return Objects.hash(this.data);
|
||||||
|
|||||||
Reference in New Issue
Block a user