Skip to content

Commit

Permalink
Added missing files
Browse files Browse the repository at this point in the history
  • Loading branch information
xEcho1337 committed Dec 13, 2024
1 parent 918dd6c commit 0f1180d
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 0 deletions.
20 changes: 20 additions & 0 deletions src/main/java/net/echo/brain4j/training/updater/Updater.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package net.echo.brain4j.training.updater;

import net.echo.brain4j.layer.Layer;
import net.echo.brain4j.structure.Synapse;

import java.util.List;

public abstract class Updater {

public void postInitialize() {
}

public void postIteration(List<Layer> layers, double learningRate) {
}

public void postFit(List<Layer> layers, double learningRate) {
}

public abstract void acknowledgeChange(Synapse synapse, double change, double learningRate);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package net.echo.brain4j.training.updater.impl;

import net.echo.brain4j.layer.Layer;
import net.echo.brain4j.structure.Neuron;
import net.echo.brain4j.structure.Synapse;
import net.echo.brain4j.training.updater.Updater;

import java.util.Arrays;
import java.util.List;

public class NormalUpdater extends Updater {

private Synapse[] synapses;
private Double[] gradients;
private int minSynapseIndex;

@Override
public void postInitialize() {
this.synapses = new Synapse[Synapse.ID_COUNTER];
this.gradients = new Double[Synapse.ID_COUNTER];
Arrays.fill(gradients, 0.0);
}

@Override
public void postFit(List<Layer> layers, double learningRate) {
for (int i = 0; i < gradients.length; i++) {
Synapse synapse = synapses[i];
double gradient = gradients[i];

synapse.setWeight(synapse.getWeight() + learningRate * gradient);
}

for (Layer layer : layers) {
for (Neuron neuron : layer.getNeurons()) {
double deltaBias = learningRate * neuron.getDelta();
neuron.setBias(neuron.getBias() + deltaBias);
}
}

Arrays.fill(gradients, 0.0);
}

@Override
public void acknowledgeChange(Synapse synapse, double change, double learningRate) {
int id = synapse.getSynapseId();

synapses[id] = synapse;
gradients[id] += change;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package net.echo.brain4j.training.updater.impl;

import net.echo.brain4j.layer.Layer;
import net.echo.brain4j.structure.Neuron;
import net.echo.brain4j.structure.Synapse;
import net.echo.brain4j.training.updater.Updater;

import java.util.List;

public class StochasticUpdater extends Updater {

@Override
public void postIteration(List<Layer> layers, double learningRate) {
for (Layer layer : layers) {
for (Neuron neuron : layer.getNeurons()) {
double deltaBias = learningRate * neuron.getDelta();
neuron.setBias(neuron.getBias() + deltaBias);
}
}
}

@Override
public void acknowledgeChange(Synapse synapse, double change, double learningRate) {
synapse.setWeight(synapse.getWeight() + change);
}
}

0 comments on commit 0f1180d

Please sign in to comment.