Skip to content

Commit

Permalink
Properly implemented batches and multi threading training
Browse files Browse the repository at this point in the history
  • Loading branch information
xEcho1337 committed Dec 18, 2024
1 parent 20d5005 commit c541d38
Show file tree
Hide file tree
Showing 20 changed files with 81 additions and 34 deletions.
1 change: 1 addition & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ repositories {
dependencies {
implementation 'com.google.code.gson:gson:2.10.1'
implementation 'commons-io:commons-io:2.18.0'
implementation 'org.jfree:jfreechart:1.5.3'
// implementation 'org.apache.commons:commons-io:2.18.0'
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public double getDerivative(double input) {
@Override
public void apply(List<Neuron> neurons) {
for (Neuron neuron : neurons) {
neuron.setValue(activate(neuron.getValue()));
neuron.setValue(activate(neuron.getLocalValue()));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public double getDerivative(double input) {
@Override
public void apply(List<Neuron> neurons) {
for (Neuron neuron : neurons) {
neuron.setValue(activate(neuron.getValue()));
neuron.setValue(activate(neuron.getLocalValue()));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public double getDerivative(double input) {
@Override
public void apply(List<Neuron> neurons) {
for (Neuron neuron : neurons) {
double output = activate(neuron.getValue() + neuron.getBias());
double output = activate(neuron.getLocalValue() + neuron.getBias());

neuron.setValue(output);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public double getDerivative(double input) {
@Override
public void apply(List<Neuron> neurons) {
for (Neuron neuron : neurons) {
double output = activate(neuron.getValue() + neuron.getBias());
double output = activate(neuron.getLocalValue() + neuron.getBias());

neuron.setValue(output);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public double getDerivative(double input) {
@Override
public void apply(List<Neuron> neurons) {
for (Neuron neuron : neurons) {
double output = activate(neuron.getValue() + neuron.getBias());
double output = activate(neuron.getLocalValue() + neuron.getBias());

neuron.setValue(output);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public double getDerivative(double input) {
@Override
public void apply(List<Neuron> neurons) {
for (Neuron neuron : neurons) {
double output = activate(neuron.getValue() + neuron.getBias());
double output = activate(neuron.getLocalValue() + neuron.getBias());

neuron.setValue(output);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public void apply(List<Neuron> neurons) {
double[] values = new double[neurons.size()];

for (int i = 0; i < neurons.size(); i++) {
values[i] = neurons.get(i).getValue() + neurons.get(i).getBias();
values[i] = neurons.get(i).getLocalValue() + neurons.get(i).getBias();
}

double[] activatedValues = activate(values);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public double getDerivative(double input) {
@Override
public void apply(List<Neuron> neurons) {
for (Neuron neuron : neurons) {
double value = neuron.getValue() + neuron.getBias();
double value = neuron.getLocalValue() + neuron.getBias();
neuron.setValue(activate(value));
}
}
Expand Down
6 changes: 3 additions & 3 deletions src/main/java/net/echo/brain4j/layer/Layer.java
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ public void activate() {
Neuron inputNeuron = synapse.getInputNeuron();
Neuron outputNeuron = synapse.getOutputNeuron();

outputNeuron.setValue(outputNeuron.getValue() + inputNeuron.getValue() * synapse.getWeight());
outputNeuron.setValue(outputNeuron.getLocalValue() + inputNeuron.getLocalValue() * synapse.getWeight());
}
}

Expand All @@ -73,7 +73,7 @@ public void activate(Vector input) {
for (Synapse synapse : inputNeuron.getSynapses()) {
Neuron outputNeuron = synapse.getOutputNeuron();

outputNeuron.setValue(outputNeuron.getValue() + input.get(i) * synapse.getWeight());
outputNeuron.setValue(outputNeuron.getLocalValue() + input.get(i) * synapse.getWeight());
}
}
}
Expand All @@ -82,7 +82,7 @@ public Vector getVector() {
Vector values = new Vector(neurons.size());

for (int i = 0; i < neurons.size(); i++) {
values.set(i, neurons.get(i).getValue());
values.set(i, neurons.get(i).getLocalValue());
}

return values;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public void process(List<Neuron> neurons) {

public void backward(List<Neuron> neurons) {
for (Neuron neuron : neurons) {
neuron.setValue(neuron.getValue() * (1.0 - dropout));
neuron.setValue(neuron.getLocalValue() * (1.0 - dropout));
}
}

Expand Down
6 changes: 3 additions & 3 deletions src/main/java/net/echo/brain4j/layer/impl/LayerNorm.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public void applyFunction(Layer previous) {
double variance = calculateVariance(inputs, mean);

for (Neuron input : inputs) {
double value = input.getValue();
double value = input.getLocalValue();
double normalized = (value - mean) / Math.sqrt(variance + epsilon);

input.setValue(normalized);
Expand All @@ -55,7 +55,7 @@ private double calculateMean(List<Neuron> inputs) {
double sum = 0.0;

for (Neuron value : inputs) {
sum += value.getValue();
sum += value.getLocalValue();
}

return sum / inputs.size();
Expand All @@ -65,7 +65,7 @@ private double calculateVariance(List<Neuron> inputs, double mean) {
double sum = 0.0;

for (Neuron value : inputs) {
sum += Math.pow(value.getValue() - mean, 2);
sum += Math.pow(value.getLocalValue() - mean, 2);
}

return sum / inputs.size();
Expand Down
7 changes: 3 additions & 4 deletions src/main/java/net/echo/brain4j/loss/LossFunctions.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
package net.echo.brain4j.loss;

import net.echo.brain4j.loss.impl.BinaryCrossEntropy;
import net.echo.brain4j.loss.impl.CategoricalCrossEntropy;
import net.echo.brain4j.loss.impl.CrossEntropy;
import net.echo.brain4j.loss.impl.MeanSquaredError;
import net.echo.brain4j.loss.impl.*;

public enum LossFunctions {

Expand All @@ -14,6 +11,8 @@ public enum LossFunctions {
*/
MEAN_SQUARED_ERROR(new MeanSquaredError()),

MEAN_ABSOLUTE_ERROR(new MeanAbsoluteError()),

/**
* Binary Cross Entropy: Used to evaluate the error in binary classification tasks
* by measuring the divergence between the predicted probabilities and the actual binary labels.
Expand Down
17 changes: 17 additions & 0 deletions src/main/java/net/echo/brain4j/loss/impl/MeanAbsoluteError.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package net.echo.brain4j.loss.impl;

import net.echo.brain4j.loss.LossFunction;

public class MeanAbsoluteError implements LossFunction {

@Override
public double calculate(double[] actual, double[] predicted) {
double error = 0.0;

for (int i = 0; i < actual.length; i++) {
error += Math.abs(actual[i] - predicted[i]);
}

return error / actual.length;
}
}
4 changes: 2 additions & 2 deletions src/main/java/net/echo/brain4j/model/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ public Vector predict(Vector input) {
Neuron inputNeuron = synapse.getInputNeuron();
Neuron outputNeuron = synapse.getOutputNeuron();

outputNeuron.setValue(outputNeuron.getValue() + inputNeuron.getValue() * synapse.getWeight());
outputNeuron.setValue(outputNeuron.getLocalValue() + inputNeuron.getLocalValue() * synapse.getWeight());
}

nextLayer.applyFunction(layer);
Expand All @@ -177,7 +177,7 @@ public Vector predict(Vector input) {
double[] output = new double[outputLayer.getNeurons().size()];

for (int i = 0; i < output.length; i++) {
output[i] = outputLayer.getNeuronAt(i).getValue();
output[i] = outputLayer.getNeuronAt(i).getLocalValue();
}

return Vector.of(output);
Expand Down
6 changes: 6 additions & 0 deletions src/main/java/net/echo/brain4j/structure/Neuron.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
public class Neuron {

private final List<Synapse> synapses = new ArrayList<>();
private ThreadLocal<Double> localValue = new ThreadLocal<>();
private double delta;
private double value;
@Expose private double bias = 2 * Math.random() - 1;
Expand All @@ -28,11 +29,16 @@ public void setDelta(double delta) {
this.delta = delta;
}

public double getLocalValue() {
return localValue.get();
}

public double getValue() {
return value;
}

public void setValue(double value) {
localValue.set(value);
this.value = value;
}

Expand Down
40 changes: 32 additions & 8 deletions src/main/java/net/echo/brain4j/training/BackPropagation.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,44 @@ public BackPropagation(Model model, Optimizer optimizer, Updater updater) {
this.updater = updater;
}

private List<DataRow> partition(List<DataRow> rows, int batches, int offset) {
return rows.subList(offset * batches, Math.max((offset + 1) * batches, rows.size()));
private List<DataRow> partition(List<DataRow> rows, double batches, int offset) {
int start = (int) Math.min(offset * batches, rows.size());
int stop = (int) Math.min((offset + 1) * batches, rows.size());
return rows.subList(start, stop);
}

public void iterate(DataSet dataSet, int batches) {
for (DataRow row : dataSet.getDataRows()) {
Vector output = model.predict(row.inputs());
Vector target = row.outputs();
List<DataRow> rows = dataSet.getDataRows();
double rowsPerBatch = (double) rows.size() / batches;

backpropagate(target.toArray(), output.toArray());
for (int i = 0; i < batches; i++) {
List<DataRow> batch = partition(dataSet.getDataRows(), rowsPerBatch, i);
List<Thread> threads = new ArrayList<>();

for (DataRow row : batch) {
threads.add(Thread.startVirtualThread(() -> {
Vector output = model.predict(row.inputs());
Vector target = row.outputs();

backpropagate(target.toArray(), output.toArray());
}));
}

waitAll(threads);

List<Layer> layers = model.getLayers();
updater.postFit(layers, optimizer.getLearningRate());
}
}

List<Layer> layers = model.getLayers();
updater.postFit(layers, optimizer.getLearningRate());
private void waitAll(List<Thread> threads) {
for (Thread thread : threads) {
try {
thread.join();
} catch (InterruptedException e) {
e.printStackTrace(System.err);
}
}
}

public void backpropagate(double[] targets, double[] outputs) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public void applyGradientStep(Updater updater, Layer layer, Neuron neuron, Synap
* @return the calculated gradient
*/
public double calculateGradient(Layer layer, Neuron neuron, Synapse synapse) {
double output = neuron.getValue();
double output = neuron.getLocalValue();

double derivative = layer.getActivation().getFunction().getDerivative(output);

Expand All @@ -96,7 +96,7 @@ public double calculateGradient(Layer layer, Neuron neuron, Synapse synapse) {

neuron.setDelta(delta);

return clipGradient(delta * synapse.getInputNeuron().getValue());
return clipGradient(delta * synapse.getInputNeuron().getLocalValue());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public class Adam extends Optimizer {
private double beta1;
private double beta2;
private double epsilon;
private int timestep;
private int timestep = 1;

public Adam(double learningRate) {
this(learningRate, 0.9, 0.999, 1e-8);
Expand All @@ -41,7 +41,7 @@ public void postInitialize() {

@Override
public double update(Synapse synapse) {
double gradient = synapse.getOutputNeuron().getDelta() * synapse.getInputNeuron().getValue();
double gradient = synapse.getOutputNeuron().getDelta() * synapse.getInputNeuron().getLocalValue();

int synapseId = synapse.getSynapseId();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public GradientDescent(double learningRate) {

@Override
public double update(Synapse synapse) {
return learningRate * synapse.getOutputNeuron().getDelta() * synapse.getInputNeuron().getValue();
return learningRate * synapse.getOutputNeuron().getDelta() * synapse.getInputNeuron().getLocalValue();
}

@Override
Expand Down

0 comments on commit c541d38

Please sign in to comment.