Optimize prediction
This commit is contained in:
@@ -34,9 +34,9 @@ public class Main {
|
||||
System.out.println(network.synCount());
|
||||
|
||||
Trainer trainer = new GradientBackpropagationTraining();
|
||||
trainer.train(0.001F, 2000, network, dataset);
|
||||
trainer.train(0.01F, 2000, network, dataset);
|
||||
|
||||
plotGraph(dataset, network);
|
||||
//plotGraph(dataset, network);
|
||||
}
|
||||
|
||||
private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){
|
||||
@@ -83,8 +83,8 @@ public class Main {
|
||||
float step = 0.03F;
|
||||
for (float x = min; x < max; x+=step){
|
||||
for (float y = min; y < max; y+=step){
|
||||
List<Float> predictions = network.predict(List.of(new Input(x), new Input(y)));
|
||||
visualizer.addPoint(Float.toString(Math.round(predictions.getFirst())), x, y);
|
||||
float[] predictions = network.predict(new float[]{x, y});
|
||||
visualizer.addPoint(Float.toString(Math.round(predictions[0])), x, y);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -15,5 +15,5 @@ public interface Model {
|
||||
//void forEachSynapse(Consumer<Synapse> consumer);
|
||||
void forEachOutputNeurons(Consumer<Neuron> consumer);
|
||||
void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer);
|
||||
List<Float> predict(List<Input> inputs);
|
||||
float[] predict(float[] inputs);
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ public abstract class TrainingContext {
|
||||
public DataSetEntry currentEntry;
|
||||
|
||||
public List<Float> expectations;
|
||||
public List<Float> predictions;
|
||||
public float[] predictions;
|
||||
public float[] deltas;
|
||||
|
||||
public float globalLoss;
|
||||
|
||||
@@ -24,13 +24,12 @@ public class FullyConnectedNetwork implements Model {
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Float> predict(List<Input> inputs) {
|
||||
List<Input> previousLayerOutputs = new ArrayList<>(inputs);
|
||||
for(Layer layer : this.layers){
|
||||
List<Float> currentLayerOutputs = layer.predict(previousLayerOutputs);
|
||||
previousLayerOutputs = currentLayerOutputs.stream().map(Input::new).toList();
|
||||
public float[] predict(float[] inputs) {
|
||||
float[] previousLayerOutputs = inputs;
|
||||
for (Layer layer : layers) {
|
||||
previousLayerOutputs = layer.predict(previousLayerOutputs);
|
||||
}
|
||||
return previousLayerOutputs.stream().map(Input::getValue).toList();
|
||||
return previousLayerOutputs;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -19,11 +19,10 @@ public class Layer implements Model {
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Float> predict(List<Input> inputs) {
|
||||
List<Float> result = new ArrayList<>();
|
||||
for(Neuron neuron : this.neurons){
|
||||
List<Float> res = neuron.predict(inputs);
|
||||
result.addAll(res);
|
||||
public float[] predict(float[] inputs) {
|
||||
float[] result = new float[neurons.length];
|
||||
for (int i = 0; i < neurons.length; i++) {
|
||||
result[i] = neurons[i].predict(inputs)[0];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ public class Neuron implements Model {
|
||||
this.id = id;
|
||||
this.activationFunction = func;
|
||||
|
||||
output = 0;
|
||||
weights = new float[synapses.length+1]; //takes the bias into account
|
||||
inputs = new float[synapses.length+1]; //takes the bias into account
|
||||
|
||||
@@ -44,10 +45,6 @@ public class Neuron implements Model {
|
||||
return this.activationFunction;
|
||||
}
|
||||
|
||||
public float getOutput(){
|
||||
return this.output;
|
||||
}
|
||||
|
||||
public float calculateWeightedSum() {
|
||||
int count = synCount();
|
||||
float weightedSum = 0F;
|
||||
@@ -61,6 +58,10 @@ public class Neuron implements Model {
|
||||
return this.id;
|
||||
}
|
||||
|
||||
public float getOutput() {
|
||||
return this.output;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int synCount() {
|
||||
return this.weights.length;
|
||||
@@ -77,10 +78,10 @@ public class Neuron implements Model {
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Float> predict(List<Input> inputs) {
|
||||
public float[] predict(float[] inputs) {
|
||||
this.setInputs(inputs);
|
||||
this.output = activationFunction.accept(this);
|
||||
return List.of(output);
|
||||
output = activationFunction.accept(this);
|
||||
return new float[] {output};
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -98,10 +99,8 @@ public class Neuron implements Model {
|
||||
throw new UnsupportedOperationException("Neurons have no connection with themselves");
|
||||
}
|
||||
|
||||
private void setInputs(List<Input> values){
|
||||
for(int i = 0; i < values.size(); i++){
|
||||
inputs[i+1] = values.get(i).getValue();
|
||||
}
|
||||
private void setInputs(float[] values){
|
||||
System.arraycopy(values, 0, inputs, 1, values.length);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -101,7 +101,7 @@ public class TrainingPipeline {
|
||||
|
||||
if(this.verbose && ctx.epoch % this.verboseDelay == 0) {
|
||||
System.out.printf("Epoch : %d, ", ctx.epoch);
|
||||
System.out.printf("predicted : %s, ", Arrays.toString(ctx.predictions.toArray()));
|
||||
System.out.printf("predicted : %s, ", Arrays.toString(ctx.predictions));
|
||||
System.out.printf("expected : %s, ", Arrays.toString(ctx.expectations.toArray()));
|
||||
System.out.printf("delta : %s, ", Arrays.toString(ctx.deltas));
|
||||
System.out.printf("loss : %.5f\n", ctx.localLoss);
|
||||
|
||||
@@ -21,11 +21,11 @@ public class SimpleDeltaStep implements AlgorithmStep {
|
||||
public void run() {
|
||||
DataSet dataSet = context.dataset;
|
||||
DataSetEntry entry = context.currentEntry;
|
||||
List<Float> predicted = context.predictions;
|
||||
float[] predicted = context.predictions;
|
||||
List<Float> expected = dataSet.getLabelsAsFloat(entry);
|
||||
|
||||
for (int i = 0; i < predicted.size(); i++) {
|
||||
context.deltas[i] = expected.get(i) - predicted.get(i);
|
||||
for (int i = 0; i < predicted.length; i++) {
|
||||
context.deltas[i] = expected.get(i) - predicted[0];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,9 @@ package com.naaturel.ANN.implementation.simplePerceptron;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class SimplePredictionStep implements AlgorithmStep {
|
||||
|
||||
@@ -13,6 +16,11 @@ public class SimplePredictionStep implements AlgorithmStep {
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
context.predictions = context.model.predict(context.currentEntry.getData());
|
||||
List<Input> data = context.currentEntry.getData();
|
||||
float[] flatData = new float[data.size()];
|
||||
for (int i = 0; i < data.size(); i++) {
|
||||
flatData[i] = data.get(i).getValue();
|
||||
}
|
||||
context.predictions = context.model.predict(flatData);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user