From 5fe709d71db27956b0c5573690f3c4b8e5136cf5 Mon Sep 17 00:00:00 2001 From: xEcho1337 Date: Sat, 30 Nov 2024 01:14:48 +0100 Subject: [PATCH] Removed useless abstraction --- .../java/net/echo/brain4j/model/Model.java | 272 +++++++++++++++- .../brain4j/model/impl/NeuralNetwork.java | 292 ------------------ .../nlp/agents/model/TransformerModel.java | 4 +- .../brain4j/training/BackPropagation.java | 1 - src/test/java/XorTest.java | 5 +- 5 files changed, 263 insertions(+), 311 deletions(-) delete mode 100644 src/main/java/net/echo/brain4j/model/impl/NeuralNetwork.java diff --git a/src/main/java/net/echo/brain4j/model/Model.java b/src/main/java/net/echo/brain4j/model/Model.java index c26190f..453ee81 100644 --- a/src/main/java/net/echo/brain4j/model/Model.java +++ b/src/main/java/net/echo/brain4j/model/Model.java @@ -1,18 +1,81 @@ package net.echo.brain4j.model; +import com.google.gson.Gson; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; +import com.google.gson.reflect.TypeToken; +import net.echo.brain4j.adapters.LayerAdapter; +import net.echo.brain4j.adapters.OptimizerAdapter; import net.echo.brain4j.layer.Layer; +import net.echo.brain4j.layer.impl.DenseLayer; +import net.echo.brain4j.layer.impl.DropoutLayer; import net.echo.brain4j.loss.LossFunction; import net.echo.brain4j.loss.LossFunctions; import net.echo.brain4j.model.initialization.InitializationType; +import net.echo.brain4j.structure.Neuron; +import net.echo.brain4j.structure.Synapse; +import net.echo.brain4j.training.BackPropagation; +import net.echo.brain4j.training.data.DataRow; import net.echo.brain4j.training.data.DataSet; import net.echo.brain4j.training.optimizers.Optimizer; +import net.echo.brain4j.training.optimizers.impl.Adam; +import net.echo.brain4j.training.optimizers.impl.SGD; +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileReader; +import java.io.FileWriter; +import java.lang.reflect.Type; +import java.util.ArrayList; +import java.util.Arrays; import java.util.List; /** - * Interface for a neural network model, defining methods for compilation, training, and prediction. + * Represents a generic neural network model. */ -public interface Model { +public class Model { + + private static final OptimizerAdapter OPTIMIZER_ADAPTER = new OptimizerAdapter(); + private static final LayerAdapter LAYER_ADAPTER = new LayerAdapter(); + private static final Gson GSON = new Gson() + .newBuilder() + .setPrettyPrinting() + .excludeFieldsWithoutExposeAnnotation() + .registerTypeAdapter(DenseLayer.class, LAYER_ADAPTER) + .registerTypeAdapter(DropoutLayer.class, LAYER_ADAPTER) + .registerTypeAdapter(Adam.class, OPTIMIZER_ADAPTER) + .registerTypeAdapter(SGD.class, OPTIMIZER_ADAPTER) + .create(); + + private List layers; + private LossFunctions function; + private Optimizer optimizer; + private BackPropagation propagation; + + public Model(Layer... layers) { + this.layers = new ArrayList<>(Arrays.asList(layers)); + } + + private void connect(InitializationType type) { + for (int i = 0; i < layers.size() - 1; i++) { + Layer layer = layers.get(i); + + if (layer instanceof DropoutLayer) continue; + + Layer nextLayer = layers.get(i + 1); + + if (nextLayer instanceof DropoutLayer) { + nextLayer = layers.get(i + 2); + } + + int nIn = layer.getNeurons().size(); + int nOut = nextLayer.getNeurons().size(); + + double bound = type.getInitializer().getBound(nIn, nOut); + + layer.connectAll(nextLayer, bound); + } + } /** * Initializes the model and layers. @@ -21,14 +84,27 @@ public interface Model { * @param function loss function for error assessment * @param optimizer optimization algorithm */ - void compile(InitializationType type, LossFunctions function, Optimizer optimizer); + public void compile(InitializationType type, LossFunctions function, Optimizer optimizer) { + this.function = function; + this.optimizer = optimizer; + this.propagation = new BackPropagation(this, optimizer); + + connect(type); + } /** * Trains the model for one epoch. * * @param set dataset for training */ - void fit(DataSet set); + public void fit(DataSet set) { + for (DataRow row : set.getDataRows()) { + double[] inputs = row.inputs(); + double[] outputs = predict(inputs); + + propagation.iterate(new DataSet(row), optimizer.getLearningRate()); + } + } /** * Evaluates the model on the given dataset. @@ -36,7 +112,20 @@ public interface Model { * @param set dataset for testing * @return the error of the model */ - double evaluate(DataSet set); + public double evaluate(DataSet set) { + double totalError = 0.0; + + for (DataRow row : set.getDataRows()) { + double[] inputs = row.inputs(); + double[] targets = row.outputs(); + + double[] outputs = predict(inputs); + + totalError += function.getFunction().calculate(targets, outputs); + } + + return totalError; + } /** * Predicts output for given input. @@ -44,47 +133,204 @@ public interface Model { * @param input input data * @return predicted outputs */ - double[] predict(double... input); + public double[] predict(double ... input) { + Layer inputLayer = layers.get(0); + + if (input.length != inputLayer.getNeurons().size()) { + throw new IllegalArgumentException("Input size does not match model's input dimension! (Input != Expected) " + + input.length + " != " + inputLayer.getNeurons().size()); + } + + for (Layer layer : layers) { + for (Neuron neuron : layer.getNeurons()) { + neuron.setValue(0); + } + } + + for (int i = 0; i < input.length; i++) { + inputLayer.getNeuronAt(i).setValue(input[i]); + } + + for (int l = 0; l < layers.size() - 1; l++) { + Layer layer = layers.get(l); + + if (layer instanceof DropoutLayer) continue; + + Layer nextLayer = layers.get(l + 1); + + if (nextLayer instanceof DropoutLayer) { + nextLayer = layers.get(l + 2); + } + + for (Synapse synapse : layer.getSynapses()) { + Neuron inputNeuron = synapse.getInputNeuron(); + Neuron outputNeuron = synapse.getOutputNeuron(); + + outputNeuron.setValue(outputNeuron.getValue() + inputNeuron.getValue() * synapse.getWeight()); + } + + nextLayer.applyFunction(); + } + + Layer outputLayer = layers.get(layers.size() - 1); + double[] output = new double[outputLayer.getNeurons().size()]; + + for (int i = 0; i < output.length; i++) { + output[i] = outputLayer.getNeuronAt(i).getValue(); + } + + return output; + } /** * Gets the model's loss function. * * @return loss function */ - LossFunction getLossFunction(); + public LossFunction getLossFunction() { + return function.getFunction(); + } /** * Retrieves layers of the network. * * @return list of layers */ - List getLayers(); + public List getLayers() { + return layers; + } /** * Generates model statistics. * * @return model stats */ - String getStats(); + public String getStats() { + StringBuilder stats = new StringBuilder(); + stats.append(String.format("%-7s %-15s %-10s %-12s\n", "Index", "Layer name", "nIn, nOut", "TotalParams")); + stats.append("================================================\n"); + + int params = 0; + + for (int i = 0; i < layers.size(); i++) { + Layer layer = layers.get(i); + Layer next = layers.get(Math.min(i, layers.size() - 1)); + + if (next instanceof DropoutLayer) { + next = layers.get(Math.min(i + 1, layers.size() - 1)); + } + + String layerType = layer.getClass().getSimpleName(); + + int nIn = layer.getNeurons().size(); + int nOut = i == layers.size() - 1 ? 0 : next.getNeurons().size(); + + int totalParams = layer.getTotalParams(); + + String formatNin = layer instanceof DropoutLayer ? "-" : String.valueOf(nIn); + String formatNout = layer instanceof DropoutLayer ? "-" : String.valueOf(nOut); + + stats.append(String.format("%-7d %-15s %-10s %-12d\n", + i, layerType, formatNin + ", " + formatNout, totalParams)); + + params += totalParams; + } + + stats.append("================================================\n"); + stats.append("Total parameters: ").append(params).append("\n"); + stats.append("================================================\n"); + return stats.toString(); + } /** * Loads a model from a file. * * @param path path to model file */ - void load(String path); + public void load(String path) { + File file = new File(path); + + if (!file.exists()) { + throw new IllegalArgumentException("File does not exist: " + path); + } + + try { + JsonObject parent = JsonParser.parseReader(new FileReader(file)).getAsJsonObject(); + + this.optimizer = GSON.fromJson(parent.get("optimizer"), Optimizer.class); + this.function = LossFunctions.valueOf(parent.get("lossFunction").getAsString()); + + Type listType = new TypeToken>(){}.getType(); + + this.layers = GSON.fromJson(parent.get("layers"), listType); + + connect(InitializationType.NORMAL); + + double[][] weights = GSON.fromJson(parent.get("weights"), double[][].class); + + for (int i = 0; i < weights.length; i++) { + double[] layerWeights = weights[i]; + Layer layer = layers.get(i); + + for (int j = 0; j < layerWeights.length; j++) { + layer.getSynapses().get(j).setWeight(layerWeights[j]); + } + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } /** * Saves the model to a file. * * @param path path to save model */ - void save(String path); + public void save(String path) { + File file = new File(path); + + JsonObject parent = new JsonObject(); + JsonObject optimizerObject = GSON.toJsonTree(optimizer).getAsJsonObject(); + + parent.addProperty("lossFunction", function.name()); + parent.add("optimizer", optimizerObject); + + List layerObjects = new ArrayList<>(); + + for (Layer layer : layers) { + layerObjects.add(GSON.toJsonTree(layer).getAsJsonObject()); + } + + parent.add("layers", GSON.toJsonTree(layerObjects).getAsJsonArray()); + + double[][] weights = new double[layers.size()][]; + + for (int i = 0; i < layers.size(); i++) { + Layer layer = layers.get(i); + weights[i] = new double[layer.getSynapses().size()]; + + for (int j = 0; j < layer.getSynapses().size(); j++) { + Synapse synapse = layer.getSynapses().get(j); + + weights[i][j] = synapse.getWeight(); + } + } + + parent.add("weights", GSON.toJsonTree(weights)); + + try (BufferedWriter writer = new BufferedWriter(new FileWriter(file))) { + GSON.toJson(parent, writer); + } catch (Exception e) { + throw new RuntimeException(e); + } + } /** * Adds a layers to the network. * * @param layers the layers to add */ - void add(Layer... layers); -} \ No newline at end of file + public void add(Layer... layers) { + this.layers.addAll(Arrays.asList(layers)); + } +} diff --git a/src/main/java/net/echo/brain4j/model/impl/NeuralNetwork.java b/src/main/java/net/echo/brain4j/model/impl/NeuralNetwork.java deleted file mode 100644 index c1a3e34..0000000 --- a/src/main/java/net/echo/brain4j/model/impl/NeuralNetwork.java +++ /dev/null @@ -1,292 +0,0 @@ -package net.echo.brain4j.model.impl; - -import com.google.gson.Gson; -import com.google.gson.JsonObject; -import com.google.gson.JsonParser; -import com.google.gson.reflect.TypeToken; -import net.echo.brain4j.adapters.LayerAdapter; -import net.echo.brain4j.adapters.OptimizerAdapter; -import net.echo.brain4j.layer.Layer; -import net.echo.brain4j.layer.impl.DenseLayer; -import net.echo.brain4j.layer.impl.DropoutLayer; -import net.echo.brain4j.loss.LossFunction; -import net.echo.brain4j.loss.LossFunctions; -import net.echo.brain4j.model.Model; -import net.echo.brain4j.model.initialization.InitializationType; -import net.echo.brain4j.structure.Neuron; -import net.echo.brain4j.structure.Synapse; -import net.echo.brain4j.training.BackPropagation; -import net.echo.brain4j.training.data.DataRow; -import net.echo.brain4j.training.data.DataSet; -import net.echo.brain4j.training.optimizers.Optimizer; -import net.echo.brain4j.training.optimizers.impl.Adam; -import net.echo.brain4j.training.optimizers.impl.SGD; - -import java.io.BufferedWriter; -import java.io.File; -import java.io.FileReader; -import java.io.FileWriter; -import java.lang.reflect.Type; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -public class NeuralNetwork implements Model { - - private static final OptimizerAdapter OPTIMIZER_ADAPTER = new OptimizerAdapter(); - private static final LayerAdapter LAYER_ADAPTER = new LayerAdapter(); - private static final Gson GSON = new Gson() - .newBuilder() - .setPrettyPrinting() - .excludeFieldsWithoutExposeAnnotation() - .registerTypeAdapter(DenseLayer.class, LAYER_ADAPTER) - .registerTypeAdapter(DropoutLayer.class, LAYER_ADAPTER) - .registerTypeAdapter(Adam.class, OPTIMIZER_ADAPTER) - .registerTypeAdapter(SGD.class, OPTIMIZER_ADAPTER) - .create(); - - private List layers; - private LossFunctions function; - private Optimizer optimizer; - private BackPropagation propagation; - - public NeuralNetwork(Layer... layers) { - this.layers = new ArrayList<>(Arrays.asList(layers)); - } - - private void connect(InitializationType type) { - for (int i = 0; i < layers.size() - 1; i++) { - Layer layer = layers.get(i); - - if (layer instanceof DropoutLayer) continue; - - Layer nextLayer = layers.get(i + 1); - - if (nextLayer instanceof DropoutLayer) { - nextLayer = layers.get(i + 2); - } - - int nIn = layer.getNeurons().size(); - int nOut = nextLayer.getNeurons().size(); - - double bound = type.getInitializer().getBound(nIn, nOut); - - layer.connectAll(nextLayer, bound); - } - } - - @Override - public void compile(InitializationType type, LossFunctions function, Optimizer optimizer) { - this.function = function; - this.optimizer = optimizer; - this.propagation = new BackPropagation(this, optimizer); - - connect(type); - } - - @Override - public void fit(DataSet set) { - for (DataRow row : set.getDataRows()) { - double[] inputs = row.inputs(); - double[] outputs = predict(inputs); - - propagation.iterate(new DataSet(row), optimizer.getLearningRate()); - } - } - - @Override - public double evaluate(DataSet set) { - double totalError = 0.0; - - for (DataRow row : set.getDataRows()) { - double[] inputs = row.inputs(); - double[] targets = row.outputs(); - - double[] outputs = predict(inputs); - - double currentError = function.getFunction().calculate(targets, outputs); - // System.out.println("Currerr: " + currentError); - totalError += currentError; - } - - return totalError; - } - - @Override - public double[] predict(double ... input) { - Layer inputLayer = layers.get(0); - - if (input.length != inputLayer.getNeurons().size()) { - throw new IllegalArgumentException("Input size does not match model's input dimension! (Input != Expected) " + - input.length + " != " + inputLayer.getNeurons().size()); - } - - for (Layer layer : layers) { - for (Neuron neuron : layer.getNeurons()) { - neuron.setValue(0); - } - } - - for (int i = 0; i < input.length; i++) { - inputLayer.getNeuronAt(i).setValue(input[i]); - } - - for (int l = 0; l < layers.size() - 1; l++) { - Layer layer = layers.get(l); - - if (layer instanceof DropoutLayer) continue; - - Layer nextLayer = layers.get(l + 1); - - if (nextLayer instanceof DropoutLayer) { - nextLayer = layers.get(l + 2); - } - - for (Synapse synapse : layer.getSynapses()) { - Neuron inputNeuron = synapse.getInputNeuron(); - Neuron outputNeuron = synapse.getOutputNeuron(); - - outputNeuron.setValue(outputNeuron.getValue() + inputNeuron.getValue() * synapse.getWeight()); - } - - nextLayer.applyFunction(); - } - - Layer outputLayer = layers.get(layers.size() - 1); - double[] output = new double[outputLayer.getNeurons().size()]; - - for (int i = 0; i < output.length; i++) { - output[i] = outputLayer.getNeuronAt(i).getValue(); - } - - return output; - } - - @Override - public LossFunction getLossFunction() { - return function.getFunction(); - } - - @Override - public List getLayers() { - return layers; - } - - @Override - public String getStats() { - StringBuilder stats = new StringBuilder(); - stats.append(String.format("%-7s %-15s %-10s %-12s\n", "Index", "Layer name", "nIn, nOut", "TotalParams")); - stats.append("================================================\n"); - - int params = 0; - - for (int i = 0; i < layers.size(); i++) { - Layer layer = layers.get(i); - Layer next = layers.get(Math.min(i, layers.size() - 1)); - - if (next instanceof DropoutLayer) { - next = layers.get(Math.min(i + 1, layers.size() - 1)); - } - - String layerType = layer.getClass().getSimpleName(); - - int nIn = layer.getNeurons().size(); - int nOut = i == layers.size() - 1 ? 0 : next.getNeurons().size(); - - int totalParams = layer.getTotalParams(); - - String formatNin = layer instanceof DropoutLayer ? "-" : String.valueOf(nIn); - String formatNout = layer instanceof DropoutLayer ? "-" : String.valueOf(nOut); - - stats.append(String.format("%-7d %-15s %-10s %-12d\n", - i, layerType, formatNin + ", " + formatNout, totalParams)); - - params += totalParams; - } - - stats.append("================================================\n"); - stats.append("Total parameters: ").append(params).append("\n"); - stats.append("================================================\n"); - return stats.toString(); - } - - @Override - public void load(String path) { - File file = new File(path); - - if (!file.exists()) { - throw new IllegalArgumentException("File does not exist: " + path); - } - - try { - JsonObject parent = JsonParser.parseReader(new FileReader(file)).getAsJsonObject(); - - this.optimizer = GSON.fromJson(parent.get("optimizer"), Optimizer.class); - this.function = LossFunctions.valueOf(parent.get("lossFunction").getAsString()); - - Type listType = new TypeToken>(){}.getType(); - - this.layers = GSON.fromJson(parent.get("layers"), listType); - - connect(InitializationType.NORMAL); - - double[][] weights = GSON.fromJson(parent.get("weights"), double[][].class); - - for (int i = 0; i < weights.length; i++) { - double[] layerWeights = weights[i]; - Layer layer = layers.get(i); - - for (int j = 0; j < layerWeights.length; j++) { - layer.getSynapses().get(j).setWeight(layerWeights[j]); - } - } - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - @Override - public void save(String path) { - File file = new File(path); - - JsonObject parent = new JsonObject(); - JsonObject optimizerObject = GSON.toJsonTree(optimizer).getAsJsonObject(); - - parent.addProperty("lossFunction", function.name()); - parent.add("optimizer", optimizerObject); - - List layerObjects = new ArrayList<>(); - - for (Layer layer : layers) { - layerObjects.add(GSON.toJsonTree(layer).getAsJsonObject()); - } - - parent.add("layers", GSON.toJsonTree(layerObjects).getAsJsonArray()); - - double[][] weights = new double[layers.size()][]; - - for (int i = 0; i < layers.size(); i++) { - Layer layer = layers.get(i); - weights[i] = new double[layer.getSynapses().size()]; - - for (int j = 0; j < layer.getSynapses().size(); j++) { - Synapse synapse = layer.getSynapses().get(j); - - weights[i][j] = synapse.getWeight(); - } - } - - parent.add("weights", GSON.toJsonTree(weights)); - - try (BufferedWriter writer = new BufferedWriter(new FileWriter(file))) { - GSON.toJson(parent, writer); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - @Override - public void add(Layer... layers) { - this.layers.addAll(Arrays.asList(layers)); - } -} diff --git a/src/main/java/net/echo/brain4j/nlp/agents/model/TransformerModel.java b/src/main/java/net/echo/brain4j/nlp/agents/model/TransformerModel.java index 1c949f2..0e7718d 100644 --- a/src/main/java/net/echo/brain4j/nlp/agents/model/TransformerModel.java +++ b/src/main/java/net/echo/brain4j/nlp/agents/model/TransformerModel.java @@ -3,10 +3,10 @@ import net.echo.brain4j.activation.Activations; import net.echo.brain4j.layer.impl.DenseLayer; import net.echo.brain4j.layer.impl.DropoutLayer; -import net.echo.brain4j.model.impl.NeuralNetwork; +import net.echo.brain4j.model.Model; import net.echo.brain4j.nlp.agents.attention.impl.MultiHeadAttention; -public class TransformerModel extends NeuralNetwork { +public class TransformerModel extends Model { public TransformerModel(int maxSequenceLength, int numHeads, int embeddingDim, double temperature, double topK) { super(); for (int i = 0; i < 3; i++) { diff --git a/src/main/java/net/echo/brain4j/training/BackPropagation.java b/src/main/java/net/echo/brain4j/training/BackPropagation.java index de4bf4c..1d11c7d 100644 --- a/src/main/java/net/echo/brain4j/training/BackPropagation.java +++ b/src/main/java/net/echo/brain4j/training/BackPropagation.java @@ -2,7 +2,6 @@ import net.echo.brain4j.layer.Layer; import net.echo.brain4j.layer.impl.DropoutLayer; -import net.echo.brain4j.model.Model; import net.echo.brain4j.structure.Neuron; import net.echo.brain4j.structure.Synapse; import net.echo.brain4j.training.data.DataRow; diff --git a/src/test/java/XorTest.java b/src/test/java/XorTest.java index 1d60f16..0128758 100644 --- a/src/test/java/XorTest.java +++ b/src/test/java/XorTest.java @@ -9,7 +9,6 @@ import net.echo.brain4j.layer.impl.DenseLayer; import net.echo.brain4j.loss.LossFunctions; import net.echo.brain4j.model.Model; -import net.echo.brain4j.model.impl.NeuralNetwork; import net.echo.brain4j.training.optimizers.impl.Adam; import java.util.Arrays; @@ -17,7 +16,7 @@ public class XorTest { public static void main(String[] args) { - Model convolutionalModel = new NeuralNetwork( + Model convolutionalModel = new Model( new InputLayer(24, 24), new ConvolutionalLayer(24, 24, 32, Activations.RELU), new PoolingLayer(PoolingType.MAX, 2, 2, 2, 0), @@ -27,7 +26,7 @@ public static void main(String[] args) { new DenseLayer(1, Activations.SIGMOID) ); - Model network = new NeuralNetwork( + Model network = new Model( new DenseLayer(2, Activations.LINEAR), new DenseLayer(32, Activations.RELU), new DenseLayer(32, Activations.RELU),