Skip to content

Commit

Permalink
Working on batches
Browse files Browse the repository at this point in the history
  • Loading branch information
xEcho1337 committed Dec 14, 2024
1 parent 01d389f commit be564a2
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 11 deletions.
6 changes: 3 additions & 3 deletions src/main/java/net/echo/brain4j/model/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ private void connect(WeightInitialization type) {
*
* @param type initialization method
* @param function loss function for error assessment
* @param optimizer optimization algorithm
* @param updater updater for weights
* @param optimizer optimization algorithm for training
* @param updater weights updating algorithm for training
*/
public void compile(WeightInitialization type, LossFunctions function, Optimizer optimizer, Updater updater) {
this.function = function;
Expand All @@ -103,7 +103,7 @@ public void compile(WeightInitialization type, LossFunctions function, Optimizer
* @param set dataset for training
*/
public void fit(DataSet set, int batchSize) {
propagation.iterate(set);
propagation.iterate(set, batchSize);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public enum WeightInitialization {
NORMAL(new NormalInit()),

/**
* He initialization (also known as He et al. initialization) is designed for layers with ReLU activations.
* He initialization (also known as Kaiming initialization) is designed for layers with ReLU activations.
* It initializes weights with a normal distribution, scaled by the square root of 2 divided by the number of input neurons.
*/
HE(new HeInit()),
Expand Down
37 changes: 35 additions & 2 deletions src/main/java/net/echo/brain4j/training/BackPropagation.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
import net.echo.brain4j.layer.impl.DropoutLayer;
import net.echo.brain4j.model.Model;
import net.echo.brain4j.structure.Neuron;
import net.echo.brain4j.structure.Synapse;
import net.echo.brain4j.training.data.DataRow;
import net.echo.brain4j.training.data.DataSet;
import net.echo.brain4j.training.optimizers.Optimizer;
import net.echo.brain4j.training.updater.Updater;
import net.echo.brain4j.utils.Vector;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;

public class BackPropagation {

Expand All @@ -25,7 +26,39 @@ public BackPropagation(Model model, Optimizer optimizer, Updater updater) {
this.updater = updater;
}

public void iterate(DataSet dataSet) {
private List<DataRow> partition(List<DataRow> rows, int batches, int offset) {
return rows.subList(offset * batches, Math.max((offset + 1) * batches, rows.size()));
}

public void iterate(DataSet dataSet, int batches) {
/*List<DataRow> rows = dataSet.getDataRows();
int rowsPerBatch = rows.size() / batches;
List<Thread> threads = new ArrayList<>();
for (int i = 0; i < batches; i++) {
List<DataRow> batch = partition(rows, rowsPerBatch, i);
Thread thread = Thread.startVirtualThread(() -> {
for (DataRow row : batch) {
Vector output = model.predict(row.inputs());
Vector target = row.outputs();
backpropagate(target.toArray(), output.toArray());
}
});
threads.add(thread);
}
for (Thread thread : threads) {
try {
thread.join();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}*/
for (DataRow row : dataSet.getDataRows()) {
Vector output = model.predict(row.inputs());
Vector target = row.outputs();
Expand Down
8 changes: 8 additions & 0 deletions src/main/java/net/echo/brain4j/training/data/DataSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

public class DataSet {
Expand All @@ -19,4 +20,11 @@ public DataSet(DataRow... rows) {
public List<DataRow> getDataRows() {
return dataRows;
}

/**
* Randomly shuffles the dataset, making the training more efficient.
*/
public void shuffle() {
Collections.shuffle(dataRows);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ public void postIteration(Updater updater, List<Layer> layers) {
this.beta2Timestep = Math.pow(beta2, timestep);

for (Layer layer : layers) {
layer.getSynapses().parallelStream().forEach(synapse -> {
for (Synapse synapse : layer.getSynapses()) {
double change = update(synapse);
updater.acknowledgeChange(synapse, change, learningRate);
});
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ public double update(Synapse synapse) {
@Override
public void postIteration(Updater updater, List<Layer> layers) {
for (Layer layer : layers) {
layer.getSynapses().parallelStream().forEach(synapse -> {
for (Synapse synapse : layer.getSynapses()) {
double change = update(synapse);
updater.acknowledgeChange(synapse, change, learningRate);
});
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ public class NormalUpdater extends Updater {

private Synapse[] synapses;
private Double[] gradients;
private int minSynapseIndex;

@Override
public void postInitialize() {
Expand Down

0 comments on commit be564a2

Please sign in to comment.