77 lines
3.9 KiB
Java
77 lines
3.9 KiB
Java
package com.naaturel.ANN;
|
|
|
|
import com.naaturel.ANN.domain.abstraction.Neuron;
|
|
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
|
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
|
import com.naaturel.ANN.domain.model.dataset.Label;
|
|
import com.naaturel.ANN.domain.model.neuron.Bias;
|
|
import com.naaturel.ANN.domain.model.neuron.Input;
|
|
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
|
import com.naaturel.ANN.domain.model.neuron.Weight;
|
|
import com.naaturel.ANN.implementation.activationFunction.Heaviside;
|
|
import com.naaturel.ANN.implementation.activationFunction.Linear;
|
|
import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
|
|
import com.naaturel.ANN.implementation.training.GradientDescentTraining;
|
|
|
|
import java.util.*;
|
|
|
|
public class Main {
|
|
|
|
public static void main(String[] args){
|
|
|
|
DataSet orDataSet = new DataSet(Map.ofEntries(
|
|
Map.entry(new DataSetEntry(List.of(0.0F, 0.0F)), new Label(0.0F)),
|
|
Map.entry(new DataSetEntry(List.of(0.0F, 1.0F)), new Label(1.0F)),
|
|
Map.entry(new DataSetEntry(List.of(1.0F, 0.0F)), new Label(1.0F)),
|
|
Map.entry(new DataSetEntry(List.of(1.0F, 1.0F)), new Label(1.0F))
|
|
));
|
|
|
|
DataSet andDataSet = new DataSet(Map.ofEntries(
|
|
Map.entry(new DataSetEntry(List.of(0.0F, 0.0F)), new Label(-1.0F)),
|
|
Map.entry(new DataSetEntry(List.of(0.0F, 1.0F)), new Label(-1.0F)),
|
|
Map.entry(new DataSetEntry(List.of(1.0F, 0.0F)), new Label(-1.0F)),
|
|
Map.entry(new DataSetEntry(List.of(1.0F, 1.0F)), new Label(1.0F))
|
|
));
|
|
|
|
DataSet dataSet = new DataSet(Map.ofEntries(
|
|
Map.entry(new DataSetEntry(List.of(1.0F, 6.0F)), new Label(1.0F)),
|
|
Map.entry(new DataSetEntry(List.of(7.0F, 9.0F)), new Label(-1.0F)),
|
|
Map.entry(new DataSetEntry(List.of(1.0F, 9.0F)), new Label(1.0F)),
|
|
Map.entry(new DataSetEntry(List.of(7.0F, 10.0F)), new Label(-1.0F)),
|
|
Map.entry(new DataSetEntry(List.of(2.0F, 5.0F)), new Label(-1.0F)),
|
|
Map.entry(new DataSetEntry(List.of(2.0F, 7.0F)), new Label(1.0F)),
|
|
Map.entry(new DataSetEntry(List.of(2.0F, 8.0F)), new Label(1.0F)),
|
|
Map.entry(new DataSetEntry(List.of(6.0F, 8.0F)), new Label(-1.0F)),
|
|
Map.entry(new DataSetEntry(List.of(6.0F, 9.0F)), new Label(-1.0F)),
|
|
Map.entry(new DataSetEntry(List.of(3.0F, 5.0F)), new Label(-1.0F)),
|
|
Map.entry(new DataSetEntry(List.of(3.0F, 6.0F)), new Label(-1.0F)),
|
|
Map.entry(new DataSetEntry(List.of(3.0F, 8.0F)), new Label(1.0F)),
|
|
Map.entry(new DataSetEntry(List.of(3.0F, 9.0F)), new Label(1.0F)),
|
|
Map.entry(new DataSetEntry(List.of(5.0F, 7.0F)), new Label(-1.0F)),
|
|
Map.entry(new DataSetEntry(List.of(5.0F, 8.0F)), new Label(-1.0F)),
|
|
Map.entry(new DataSetEntry(List.of(5.0F, 10.0F)), new Label(1.0F)),
|
|
Map.entry(new DataSetEntry(List.of(5.0F, 11.0F)), new Label(1.0F)),
|
|
Map.entry(new DataSetEntry(List.of(4.0F, 6.0F)), new Label(-1.0F)),
|
|
Map.entry(new DataSetEntry(List.of(4.0F, 7.0F)), new Label(-1.0F)),
|
|
Map.entry(new DataSetEntry(List.of(4.0F, 9.0F)), new Label(1.0F)),
|
|
Map.entry(new DataSetEntry(List.of(4.0F, 10.0F)), new Label(1.0F))
|
|
));
|
|
|
|
List<Synapse> syns = new ArrayList<>();
|
|
syns.add(new Synapse(new Input(0), new Weight(0)));
|
|
syns.add(new Synapse(new Input(0), new Weight(0)));
|
|
|
|
Bias bias = new Bias(new Weight(0));
|
|
|
|
Neuron n = new SimplePerceptron(syns, bias, new Linear());
|
|
GradientDescentTraining st = new GradientDescentTraining();
|
|
|
|
long start = System.currentTimeMillis();
|
|
|
|
st.train(n, 0.2F, andDataSet);
|
|
|
|
long end = System.currentTimeMillis();
|
|
System.out.printf("Training completed in %.2f s%n", (end - start) / 1000.0);
|
|
}
|
|
}
|