From c541d38aad70afc0df65a7947132aac1be6be557 Mon Sep 17 00:00:00 2001 From: echo Date: Wed, 18 Dec 2024 13:16:14 +0100 Subject: [PATCH] Properly implemented batches and multi threading training --- build.gradle | 1 + .../activation/impl/ELUActivation.java | 2 +- .../activation/impl/GELUActivation.java | 2 +- .../activation/impl/LeakyReLUActivation.java | 2 +- .../activation/impl/LinearActivation.java | 2 +- .../activation/impl/ReLUActivation.java | 2 +- .../activation/impl/SigmoidActivation.java | 2 +- .../activation/impl/SoftmaxActivation.java | 2 +- .../activation/impl/TanhActivation.java | 2 +- .../java/net/echo/brain4j/layer/Layer.java | 6 +-- .../echo/brain4j/layer/impl/DropoutLayer.java | 2 +- .../echo/brain4j/layer/impl/LayerNorm.java | 6 +-- .../net/echo/brain4j/loss/LossFunctions.java | 7 ++-- .../brain4j/loss/impl/MeanAbsoluteError.java | 17 ++++++++ .../java/net/echo/brain4j/model/Model.java | 4 +- .../net/echo/brain4j/structure/Neuron.java | 6 +++ .../brain4j/training/BackPropagation.java | 40 +++++++++++++++---- .../training/optimizers/Optimizer.java | 4 +- .../training/optimizers/impl/Adam.java | 4 +- .../optimizers/impl/GradientDescent.java | 2 +- 20 files changed, 81 insertions(+), 34 deletions(-) create mode 100644 src/main/java/net/echo/brain4j/loss/impl/MeanAbsoluteError.java diff --git a/build.gradle b/build.gradle index a79db57..4cf6e34 100644 --- a/build.gradle +++ b/build.gradle @@ -13,6 +13,7 @@ repositories { dependencies { implementation 'com.google.code.gson:gson:2.10.1' implementation 'commons-io:commons-io:2.18.0' + implementation 'org.jfree:jfreechart:1.5.3' // implementation 'org.apache.commons:commons-io:2.18.0' } diff --git a/src/main/java/net/echo/brain4j/activation/impl/ELUActivation.java b/src/main/java/net/echo/brain4j/activation/impl/ELUActivation.java index d637870..313591f 100644 --- a/src/main/java/net/echo/brain4j/activation/impl/ELUActivation.java +++ b/src/main/java/net/echo/brain4j/activation/impl/ELUActivation.java @@ -35,7 +35,7 @@ public double getDerivative(double input) { @Override public void apply(List neurons) { for (Neuron neuron : neurons) { - neuron.setValue(activate(neuron.getValue())); + neuron.setValue(activate(neuron.getLocalValue())); } } } diff --git a/src/main/java/net/echo/brain4j/activation/impl/GELUActivation.java b/src/main/java/net/echo/brain4j/activation/impl/GELUActivation.java index fe5cffa..b7b250a 100644 --- a/src/main/java/net/echo/brain4j/activation/impl/GELUActivation.java +++ b/src/main/java/net/echo/brain4j/activation/impl/GELUActivation.java @@ -26,7 +26,7 @@ public double getDerivative(double input) { @Override public void apply(List neurons) { for (Neuron neuron : neurons) { - neuron.setValue(activate(neuron.getValue())); + neuron.setValue(activate(neuron.getLocalValue())); } } } diff --git a/src/main/java/net/echo/brain4j/activation/impl/LeakyReLUActivation.java b/src/main/java/net/echo/brain4j/activation/impl/LeakyReLUActivation.java index e3c2689..99ce5f9 100644 --- a/src/main/java/net/echo/brain4j/activation/impl/LeakyReLUActivation.java +++ b/src/main/java/net/echo/brain4j/activation/impl/LeakyReLUActivation.java @@ -25,7 +25,7 @@ public double getDerivative(double input) { @Override public void apply(List neurons) { for (Neuron neuron : neurons) { - double output = activate(neuron.getValue() + neuron.getBias()); + double output = activate(neuron.getLocalValue() + neuron.getBias()); neuron.setValue(output); } diff --git a/src/main/java/net/echo/brain4j/activation/impl/LinearActivation.java b/src/main/java/net/echo/brain4j/activation/impl/LinearActivation.java index 9bfcc49..1749054 100644 --- a/src/main/java/net/echo/brain4j/activation/impl/LinearActivation.java +++ b/src/main/java/net/echo/brain4j/activation/impl/LinearActivation.java @@ -25,7 +25,7 @@ public double getDerivative(double input) { @Override public void apply(List neurons) { for (Neuron neuron : neurons) { - double output = activate(neuron.getValue() + neuron.getBias()); + double output = activate(neuron.getLocalValue() + neuron.getBias()); neuron.setValue(output); } diff --git a/src/main/java/net/echo/brain4j/activation/impl/ReLUActivation.java b/src/main/java/net/echo/brain4j/activation/impl/ReLUActivation.java index 6ff33c5..5070eb3 100644 --- a/src/main/java/net/echo/brain4j/activation/impl/ReLUActivation.java +++ b/src/main/java/net/echo/brain4j/activation/impl/ReLUActivation.java @@ -25,7 +25,7 @@ public double getDerivative(double input) { @Override public void apply(List neurons) { for (Neuron neuron : neurons) { - double output = activate(neuron.getValue() + neuron.getBias()); + double output = activate(neuron.getLocalValue() + neuron.getBias()); neuron.setValue(output); } diff --git a/src/main/java/net/echo/brain4j/activation/impl/SigmoidActivation.java b/src/main/java/net/echo/brain4j/activation/impl/SigmoidActivation.java index 8986f41..3b0b3d2 100644 --- a/src/main/java/net/echo/brain4j/activation/impl/SigmoidActivation.java +++ b/src/main/java/net/echo/brain4j/activation/impl/SigmoidActivation.java @@ -31,7 +31,7 @@ public double getDerivative(double input) { @Override public void apply(List neurons) { for (Neuron neuron : neurons) { - double output = activate(neuron.getValue() + neuron.getBias()); + double output = activate(neuron.getLocalValue() + neuron.getBias()); neuron.setValue(output); } diff --git a/src/main/java/net/echo/brain4j/activation/impl/SoftmaxActivation.java b/src/main/java/net/echo/brain4j/activation/impl/SoftmaxActivation.java index c21b8fb..12b9479 100644 --- a/src/main/java/net/echo/brain4j/activation/impl/SoftmaxActivation.java +++ b/src/main/java/net/echo/brain4j/activation/impl/SoftmaxActivation.java @@ -46,7 +46,7 @@ public void apply(List neurons) { double[] values = new double[neurons.size()]; for (int i = 0; i < neurons.size(); i++) { - values[i] = neurons.get(i).getValue() + neurons.get(i).getBias(); + values[i] = neurons.get(i).getLocalValue() + neurons.get(i).getBias(); } double[] activatedValues = activate(values); diff --git a/src/main/java/net/echo/brain4j/activation/impl/TanhActivation.java b/src/main/java/net/echo/brain4j/activation/impl/TanhActivation.java index 5502939..6441c1e 100644 --- a/src/main/java/net/echo/brain4j/activation/impl/TanhActivation.java +++ b/src/main/java/net/echo/brain4j/activation/impl/TanhActivation.java @@ -29,7 +29,7 @@ public double getDerivative(double input) { @Override public void apply(List neurons) { for (Neuron neuron : neurons) { - double value = neuron.getValue() + neuron.getBias(); + double value = neuron.getLocalValue() + neuron.getBias(); neuron.setValue(activate(value)); } } diff --git a/src/main/java/net/echo/brain4j/layer/Layer.java b/src/main/java/net/echo/brain4j/layer/Layer.java index 9d46abc..90fbb54 100644 --- a/src/main/java/net/echo/brain4j/layer/Layer.java +++ b/src/main/java/net/echo/brain4j/layer/Layer.java @@ -62,7 +62,7 @@ public void activate() { Neuron inputNeuron = synapse.getInputNeuron(); Neuron outputNeuron = synapse.getOutputNeuron(); - outputNeuron.setValue(outputNeuron.getValue() + inputNeuron.getValue() * synapse.getWeight()); + outputNeuron.setValue(outputNeuron.getLocalValue() + inputNeuron.getLocalValue() * synapse.getWeight()); } } @@ -73,7 +73,7 @@ public void activate(Vector input) { for (Synapse synapse : inputNeuron.getSynapses()) { Neuron outputNeuron = synapse.getOutputNeuron(); - outputNeuron.setValue(outputNeuron.getValue() + input.get(i) * synapse.getWeight()); + outputNeuron.setValue(outputNeuron.getLocalValue() + input.get(i) * synapse.getWeight()); } } } @@ -82,7 +82,7 @@ public Vector getVector() { Vector values = new Vector(neurons.size()); for (int i = 0; i < neurons.size(); i++) { - values.set(i, neurons.get(i).getValue()); + values.set(i, neurons.get(i).getLocalValue()); } return values; diff --git a/src/main/java/net/echo/brain4j/layer/impl/DropoutLayer.java b/src/main/java/net/echo/brain4j/layer/impl/DropoutLayer.java index d2a8fe1..99b3e16 100644 --- a/src/main/java/net/echo/brain4j/layer/impl/DropoutLayer.java +++ b/src/main/java/net/echo/brain4j/layer/impl/DropoutLayer.java @@ -43,7 +43,7 @@ public void process(List neurons) { public void backward(List neurons) { for (Neuron neuron : neurons) { - neuron.setValue(neuron.getValue() * (1.0 - dropout)); + neuron.setValue(neuron.getLocalValue() * (1.0 - dropout)); } } diff --git a/src/main/java/net/echo/brain4j/layer/impl/LayerNorm.java b/src/main/java/net/echo/brain4j/layer/impl/LayerNorm.java index 580d9e3..e061df2 100644 --- a/src/main/java/net/echo/brain4j/layer/impl/LayerNorm.java +++ b/src/main/java/net/echo/brain4j/layer/impl/LayerNorm.java @@ -28,7 +28,7 @@ public void applyFunction(Layer previous) { double variance = calculateVariance(inputs, mean); for (Neuron input : inputs) { - double value = input.getValue(); + double value = input.getLocalValue(); double normalized = (value - mean) / Math.sqrt(variance + epsilon); input.setValue(normalized); @@ -55,7 +55,7 @@ private double calculateMean(List inputs) { double sum = 0.0; for (Neuron value : inputs) { - sum += value.getValue(); + sum += value.getLocalValue(); } return sum / inputs.size(); @@ -65,7 +65,7 @@ private double calculateVariance(List inputs, double mean) { double sum = 0.0; for (Neuron value : inputs) { - sum += Math.pow(value.getValue() - mean, 2); + sum += Math.pow(value.getLocalValue() - mean, 2); } return sum / inputs.size(); diff --git a/src/main/java/net/echo/brain4j/loss/LossFunctions.java b/src/main/java/net/echo/brain4j/loss/LossFunctions.java index 70f492e..13a7697 100644 --- a/src/main/java/net/echo/brain4j/loss/LossFunctions.java +++ b/src/main/java/net/echo/brain4j/loss/LossFunctions.java @@ -1,9 +1,6 @@ package net.echo.brain4j.loss; -import net.echo.brain4j.loss.impl.BinaryCrossEntropy; -import net.echo.brain4j.loss.impl.CategoricalCrossEntropy; -import net.echo.brain4j.loss.impl.CrossEntropy; -import net.echo.brain4j.loss.impl.MeanSquaredError; +import net.echo.brain4j.loss.impl.*; public enum LossFunctions { @@ -14,6 +11,8 @@ public enum LossFunctions { */ MEAN_SQUARED_ERROR(new MeanSquaredError()), + MEAN_ABSOLUTE_ERROR(new MeanAbsoluteError()), + /** * Binary Cross Entropy: Used to evaluate the error in binary classification tasks * by measuring the divergence between the predicted probabilities and the actual binary labels. diff --git a/src/main/java/net/echo/brain4j/loss/impl/MeanAbsoluteError.java b/src/main/java/net/echo/brain4j/loss/impl/MeanAbsoluteError.java new file mode 100644 index 0000000..89f18fe --- /dev/null +++ b/src/main/java/net/echo/brain4j/loss/impl/MeanAbsoluteError.java @@ -0,0 +1,17 @@ +package net.echo.brain4j.loss.impl; + +import net.echo.brain4j.loss.LossFunction; + +public class MeanAbsoluteError implements LossFunction { + + @Override + public double calculate(double[] actual, double[] predicted) { + double error = 0.0; + + for (int i = 0; i < actual.length; i++) { + error += Math.abs(actual[i] - predicted[i]); + } + + return error / actual.length; + } +} diff --git a/src/main/java/net/echo/brain4j/model/Model.java b/src/main/java/net/echo/brain4j/model/Model.java index 738b49e..cbc9359 100644 --- a/src/main/java/net/echo/brain4j/model/Model.java +++ b/src/main/java/net/echo/brain4j/model/Model.java @@ -166,7 +166,7 @@ public Vector predict(Vector input) { Neuron inputNeuron = synapse.getInputNeuron(); Neuron outputNeuron = synapse.getOutputNeuron(); - outputNeuron.setValue(outputNeuron.getValue() + inputNeuron.getValue() * synapse.getWeight()); + outputNeuron.setValue(outputNeuron.getLocalValue() + inputNeuron.getLocalValue() * synapse.getWeight()); } nextLayer.applyFunction(layer); @@ -177,7 +177,7 @@ public Vector predict(Vector input) { double[] output = new double[outputLayer.getNeurons().size()]; for (int i = 0; i < output.length; i++) { - output[i] = outputLayer.getNeuronAt(i).getValue(); + output[i] = outputLayer.getNeuronAt(i).getLocalValue(); } return Vector.of(output); diff --git a/src/main/java/net/echo/brain4j/structure/Neuron.java b/src/main/java/net/echo/brain4j/structure/Neuron.java index 1ec9d39..65370f5 100644 --- a/src/main/java/net/echo/brain4j/structure/Neuron.java +++ b/src/main/java/net/echo/brain4j/structure/Neuron.java @@ -8,6 +8,7 @@ public class Neuron { private final List synapses = new ArrayList<>(); + private ThreadLocal localValue = new ThreadLocal<>(); private double delta; private double value; @Expose private double bias = 2 * Math.random() - 1; @@ -28,11 +29,16 @@ public void setDelta(double delta) { this.delta = delta; } + public double getLocalValue() { + return localValue.get(); + } + public double getValue() { return value; } public void setValue(double value) { + localValue.set(value); this.value = value; } diff --git a/src/main/java/net/echo/brain4j/training/BackPropagation.java b/src/main/java/net/echo/brain4j/training/BackPropagation.java index fa97c71..dcb7ae5 100644 --- a/src/main/java/net/echo/brain4j/training/BackPropagation.java +++ b/src/main/java/net/echo/brain4j/training/BackPropagation.java @@ -26,20 +26,44 @@ public BackPropagation(Model model, Optimizer optimizer, Updater updater) { this.updater = updater; } - private List partition(List rows, int batches, int offset) { - return rows.subList(offset * batches, Math.max((offset + 1) * batches, rows.size())); + private List partition(List rows, double batches, int offset) { + int start = (int) Math.min(offset * batches, rows.size()); + int stop = (int) Math.min((offset + 1) * batches, rows.size()); + return rows.subList(start, stop); } public void iterate(DataSet dataSet, int batches) { - for (DataRow row : dataSet.getDataRows()) { - Vector output = model.predict(row.inputs()); - Vector target = row.outputs(); + List rows = dataSet.getDataRows(); + double rowsPerBatch = (double) rows.size() / batches; - backpropagate(target.toArray(), output.toArray()); + for (int i = 0; i < batches; i++) { + List batch = partition(dataSet.getDataRows(), rowsPerBatch, i); + List threads = new ArrayList<>(); + + for (DataRow row : batch) { + threads.add(Thread.startVirtualThread(() -> { + Vector output = model.predict(row.inputs()); + Vector target = row.outputs(); + + backpropagate(target.toArray(), output.toArray()); + })); + } + + waitAll(threads); + + List layers = model.getLayers(); + updater.postFit(layers, optimizer.getLearningRate()); } + } - List layers = model.getLayers(); - updater.postFit(layers, optimizer.getLearningRate()); + private void waitAll(List threads) { + for (Thread thread : threads) { + try { + thread.join(); + } catch (InterruptedException e) { + e.printStackTrace(System.err); + } + } } public void backpropagate(double[] targets, double[] outputs) { diff --git a/src/main/java/net/echo/brain4j/training/optimizers/Optimizer.java b/src/main/java/net/echo/brain4j/training/optimizers/Optimizer.java index e6fa847..b4b3695 100644 --- a/src/main/java/net/echo/brain4j/training/optimizers/Optimizer.java +++ b/src/main/java/net/echo/brain4j/training/optimizers/Optimizer.java @@ -87,7 +87,7 @@ public void applyGradientStep(Updater updater, Layer layer, Neuron neuron, Synap * @return the calculated gradient */ public double calculateGradient(Layer layer, Neuron neuron, Synapse synapse) { - double output = neuron.getValue(); + double output = neuron.getLocalValue(); double derivative = layer.getActivation().getFunction().getDerivative(output); @@ -96,7 +96,7 @@ public double calculateGradient(Layer layer, Neuron neuron, Synapse synapse) { neuron.setDelta(delta); - return clipGradient(delta * synapse.getInputNeuron().getValue()); + return clipGradient(delta * synapse.getInputNeuron().getLocalValue()); } /** diff --git a/src/main/java/net/echo/brain4j/training/optimizers/impl/Adam.java b/src/main/java/net/echo/brain4j/training/optimizers/impl/Adam.java index 7e6c435..346a047 100644 --- a/src/main/java/net/echo/brain4j/training/optimizers/impl/Adam.java +++ b/src/main/java/net/echo/brain4j/training/optimizers/impl/Adam.java @@ -20,7 +20,7 @@ public class Adam extends Optimizer { private double beta1; private double beta2; private double epsilon; - private int timestep; + private int timestep = 1; public Adam(double learningRate) { this(learningRate, 0.9, 0.999, 1e-8); @@ -41,7 +41,7 @@ public void postInitialize() { @Override public double update(Synapse synapse) { - double gradient = synapse.getOutputNeuron().getDelta() * synapse.getInputNeuron().getValue(); + double gradient = synapse.getOutputNeuron().getDelta() * synapse.getInputNeuron().getLocalValue(); int synapseId = synapse.getSynapseId(); diff --git a/src/main/java/net/echo/brain4j/training/optimizers/impl/GradientDescent.java b/src/main/java/net/echo/brain4j/training/optimizers/impl/GradientDescent.java index c73c51c..c99a907 100644 --- a/src/main/java/net/echo/brain4j/training/optimizers/impl/GradientDescent.java +++ b/src/main/java/net/echo/brain4j/training/optimizers/impl/GradientDescent.java @@ -15,7 +15,7 @@ public GradientDescent(double learningRate) { @Override public double update(Synapse synapse) { - return learningRate * synapse.getOutputNeuron().getDelta() * synapse.getInputNeuron().getValue(); + return learningRate * synapse.getOutputNeuron().getDelta() * synapse.getInputNeuron().getLocalValue(); } @Override