Skip to content

Commit

Permalink
Small performance improvement
Browse files Browse the repository at this point in the history
  • Loading branch information
xEcho1337 committed Dec 9, 2024
1 parent dfe2de3 commit b6eb524
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,16 @@ public void postIteration(List<Layer> layers) {
this.beta1Timestep = Math.pow(beta1, timestep);
this.beta2Timestep = Math.pow(beta2, timestep);

for (Layer layer : layers) {
// 30% improvement using parallel stream. TODO: Implement GPU support for better parallelization
layer.getSynapses().parallelStream().forEach(this::update);
layers.parallelStream().forEach(layer -> {
for (Synapse synapse : layer.getSynapses()) {
update(synapse);
}

for (Neuron neuron : layer.getNeurons()) {
double deltaBias = learningRate * neuron.getDelta();
neuron.setBias(neuron.getBias() + deltaBias);
}
}
});
}

public double getBeta1() {
Expand Down
11 changes: 6 additions & 5 deletions src/test/java/XorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ public class XorTest {
public static void main(String[] args) {
Model network = new Model(
new DenseLayer(2, Activations.LINEAR),
new DenseLayer(32, Activations.RELU),
new DenseLayer(32, Activations.RELU),
new DenseLayer(32, Activations.RELU),
new DenseLayer(256, Activations.RELU),
new DenseLayer(256, Activations.RELU),
new DenseLayer(256, Activations.RELU),
new DenseLayer(1, Activations.SIGMOID)
);

Expand All @@ -32,14 +32,15 @@ public static void main(String[] args) {

long start = System.nanoTime();

for (int i = 0; i < 1000; i++) {
int epoches = 500;
for (int i = 0; i < epoches; i++) {
network.fit(training, 1);
}

double error = network.evaluate(training);
double took = (System.nanoTime() - start) / 1e6;

System.out.println("Took " + took + " ms with an average of " + (took / 1000) + " ms per epoch and error " + error);
System.out.println("Took " + took + " ms with an average of " + (took / epoches) + " ms per epoch and error " + error);

for (DataRow row : training.getDataRows()) {
Vector output = network.predict(row.inputs());
Expand Down

0 comments on commit b6eb524

Please sign in to comment.