From 4d80c925dce750691c7df4bc9d83de80ddb8a926 Mon Sep 17 00:00:00 2001 From: Adversing Date: Fri, 29 Nov 2024 20:33:32 +0100 Subject: [PATCH] Feat: Added WIP NLP actual implementation (to-be-fixed) --- .../echo/brain4j/activation/Activation.java | 6 + .../echo/brain4j/activation/Activations.java | 9 +- .../activation/impl/LeakyReLUActivation.java | 17 ++ .../activation/impl/LinearActivation.java | 17 ++ .../activation/impl/ReLUActivation.java | 17 ++ .../activation/impl/SigmoidActivation.java | 21 +++ .../activation/impl/SoftmaxActivation.java | 55 +++++++ .../activation/impl/TanhActivation.java | 36 ++++ .../java/net/echo/brain4j/layer/Layer.java | 6 +- .../net/echo/brain4j/loss/LossFunctions.java | 10 +- .../loss/impl/CategoricalCrossEntropy.java | 14 ++ .../brain4j/model/impl/FeedForwardModel.java | 17 +- .../echo/brain4j/nlp/LabelTransformer.java | 1 + .../net/echo/brain4j/nlp/agents/Agent.java | 11 ++ .../agents/attention/AttentionMechanism.java | 6 + .../attention/impl/MultiHeadAttention.java | 85 ++++++++++ .../agents/attention/impl/SelfAttention.java | 68 ++++++++ .../attention/score/AttentionScorer.java | 55 +++++++ .../agents/encoding/PositionalEncoding.java | 31 ++++ .../brain4j/nlp/agents/impl/ChatAgent.java | 155 ++++++++++++++++++ .../nlp/agents/model/TransformerModel.java | 21 +++ .../nlp/token/weight/TokenWeighter.java | 24 +++ .../brain4j/training/BackPropagation.java | 30 ++-- .../data/nlp/ConversationDataSet.java | 38 +++++ src/test/java/ChatAgentTest.java | 36 ++++ 25 files changed, 762 insertions(+), 24 deletions(-) create mode 100644 src/main/java/net/echo/brain4j/activation/impl/SoftmaxActivation.java create mode 100644 src/main/java/net/echo/brain4j/activation/impl/TanhActivation.java create mode 100644 src/main/java/net/echo/brain4j/loss/impl/CategoricalCrossEntropy.java create mode 100644 src/main/java/net/echo/brain4j/nlp/agents/Agent.java create mode 100644 src/main/java/net/echo/brain4j/nlp/agents/attention/AttentionMechanism.java create mode 100644 src/main/java/net/echo/brain4j/nlp/agents/attention/impl/MultiHeadAttention.java create mode 100644 src/main/java/net/echo/brain4j/nlp/agents/attention/impl/SelfAttention.java create mode 100644 src/main/java/net/echo/brain4j/nlp/agents/attention/score/AttentionScorer.java create mode 100644 src/main/java/net/echo/brain4j/nlp/agents/encoding/PositionalEncoding.java create mode 100644 src/main/java/net/echo/brain4j/nlp/agents/impl/ChatAgent.java create mode 100644 src/main/java/net/echo/brain4j/nlp/agents/model/TransformerModel.java create mode 100644 src/main/java/net/echo/brain4j/nlp/token/weight/TokenWeighter.java create mode 100644 src/main/java/net/echo/brain4j/training/data/nlp/ConversationDataSet.java create mode 100644 src/test/java/ChatAgentTest.java diff --git a/src/main/java/net/echo/brain4j/activation/Activation.java b/src/main/java/net/echo/brain4j/activation/Activation.java index 74d3978..65bcf08 100644 --- a/src/main/java/net/echo/brain4j/activation/Activation.java +++ b/src/main/java/net/echo/brain4j/activation/Activation.java @@ -1,8 +1,14 @@ package net.echo.brain4j.activation; +import net.echo.brain4j.structure.Neuron; + +import java.util.List; + public interface Activation { double activate(double input); + double[] activate(double[] input); double getDerivative(double input); + void apply(List neurons); } diff --git a/src/main/java/net/echo/brain4j/activation/Activations.java b/src/main/java/net/echo/brain4j/activation/Activations.java index 2c1a49b..10022d4 100644 --- a/src/main/java/net/echo/brain4j/activation/Activations.java +++ b/src/main/java/net/echo/brain4j/activation/Activations.java @@ -1,16 +1,15 @@ package net.echo.brain4j.activation; -import net.echo.brain4j.activation.impl.LeakyReLUActivation; -import net.echo.brain4j.activation.impl.LinearActivation; -import net.echo.brain4j.activation.impl.ReLUActivation; -import net.echo.brain4j.activation.impl.SigmoidActivation; +import net.echo.brain4j.activation.impl.*; public enum Activations { LINEAR(new LinearActivation()), RELU(new ReLUActivation()), LEAKY_RELU(new LeakyReLUActivation()), - SIGMOID(new SigmoidActivation()); + SIGMOID(new SigmoidActivation()), + SOFTMAX(new SoftmaxActivation()), + TANH(new TanhActivation()); private final Activation function; diff --git a/src/main/java/net/echo/brain4j/activation/impl/LeakyReLUActivation.java b/src/main/java/net/echo/brain4j/activation/impl/LeakyReLUActivation.java index f83e47c..e3c2689 100644 --- a/src/main/java/net/echo/brain4j/activation/impl/LeakyReLUActivation.java +++ b/src/main/java/net/echo/brain4j/activation/impl/LeakyReLUActivation.java @@ -1,6 +1,9 @@ package net.echo.brain4j.activation.impl; import net.echo.brain4j.activation.Activation; +import net.echo.brain4j.structure.Neuron; + +import java.util.List; public class LeakyReLUActivation implements Activation { @@ -9,8 +12,22 @@ public double activate(double input) { return Math.max(0.01 * input, input); } + @Override + public double[] activate(double[] input) { + throw new UnsupportedOperationException("Leaky ReLU activation function is not supported for multiple inputs"); + } + @Override public double getDerivative(double input) { return input > 0 ? 1 : 0.01; } + + @Override + public void apply(List neurons) { + for (Neuron neuron : neurons) { + double output = activate(neuron.getValue() + neuron.getBias()); + + neuron.setValue(output); + } + } } diff --git a/src/main/java/net/echo/brain4j/activation/impl/LinearActivation.java b/src/main/java/net/echo/brain4j/activation/impl/LinearActivation.java index d201705..9bfcc49 100644 --- a/src/main/java/net/echo/brain4j/activation/impl/LinearActivation.java +++ b/src/main/java/net/echo/brain4j/activation/impl/LinearActivation.java @@ -1,6 +1,9 @@ package net.echo.brain4j.activation.impl; import net.echo.brain4j.activation.Activation; +import net.echo.brain4j.structure.Neuron; + +import java.util.List; public class LinearActivation implements Activation { @@ -9,8 +12,22 @@ public double activate(double input) { return input; } + @Override + public double[] activate(double[] input) { + throw new UnsupportedOperationException("Linear activation function is not supported for multiple inputs"); + } + @Override public double getDerivative(double input) { return 1; } + + @Override + public void apply(List neurons) { + for (Neuron neuron : neurons) { + double output = activate(neuron.getValue() + neuron.getBias()); + + neuron.setValue(output); + } + } } diff --git a/src/main/java/net/echo/brain4j/activation/impl/ReLUActivation.java b/src/main/java/net/echo/brain4j/activation/impl/ReLUActivation.java index f735ee7..6ff33c5 100644 --- a/src/main/java/net/echo/brain4j/activation/impl/ReLUActivation.java +++ b/src/main/java/net/echo/brain4j/activation/impl/ReLUActivation.java @@ -1,6 +1,9 @@ package net.echo.brain4j.activation.impl; import net.echo.brain4j.activation.Activation; +import net.echo.brain4j.structure.Neuron; + +import java.util.List; public class ReLUActivation implements Activation { @@ -9,8 +12,22 @@ public double activate(double input) { return Math.max(0, input); } + @Override + public double[] activate(double[] input) { + throw new UnsupportedOperationException("ReLU activation function is not supported for multiple inputs"); + } + @Override public double getDerivative(double input) { return input > 0 ? 1 : 0; } + + @Override + public void apply(List neurons) { + for (Neuron neuron : neurons) { + double output = activate(neuron.getValue() + neuron.getBias()); + + neuron.setValue(output); + } + } } diff --git a/src/main/java/net/echo/brain4j/activation/impl/SigmoidActivation.java b/src/main/java/net/echo/brain4j/activation/impl/SigmoidActivation.java index b54e3f7..d4b483a 100644 --- a/src/main/java/net/echo/brain4j/activation/impl/SigmoidActivation.java +++ b/src/main/java/net/echo/brain4j/activation/impl/SigmoidActivation.java @@ -1,6 +1,9 @@ package net.echo.brain4j.activation.impl; import net.echo.brain4j.activation.Activation; +import net.echo.brain4j.structure.Neuron; + +import java.util.List; public class SigmoidActivation implements Activation { @@ -9,8 +12,26 @@ public double activate(double input) { return 1 / (1 + Math.exp(-input)); } + @Override + public double[] activate(double[] input) { + double[] result = new double[input.length]; + for (int i = 0; i < input.length; i++) { + result[i] = activate(input[i]); + } + return result; + } + @Override public double getDerivative(double input) { return activate(input) * (1 - activate(input)); } + + @Override + public void apply(List neurons) { + for (Neuron neuron : neurons) { + double output = activate(neuron.getValue() + neuron.getBias()); + + neuron.setValue(output); + } + } } diff --git a/src/main/java/net/echo/brain4j/activation/impl/SoftmaxActivation.java b/src/main/java/net/echo/brain4j/activation/impl/SoftmaxActivation.java new file mode 100644 index 0000000..f8028a2 --- /dev/null +++ b/src/main/java/net/echo/brain4j/activation/impl/SoftmaxActivation.java @@ -0,0 +1,55 @@ +package net.echo.brain4j.activation.impl; + +import net.echo.brain4j.activation.Activation; +import net.echo.brain4j.structure.Neuron; + +import java.util.List; + +public class SoftmaxActivation implements Activation { + @Override + public double activate(double input) { + throw new UnsupportedOperationException("Softmax activation function is not supported for single value"); + } + + @Override + public double[] activate(double[] inputs) { + double maxInput = Double.NEGATIVE_INFINITY; + for (double input : inputs) { + if (input > maxInput) { + maxInput = input; + } + } + + double[] expValues = new double[inputs.length]; + double sum = 0.0; + for (int i = 0; i < inputs.length; i++) { + expValues[i] = Math.exp(inputs[i] - maxInput); + sum += expValues[i]; + } + + for (int i = 0; i < expValues.length; i++) { + expValues[i] /= sum; + } + + return expValues; + } + + @Override + public double getDerivative(double input) { + return input * (1.0 - input); + } + + @Override + public void apply(List neurons) { + double[] values = new double[neurons.size()]; + for (int i = 0; i < neurons.size(); i++) { + values[i] = neurons.get(i).getValue() + neurons.get(i).getBias(); + } + + double[] activatedValues = activate(values); + + for (int i = 0; i < neurons.size(); i++) { + neurons.get(i).setValue(activatedValues[i]); + } + } +} diff --git a/src/main/java/net/echo/brain4j/activation/impl/TanhActivation.java b/src/main/java/net/echo/brain4j/activation/impl/TanhActivation.java new file mode 100644 index 0000000..5502939 --- /dev/null +++ b/src/main/java/net/echo/brain4j/activation/impl/TanhActivation.java @@ -0,0 +1,36 @@ +package net.echo.brain4j.activation.impl; + +import net.echo.brain4j.activation.Activation; +import net.echo.brain4j.structure.Neuron; + +import java.util.List; + +public class TanhActivation implements Activation { + + @Override + public double activate(double input) { + return Math.tanh(input); + } + + @Override + public double[] activate(double[] inputs) { + double[] result = new double[inputs.length]; + for (int i = 0; i < inputs.length; i++) { + result[i] = activate(inputs[i]); + } + return result; + } + + @Override + public double getDerivative(double input) { + return 1.0 - Math.pow(Math.tanh(input), 2); + } + + @Override + public void apply(List neurons) { + for (Neuron neuron : neurons) { + double value = neuron.getValue() + neuron.getBias(); + neuron.setValue(activate(value)); + } + } +} diff --git a/src/main/java/net/echo/brain4j/layer/Layer.java b/src/main/java/net/echo/brain4j/layer/Layer.java index 534273a..fc05057 100644 --- a/src/main/java/net/echo/brain4j/layer/Layer.java +++ b/src/main/java/net/echo/brain4j/layer/Layer.java @@ -56,11 +56,7 @@ public Neuron getNeuronAt(int i) { public void applyFunction() { Activation function = activation.getFunction(); - for (Neuron neuron : neurons) { - double output = function.activate(neuron.getValue() + neuron.getBias()); - - neuron.setValue(output); - } + function.apply(neurons); } public int getTotalParams() { diff --git a/src/main/java/net/echo/brain4j/loss/LossFunctions.java b/src/main/java/net/echo/brain4j/loss/LossFunctions.java index fa83a52..70f492e 100644 --- a/src/main/java/net/echo/brain4j/loss/LossFunctions.java +++ b/src/main/java/net/echo/brain4j/loss/LossFunctions.java @@ -1,6 +1,7 @@ package net.echo.brain4j.loss; import net.echo.brain4j.loss.impl.BinaryCrossEntropy; +import net.echo.brain4j.loss.impl.CategoricalCrossEntropy; import net.echo.brain4j.loss.impl.CrossEntropy; import net.echo.brain4j.loss.impl.MeanSquaredError; @@ -26,7 +27,14 @@ public enum LossFunctions { * and the actual distribution of classes. Typically used with models having multiple output neurons * and a softmax activation function. */ - CROSS_ENTROPY(new CrossEntropy()); + CROSS_ENTROPY(new CrossEntropy()), + + /** + * Categorical Cross Entropy: A variant of Cross Entropy specifically designed for multi-class classification + * tasks. It is used when the target labels are one-hot encoded. It calculates the divergence between + * the predicted probability distribution and the actual distribution of classes. + */ + CATEGORICAL_CROSS_ENTROPY(new CategoricalCrossEntropy()); private final LossFunction function; diff --git a/src/main/java/net/echo/brain4j/loss/impl/CategoricalCrossEntropy.java b/src/main/java/net/echo/brain4j/loss/impl/CategoricalCrossEntropy.java new file mode 100644 index 0000000..98d7a4a --- /dev/null +++ b/src/main/java/net/echo/brain4j/loss/impl/CategoricalCrossEntropy.java @@ -0,0 +1,14 @@ +package net.echo.brain4j.loss.impl; + +import net.echo.brain4j.loss.LossFunction; + +public class CategoricalCrossEntropy implements LossFunction { + @Override + public double calculate(double[] expected, double[] actual) { + double sum = 0.0; + for (int i = 0; i < expected.length; i++) { + sum += -expected[i] * Math.log(actual[i] + 1e-15); + } + return sum; + } +} diff --git a/src/main/java/net/echo/brain4j/model/impl/FeedForwardModel.java b/src/main/java/net/echo/brain4j/model/impl/FeedForwardModel.java index 4f12f89..cbadc18 100644 --- a/src/main/java/net/echo/brain4j/model/impl/FeedForwardModel.java +++ b/src/main/java/net/echo/brain4j/model/impl/FeedForwardModel.java @@ -4,6 +4,7 @@ import com.google.gson.JsonObject; import com.google.gson.JsonParser; import com.google.gson.reflect.TypeToken; +import net.echo.brain4j.activation.Activation; import net.echo.brain4j.adapters.LayerAdapter; import net.echo.brain4j.adapters.OptimizerAdapter; import net.echo.brain4j.layer.Layer; @@ -86,7 +87,17 @@ public void compile(InitializationType type, LossFunctions function, Optimizer o @Override public void fit(DataSet set) { - propagation.iterate(set, optimizer.getLearningRate()); + System.out.println("Processing batch of " + set.getDataRows().size() + " samples"); + for (DataRow row : set.getDataRows()) { + double[] inputs = row.inputs(); + System.out.println("Input size: " + inputs.length); + + double[] outputs = predict(inputs); + System.out.println("Output size: " + outputs.length); + + propagation.iterate(new DataSet(row), optimizer.getLearningRate()); + } + System.out.println("Batch processing complete"); } @Override @@ -139,11 +150,11 @@ public double[] predict(double ... input) { Neuron inputNeuron = synapse.getInputNeuron(); Neuron outputNeuron = synapse.getOutputNeuron(); - // Weighted sum outputNeuron.setValue(outputNeuron.getValue() + inputNeuron.getValue() * synapse.getWeight()); } - // Apply the activation function + Activation activation = layer.getActivation().getFunction(); + System.out.println("Activation function: " + activation.getClass().getSimpleName()); nextLayer.applyFunction(); } diff --git a/src/main/java/net/echo/brain4j/nlp/LabelTransformer.java b/src/main/java/net/echo/brain4j/nlp/LabelTransformer.java index 9735c12..5d0bb25 100644 --- a/src/main/java/net/echo/brain4j/nlp/LabelTransformer.java +++ b/src/main/java/net/echo/brain4j/nlp/LabelTransformer.java @@ -45,6 +45,7 @@ public double[] encode(String text, int length) { return encoded; } + // to be implemented public String decode(double[] encoded) { StringBuilder decoded = new StringBuilder(); diff --git a/src/main/java/net/echo/brain4j/nlp/agents/Agent.java b/src/main/java/net/echo/brain4j/nlp/agents/Agent.java new file mode 100644 index 0000000..17619e3 --- /dev/null +++ b/src/main/java/net/echo/brain4j/nlp/agents/Agent.java @@ -0,0 +1,11 @@ +package net.echo.brain4j.nlp.agents; + +import net.echo.brain4j.training.data.DataSet; + +public interface Agent { + String process(String input); + void train(DataSet conversationData); + double evaluate(DataSet testData); + void save(String path); + void load(String path); +} diff --git a/src/main/java/net/echo/brain4j/nlp/agents/attention/AttentionMechanism.java b/src/main/java/net/echo/brain4j/nlp/agents/attention/AttentionMechanism.java new file mode 100644 index 0000000..97e8e66 --- /dev/null +++ b/src/main/java/net/echo/brain4j/nlp/agents/attention/AttentionMechanism.java @@ -0,0 +1,6 @@ +package net.echo.brain4j.nlp.agents.attention; + +public interface AttentionMechanism { + double[] attend(double[] input, String contextKey); +} + diff --git a/src/main/java/net/echo/brain4j/nlp/agents/attention/impl/MultiHeadAttention.java b/src/main/java/net/echo/brain4j/nlp/agents/attention/impl/MultiHeadAttention.java new file mode 100644 index 0000000..6e28271 --- /dev/null +++ b/src/main/java/net/echo/brain4j/nlp/agents/attention/impl/MultiHeadAttention.java @@ -0,0 +1,85 @@ +package net.echo.brain4j.nlp.agents.attention.impl; + +import net.echo.brain4j.activation.Activations; +import net.echo.brain4j.layer.Layer; +import net.echo.brain4j.nlp.agents.attention.AttentionMechanism; +import net.echo.brain4j.nlp.agents.attention.score.AttentionScorer; + +import java.util.Random; + +public class MultiHeadAttention extends Layer implements AttentionMechanism { + private final int numHeads; + private final int headDim; + private final AttentionScorer scorer; + private final double[][] projectionWeights; + private final double[][] outputWeights; + + public MultiHeadAttention(int numHeads, int embeddingDim, double temperature, double topK) { + super(embeddingDim, Activations.LINEAR); + this.numHeads = numHeads; + this.headDim = embeddingDim / numHeads; + this.scorer = new AttentionScorer(temperature, topK); + this.projectionWeights = new double[numHeads][embeddingDim]; + this.outputWeights = new double[embeddingDim][embeddingDim]; + + initializeWeights(); + } + + private void initializeWeights() { + Random random = new Random(); + for (int i = 0; i < numHeads; i++) { + for (int j = 0; j < headDim; j++) { + projectionWeights[i][j] = random.nextGaussian() * 0.02; + } + } + + for (int i = 0; i < outputWeights.length; i++) { + for (int j = 0; j < outputWeights[0].length; j++) { + outputWeights[i][j] = random.nextGaussian() * 0.02; + } + } + } + + @Override + public double[] attend(double[] input, String contextKey) { + double[][] headOutputs = new double[numHeads][]; + + for (int head = 0; head < numHeads; head++) { + double[] projectedInput = projectToHead(input, head); + headOutputs[head] = scorer.score(projectedInput, projectedInput, contextKey + "_head_" + head); + } + + return concatenateAndProject(headOutputs); + } + + private double[] projectToHead(double[] input, int head) { + double[] projected = new double[headDim]; + for (int i = 0; i < headDim; i++) { + for (int j = 0; j < input.length; j++) { + projected[i] += input[j] * projectionWeights[head][j]; + } + } + return projected; + } + + private double[] concatenateAndProject(double[][] headOutputs) { + double[] concatenated = new double[headDim * numHeads]; + int offset = 0; + + for (double[] headOutput : headOutputs) { + System.arraycopy(headOutput, 0, concatenated, offset, headDim); + offset += headDim; + } + + double[] output = new double[getNeurons().size()]; + for (int i = 0; i < output.length; i++) { + for (int j = 0; j < concatenated.length; j++) { + output[i] += concatenated[j] * outputWeights[i][j]; + } + } + + return output; + } +} + + diff --git a/src/main/java/net/echo/brain4j/nlp/agents/attention/impl/SelfAttention.java b/src/main/java/net/echo/brain4j/nlp/agents/attention/impl/SelfAttention.java new file mode 100644 index 0000000..0840564 --- /dev/null +++ b/src/main/java/net/echo/brain4j/nlp/agents/attention/impl/SelfAttention.java @@ -0,0 +1,68 @@ +package net.echo.brain4j.nlp.agents.attention.impl; + +import net.echo.brain4j.activation.Activations; +import net.echo.brain4j.layer.Layer; +import net.echo.brain4j.nlp.agents.attention.AttentionMechanism; +import net.echo.brain4j.nlp.agents.attention.score.AttentionScorer; + +import java.util.Random; + +public class SelfAttention extends Layer implements AttentionMechanism { + private final AttentionScorer scorer; + private final int headDim; + private final double[][] queryWeights; + private final double[][] keyWeights; + private final double[][] valueWeights; + + public SelfAttention(int embeddingDim, double temperature, double topK) { + super(embeddingDim, Activations.LINEAR); + this.scorer = new AttentionScorer(temperature, topK); + this.headDim = embeddingDim; + + this.queryWeights = new double[embeddingDim][headDim]; + this.keyWeights = new double[embeddingDim][headDim]; + this.valueWeights = new double[embeddingDim][headDim]; + + initializeWeights(); + } + + private void initializeWeights() { + Random random = new Random(); + for (int i = 0; i < headDim; i++) { + for (int j = 0; j < headDim; j++) { + queryWeights[i][j] = random.nextGaussian() * 0.02; + keyWeights[i][j] = random.nextGaussian() * 0.02; + valueWeights[i][j] = random.nextGaussian() * 0.02; + } + } + } + + @Override + public double[] attend(double[] input, String contextKey) { + double[] query = projectVector(input, queryWeights); + double[] key = projectVector(input, keyWeights); + double[] value = projectVector(input, valueWeights); + + double[] attentionScores = scorer.score(query, key, contextKey); + return computeWeightedSum(attentionScores, value); + } + + private double[] projectVector(double[] input, double[][] weights) { + double[] output = new double[headDim]; + for (int i = 0; i < headDim; i++) { + for (int j = 0; j < input.length; j++) { + output[i] += input[j] * weights[j][i]; + } + } + return output; + } + + private double[] computeWeightedSum(double[] scores, double[] values) { + double[] output = new double[values.length]; + for (int i = 0; i < values.length; i++) { + output[i] = scores[i] * values[i]; + } + return output; + } +} + diff --git a/src/main/java/net/echo/brain4j/nlp/agents/attention/score/AttentionScorer.java b/src/main/java/net/echo/brain4j/nlp/agents/attention/score/AttentionScorer.java new file mode 100644 index 0000000..c8871a5 --- /dev/null +++ b/src/main/java/net/echo/brain4j/nlp/agents/attention/score/AttentionScorer.java @@ -0,0 +1,55 @@ +package net.echo.brain4j.nlp.agents.attention.score; + +import java.util.HashMap; +import java.util.Map; + +public class AttentionScorer { + private final double temperature; + private final double topK; + private final Map attentionCache; + + public AttentionScorer(double temperature, double topK) { + this.temperature = temperature; + this.topK = topK; + this.attentionCache = new HashMap<>(); + } + + public double[] score(double[] query, double[] key, String contextKey) { + double[] cachedScore = attentionCache.get(contextKey); + if (cachedScore != null) return cachedScore; + + double[] scores = computeAttentionScores(query, key); + attentionCache.put(contextKey, scores); + return scores; + } + + private double[] computeAttentionScores(double[] query, double[] key) { + double[] scores = new double[query.length]; + double maxScore = Double.NEGATIVE_INFINITY; + + for (int i = 0; i < query.length; i++) { + scores[i] = (query[i] * key[i]) / Math.sqrt(query.length); + maxScore = Math.max(maxScore, scores[i]); + } + + double sum = 0.0; + for (int i = 0; i < scores.length; i++) { + scores[i] = Math.exp((scores[i] - maxScore) / temperature); + sum += scores[i]; + } + + for (int i = 0; i < scores.length; i++) { + scores[i] = scores[i] / sum; + if (scores[i] < topK) { + scores[i] = 0; + } + } + + return scores; + } + + public void clearCache() { + attentionCache.clear(); + } +} + diff --git a/src/main/java/net/echo/brain4j/nlp/agents/encoding/PositionalEncoding.java b/src/main/java/net/echo/brain4j/nlp/agents/encoding/PositionalEncoding.java new file mode 100644 index 0000000..48bade2 --- /dev/null +++ b/src/main/java/net/echo/brain4j/nlp/agents/encoding/PositionalEncoding.java @@ -0,0 +1,31 @@ +package net.echo.brain4j.nlp.agents.encoding; + +public class PositionalEncoding { + private final int maxLength; + private final int embeddingDim; + private final double[][] encodings; + + public PositionalEncoding(int maxLength, int embeddingDim) { + this.maxLength = maxLength; + this.embeddingDim = embeddingDim; + this.encodings = new double[maxLength][embeddingDim]; + initializeEncodings(); + } + + private void initializeEncodings() { + for (int pos = 0; pos < maxLength; pos++) { + for (int i = 0; i < embeddingDim; i++) { + double angle = pos / Math.pow(10000, (2.0 * i) / embeddingDim); + encodings[pos][i] = i % 2 == 0 ? Math.sin(angle) : Math.cos(angle); + } + } + } + + public double[] encode(double[] input, int position) { + double[] encoded = new double[input.length]; + for (int i = 0; i < input.length; i++) { + encoded[i] = input[i] + encodings[position][i]; + } + return encoded; + } +} diff --git a/src/main/java/net/echo/brain4j/nlp/agents/impl/ChatAgent.java b/src/main/java/net/echo/brain4j/nlp/agents/impl/ChatAgent.java new file mode 100644 index 0000000..f5ea1a8 --- /dev/null +++ b/src/main/java/net/echo/brain4j/nlp/agents/impl/ChatAgent.java @@ -0,0 +1,155 @@ +package net.echo.brain4j.nlp.agents.impl; + +import net.echo.brain4j.loss.LossFunctions; +import net.echo.brain4j.model.initialization.InitializationType; +import net.echo.brain4j.nlp.agents.Agent; +import net.echo.brain4j.nlp.agents.attention.AttentionMechanism; +import net.echo.brain4j.nlp.agents.encoding.PositionalEncoding; +import net.echo.brain4j.nlp.agents.model.TransformerModel; +import net.echo.brain4j.nlp.token.weight.TokenWeighter; +import net.echo.brain4j.training.data.DataSet; +import net.echo.brain4j.training.optimizers.impl.Adam; + +import java.util.ArrayList; +import java.util.List; + +public class ChatAgent implements Agent { + double temperature, topK; + private final TransformerModel model; + private final TokenWeighter weighter; + private final PositionalEncoding encoder; + private final AttentionMechanism attentionMechanism; + private final int contextWindow; + private final List conversationHistory; + + public ChatAgent(AttentionMechanism attentionMechanism, int contextWindow, int embeddingDim, double temperature, double topK) { + this.temperature = temperature; + this.topK = topK; + this.attentionMechanism = attentionMechanism; + this.contextWindow = contextWindow; + this.model = new TransformerModel(contextWindow, 128, embeddingDim, temperature, topK); + this.weighter = new TokenWeighter(0.1); + this.encoder = new PositionalEncoding(contextWindow, embeddingDim); + this.conversationHistory = new ArrayList<>(); + + initializeModel(); + } + + private void initializeModel() { + model.compile( + InitializationType.XAVIER, + LossFunctions.MEAN_SQUARED_ERROR, + new Adam(0.1) + ); + } + + public String generateResponse(String userInput) { + String processedInput = preprocessInput(userInput); + double[] encodedInput = processInput(processedInput); + + String contextKey = String.valueOf(conversationHistory.size()); + double[] attendedInput = attentionMechanism.attend(encodedInput, contextKey); + + double[] modelOutput = model.predict(attendedInput); + String response = decodeResponse(modelOutput); + updateContext(userInput, response); + + return formatResponse(response); + } + + private String preprocessInput(String input) { + return input.toLowerCase() + .replaceAll("[^a-z0-9\\s]", "") + .trim(); + } + + private String formatResponse(String response) { + return response.substring(0, 1).toUpperCase() + + response.substring(1) + + (response.endsWith(".") ? "" : "."); + } + + @Override + public String process(String input) { + updateContext(input); + double[] weightedInput = processInput(input); + double[] response = model.predict(weightedInput); + String output = decodeResponse(response); + updateContext(output); + return output; + } + + @Override + public void train(DataSet conversationData) { + int maxEpochs = 150; // Further reduced for testing + double errorThreshold = 0.001; + + System.out.println("Starting training loop"); + for(int epoch = 0; epoch < maxEpochs; epoch++) { + System.out.printf("Epoch %d/%d%n", epoch + 1, maxEpochs); + model.fit(conversationData); + double error = model.evaluate(conversationData); + if (Double.isNaN(error)) { + throw new RuntimeException("Error is NaN"); + } + + System.out.printf("Error: %.4f%n", error); + if(error < errorThreshold) break; + } + } + + @Override + public double evaluate(DataSet testData) { + return model.evaluate(testData); + } + + @Override + public void save(String path) { + model.save(path); + } + + @Override + public void load(String path) { + model.load(path); + } + + private void updateContext(String text) { + conversationHistory.add(text); + if (conversationHistory.size() > contextWindow) { + conversationHistory.remove(0); + } + } + + private void updateContext(String userInput, String response) { + updateContext("User: " + userInput); + updateContext("Assistant: " + response); + } + + private double[] processInput(String input) { + String[] tokens = input.split("\\s+"); + double[] weighted = new double[contextWindow]; + + for (int i = 0; i < tokens.length && i < contextWindow; i++) { + double weight = weighter.getWeight(tokens[i]); + double[] posEncoded = encoder.encode(new double[]{weight}, i); + weighted[i] = posEncoded[0]; + } + + return weighted; + } + + private String decodeResponse(double[] response) { + StringBuilder output = new StringBuilder(); + for (double value : response) { + int index = (int) Math.round(value); + if (index >= 0 && index < 26) { + output.append((char) (index + 'a')); + } + } + return output.toString(); + } + + public TransformerModel getModel() { + return model; + } +} diff --git a/src/main/java/net/echo/brain4j/nlp/agents/model/TransformerModel.java b/src/main/java/net/echo/brain4j/nlp/agents/model/TransformerModel.java new file mode 100644 index 0000000..6f63819 --- /dev/null +++ b/src/main/java/net/echo/brain4j/nlp/agents/model/TransformerModel.java @@ -0,0 +1,21 @@ +package net.echo.brain4j.nlp.agents.model; + +import net.echo.brain4j.activation.Activations; +import net.echo.brain4j.layer.impl.DenseLayer; +import net.echo.brain4j.layer.impl.DropoutLayer; +import net.echo.brain4j.model.impl.FeedForwardModel; +import net.echo.brain4j.nlp.agents.attention.impl.MultiHeadAttention; + +public class TransformerModel extends FeedForwardModel { + public TransformerModel(int maxSequenceLength, int numHeads, int embeddingDim, double temperature, double topK) { + super( + new DenseLayer(maxSequenceLength, Activations.SIGMOID), + new MultiHeadAttention(numHeads, embeddingDim, temperature, topK), + new DenseLayer(4098, Activations.RELU), + new DropoutLayer(0.1), + new DenseLayer(embeddingDim, Activations.RELU) + ); + } +} + + diff --git a/src/main/java/net/echo/brain4j/nlp/token/weight/TokenWeighter.java b/src/main/java/net/echo/brain4j/nlp/token/weight/TokenWeighter.java new file mode 100644 index 0000000..4e80f4d --- /dev/null +++ b/src/main/java/net/echo/brain4j/nlp/token/weight/TokenWeighter.java @@ -0,0 +1,24 @@ +package net.echo.brain4j.nlp.token.weight; + +import java.util.HashMap; +import java.util.Map; + +public class TokenWeighter { + private final Map vocabulary; + private final double smoothingFactor; + + public TokenWeighter(double smoothingFactor) { + this.vocabulary = new HashMap<>(); + this.smoothingFactor = smoothingFactor; + } + + // this must be implemented + public void updateWeights(String token, double frequency) { + double idf = Math.log(1 + (1.0 / (frequency + smoothingFactor))); + vocabulary.put(token, idf); + } + + public double getWeight(String token) { + return vocabulary.getOrDefault(token, 1.0); + } +} \ No newline at end of file diff --git a/src/main/java/net/echo/brain4j/training/BackPropagation.java b/src/main/java/net/echo/brain4j/training/BackPropagation.java index 21b60dc..0b35d55 100644 --- a/src/main/java/net/echo/brain4j/training/BackPropagation.java +++ b/src/main/java/net/echo/brain4j/training/BackPropagation.java @@ -15,6 +15,7 @@ public class BackPropagation { private final Model model; private final Optimizer optimizer; + private static final double GRADIENT_CLIP = 5.0; private int timestep = 0; @@ -23,14 +24,24 @@ public BackPropagation(Model model, Optimizer optimizer) { this.optimizer = optimizer; } + private double clipGradient(double gradient) { + return Math.max(Math.min(gradient, GRADIENT_CLIP), -GRADIENT_CLIP); + } + public void iterate(DataSet dataSet, double learningRate) { + System.out.println("Starting backpropagation iteration"); + for (DataRow row : dataSet.getDataRows()) { double[] inputs = row.inputs(); double[] targets = row.outputs(); + System.out.println("Forward pass"); double[] outputs = model.predict(inputs); + System.out.println("Backward pass"); backpropagate(targets, outputs, learningRate); + + System.out.println("Iteration complete"); } } @@ -38,7 +49,6 @@ private void backpropagate(double[] targets, double[] outputs, double learningRa List layers = model.getLayers(); initialDelta(layers, targets, outputs); - // Hidden layers error and delta for (int l = layers.size() - 2; l > 0; l--) { Layer layer = layers.get(l); @@ -50,14 +60,11 @@ private void backpropagate(double[] targets, double[] outputs, double learningRa for (Neuron neuron : layer.getNeurons()) { double output = neuron.getValue(); - for (Synapse synapse : neuron.getSynapses()) { - double error = synapse.getWeight() * synapse.getOutputNeuron().getDelta(); - - double delta = error * layer.getActivation().getFunction().getDerivative(output); + double error = clipGradient(synapse.getWeight() * synapse.getOutputNeuron().getDelta()); + double delta = clipGradient(error * layer.getActivation().getFunction().getDerivative(output)); neuron.setDelta(delta); - - synapse.setWeight(synapse.getWeight() + delta * synapse.getInputNeuron().getValue()); + synapse.setWeight(synapse.getWeight() + clipGradient(delta * synapse.getInputNeuron().getValue())); } } } @@ -70,7 +77,6 @@ private void initialDelta(List layers, double[] targets, double[] outputs for (int i = 0; i < outputLayer.getNeurons().size(); i++) { Neuron neuron = outputLayer.getNeuronAt(i); - double output = outputs[i]; double error = targets[i] - output; @@ -80,17 +86,21 @@ private void initialDelta(List layers, double[] targets, double[] outputs } private void updateWeightsAndBiases(List layers, double learningRate) { + System.out.println("Starting weight updates"); timestep++; - layers.parallelStream().forEach(nextLayer -> { + for (Layer nextLayer : layers) { + System.out.println("Updating synapses for layer"); for (Synapse synapse : nextLayer.getSynapses()) { optimizer.update(synapse, timestep); } + System.out.println("Updating biases for layer"); for (Neuron neuron : nextLayer.getNeurons()) { double deltaBias = learningRate * neuron.getDelta(); neuron.setBias(neuron.getBias() + deltaBias); } - }); + } + System.out.println("Weight updates complete"); } } \ No newline at end of file diff --git a/src/main/java/net/echo/brain4j/training/data/nlp/ConversationDataSet.java b/src/main/java/net/echo/brain4j/training/data/nlp/ConversationDataSet.java new file mode 100644 index 0000000..77af888 --- /dev/null +++ b/src/main/java/net/echo/brain4j/training/data/nlp/ConversationDataSet.java @@ -0,0 +1,38 @@ +package net.echo.brain4j.training.data.nlp; + +import net.echo.brain4j.nlp.LabelTransformer; +import net.echo.brain4j.training.data.DataRow; +import net.echo.brain4j.training.data.DataSet; + +import java.util.ArrayList; +import java.util.List; + +public class ConversationDataSet extends DataSet { + private final LabelTransformer transformer; + private final int maxLength; + + public ConversationDataSet(int maxLength, LabelTransformer transformer, String... conversations) { + super(processConversations(conversations, maxLength, transformer)); + this.maxLength = maxLength; + this.transformer = transformer; + } + + private static DataRow[] processConversations(String[] conversations, int maxLength, LabelTransformer transformer) { + List rows = new ArrayList<>(); + + for (int i = 0; i < conversations.length - 1; i++) { + double[] input = transformer.encode(conversations[i], maxLength); + double[] output = transformer.encode(conversations[i + 1], maxLength); + rows.add(new DataRow(input, output)); + } + + return rows.toArray(new DataRow[0]); + } + + public void addConversation(String input, String output) { + double[] encodedInput = transformer.encode(input, maxLength); + double[] encodedOutput = transformer.encode(output, maxLength); + getDataRows().add(new DataRow(encodedInput, encodedOutput)); + } +} + diff --git a/src/test/java/ChatAgentTest.java b/src/test/java/ChatAgentTest.java new file mode 100644 index 0000000..2aa8c0a --- /dev/null +++ b/src/test/java/ChatAgentTest.java @@ -0,0 +1,36 @@ +import net.echo.brain4j.nlp.AlphabetInitialization; +import net.echo.brain4j.nlp.LabelTransformer; +import net.echo.brain4j.nlp.agents.attention.AttentionMechanism; +import net.echo.brain4j.nlp.agents.attention.impl.SelfAttention; +import net.echo.brain4j.nlp.agents.impl.ChatAgent; +import net.echo.brain4j.training.data.nlp.ConversationDataSet; + +public class ChatAgentTest { + public static void main(String[] args) { + AttentionMechanism selfAttention = new SelfAttention(512, 0.6, 0.95); + ChatAgent agent = new ChatAgent(selfAttention, 512, 512, 0.6, 0.95); + LabelTransformer transformer = new LabelTransformer(AlphabetInitialization.NORMAL); + ConversationDataSet trainingData = new ConversationDataSet(512, transformer, + "Hello, how are you?", + "I'm doing great, thanks for asking!", + "What's the weather like?", + "It's sunny and warm today." + ); + + System.out.println("Starting training with max 100 epochs..."); + try { + agent.train(trainingData); + } catch (Exception e) { + e.printStackTrace(); + } + System.out.println("Training completed"); + + String userInput = "Hello, how are you?"; + System.out.println("\nUser: " + userInput); + String response = agent.generateResponse(userInput); + System.out.println("ChatBot: " + response); + System.out.println(agent.getModel().getStats()); + } +} + +