Skip to content

Commit

Permalink
Merge pull request #1 from Adversing/feature/nlp-adversing
Browse files Browse the repository at this point in the history
Feat: Added WIP NLP actual implementation (to-be-fixed)
  • Loading branch information
xEcho1337 authored Nov 29, 2024
2 parents 7ebc187 + 4d80c92 commit e77e194
Show file tree
Hide file tree
Showing 25 changed files with 762 additions and 24 deletions.
6 changes: 6 additions & 0 deletions src/main/java/net/echo/brain4j/activation/Activation.java
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
package net.echo.brain4j.activation;

import net.echo.brain4j.structure.Neuron;

import java.util.List;

public interface Activation {

double activate(double input);
double[] activate(double[] input);

double getDerivative(double input);
void apply(List<Neuron> neurons);
}
9 changes: 4 additions & 5 deletions src/main/java/net/echo/brain4j/activation/Activations.java
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
package net.echo.brain4j.activation;

import net.echo.brain4j.activation.impl.LeakyReLUActivation;
import net.echo.brain4j.activation.impl.LinearActivation;
import net.echo.brain4j.activation.impl.ReLUActivation;
import net.echo.brain4j.activation.impl.SigmoidActivation;
import net.echo.brain4j.activation.impl.*;

public enum Activations {

LINEAR(new LinearActivation()),
RELU(new ReLUActivation()),
LEAKY_RELU(new LeakyReLUActivation()),
SIGMOID(new SigmoidActivation());
SIGMOID(new SigmoidActivation()),
SOFTMAX(new SoftmaxActivation()),
TANH(new TanhActivation());

private final Activation function;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package net.echo.brain4j.activation.impl;

import net.echo.brain4j.activation.Activation;
import net.echo.brain4j.structure.Neuron;

import java.util.List;

public class LeakyReLUActivation implements Activation {

Expand All @@ -9,8 +12,22 @@ public double activate(double input) {
return Math.max(0.01 * input, input);
}

@Override
public double[] activate(double[] input) {
throw new UnsupportedOperationException("Leaky ReLU activation function is not supported for multiple inputs");
}

@Override
public double getDerivative(double input) {
return input > 0 ? 1 : 0.01;
}

@Override
public void apply(List<Neuron> neurons) {
for (Neuron neuron : neurons) {
double output = activate(neuron.getValue() + neuron.getBias());

neuron.setValue(output);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package net.echo.brain4j.activation.impl;

import net.echo.brain4j.activation.Activation;
import net.echo.brain4j.structure.Neuron;

import java.util.List;

public class LinearActivation implements Activation {

Expand All @@ -9,8 +12,22 @@ public double activate(double input) {
return input;
}

@Override
public double[] activate(double[] input) {
throw new UnsupportedOperationException("Linear activation function is not supported for multiple inputs");
}

@Override
public double getDerivative(double input) {
return 1;
}

@Override
public void apply(List<Neuron> neurons) {
for (Neuron neuron : neurons) {
double output = activate(neuron.getValue() + neuron.getBias());

neuron.setValue(output);
}
}
}
17 changes: 17 additions & 0 deletions src/main/java/net/echo/brain4j/activation/impl/ReLUActivation.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package net.echo.brain4j.activation.impl;

import net.echo.brain4j.activation.Activation;
import net.echo.brain4j.structure.Neuron;

import java.util.List;

public class ReLUActivation implements Activation {

Expand All @@ -9,8 +12,22 @@ public double activate(double input) {
return Math.max(0, input);
}

@Override
public double[] activate(double[] input) {
throw new UnsupportedOperationException("ReLU activation function is not supported for multiple inputs");
}

@Override
public double getDerivative(double input) {
return input > 0 ? 1 : 0;
}

@Override
public void apply(List<Neuron> neurons) {
for (Neuron neuron : neurons) {
double output = activate(neuron.getValue() + neuron.getBias());

neuron.setValue(output);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package net.echo.brain4j.activation.impl;

import net.echo.brain4j.activation.Activation;
import net.echo.brain4j.structure.Neuron;

import java.util.List;

public class SigmoidActivation implements Activation {

Expand All @@ -9,8 +12,26 @@ public double activate(double input) {
return 1 / (1 + Math.exp(-input));
}

@Override
public double[] activate(double[] input) {
double[] result = new double[input.length];
for (int i = 0; i < input.length; i++) {
result[i] = activate(input[i]);
}
return result;
}

@Override
public double getDerivative(double input) {
return activate(input) * (1 - activate(input));
}

@Override
public void apply(List<Neuron> neurons) {
for (Neuron neuron : neurons) {
double output = activate(neuron.getValue() + neuron.getBias());

neuron.setValue(output);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package net.echo.brain4j.activation.impl;

import net.echo.brain4j.activation.Activation;
import net.echo.brain4j.structure.Neuron;

import java.util.List;

public class SoftmaxActivation implements Activation {
@Override
public double activate(double input) {
throw new UnsupportedOperationException("Softmax activation function is not supported for single value");
}

@Override
public double[] activate(double[] inputs) {
double maxInput = Double.NEGATIVE_INFINITY;
for (double input : inputs) {
if (input > maxInput) {
maxInput = input;
}
}

double[] expValues = new double[inputs.length];
double sum = 0.0;
for (int i = 0; i < inputs.length; i++) {
expValues[i] = Math.exp(inputs[i] - maxInput);
sum += expValues[i];
}

for (int i = 0; i < expValues.length; i++) {
expValues[i] /= sum;
}

return expValues;
}

@Override
public double getDerivative(double input) {
return input * (1.0 - input);
}

@Override
public void apply(List<Neuron> neurons) {
double[] values = new double[neurons.size()];
for (int i = 0; i < neurons.size(); i++) {
values[i] = neurons.get(i).getValue() + neurons.get(i).getBias();
}

double[] activatedValues = activate(values);

for (int i = 0; i < neurons.size(); i++) {
neurons.get(i).setValue(activatedValues[i]);
}
}
}
36 changes: 36 additions & 0 deletions src/main/java/net/echo/brain4j/activation/impl/TanhActivation.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package net.echo.brain4j.activation.impl;

import net.echo.brain4j.activation.Activation;
import net.echo.brain4j.structure.Neuron;

import java.util.List;

public class TanhActivation implements Activation {

@Override
public double activate(double input) {
return Math.tanh(input);
}

@Override
public double[] activate(double[] inputs) {
double[] result = new double[inputs.length];
for (int i = 0; i < inputs.length; i++) {
result[i] = activate(inputs[i]);
}
return result;
}

@Override
public double getDerivative(double input) {
return 1.0 - Math.pow(Math.tanh(input), 2);
}

@Override
public void apply(List<Neuron> neurons) {
for (Neuron neuron : neurons) {
double value = neuron.getValue() + neuron.getBias();
neuron.setValue(activate(value));
}
}
}
6 changes: 1 addition & 5 deletions src/main/java/net/echo/brain4j/layer/Layer.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,7 @@ public Neuron getNeuronAt(int i) {
public void applyFunction() {
Activation function = activation.getFunction();

for (Neuron neuron : neurons) {
double output = function.activate(neuron.getValue() + neuron.getBias());

neuron.setValue(output);
}
function.apply(neurons);
}

public int getTotalParams() {
Expand Down
10 changes: 9 additions & 1 deletion src/main/java/net/echo/brain4j/loss/LossFunctions.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package net.echo.brain4j.loss;

import net.echo.brain4j.loss.impl.BinaryCrossEntropy;
import net.echo.brain4j.loss.impl.CategoricalCrossEntropy;
import net.echo.brain4j.loss.impl.CrossEntropy;
import net.echo.brain4j.loss.impl.MeanSquaredError;

Expand All @@ -26,7 +27,14 @@ public enum LossFunctions {
* and the actual distribution of classes. Typically used with models having multiple output neurons
* and a softmax activation function.
*/
CROSS_ENTROPY(new CrossEntropy());
CROSS_ENTROPY(new CrossEntropy()),

/**
* Categorical Cross Entropy: A variant of Cross Entropy specifically designed for multi-class classification
* tasks. It is used when the target labels are one-hot encoded. It calculates the divergence between
* the predicted probability distribution and the actual distribution of classes.
*/
CATEGORICAL_CROSS_ENTROPY(new CategoricalCrossEntropy());

private final LossFunction function;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package net.echo.brain4j.loss.impl;

import net.echo.brain4j.loss.LossFunction;

public class CategoricalCrossEntropy implements LossFunction {
@Override
public double calculate(double[] expected, double[] actual) {
double sum = 0.0;
for (int i = 0; i < expected.length; i++) {
sum += -expected[i] * Math.log(actual[i] + 1e-15);
}
return sum;
}
}
17 changes: 14 additions & 3 deletions src/main/java/net/echo/brain4j/model/impl/FeedForwardModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import com.google.gson.reflect.TypeToken;
import net.echo.brain4j.activation.Activation;
import net.echo.brain4j.adapters.LayerAdapter;
import net.echo.brain4j.adapters.OptimizerAdapter;
import net.echo.brain4j.layer.Layer;
Expand Down Expand Up @@ -86,7 +87,17 @@ public void compile(InitializationType type, LossFunctions function, Optimizer o

@Override
public void fit(DataSet set) {
propagation.iterate(set, optimizer.getLearningRate());
System.out.println("Processing batch of " + set.getDataRows().size() + " samples");
for (DataRow row : set.getDataRows()) {
double[] inputs = row.inputs();
System.out.println("Input size: " + inputs.length);

double[] outputs = predict(inputs);
System.out.println("Output size: " + outputs.length);

propagation.iterate(new DataSet(row), optimizer.getLearningRate());
}
System.out.println("Batch processing complete");
}

@Override
Expand Down Expand Up @@ -139,11 +150,11 @@ public double[] predict(double ... input) {
Neuron inputNeuron = synapse.getInputNeuron();
Neuron outputNeuron = synapse.getOutputNeuron();

// Weighted sum
outputNeuron.setValue(outputNeuron.getValue() + inputNeuron.getValue() * synapse.getWeight());
}

// Apply the activation function
Activation activation = layer.getActivation().getFunction();
System.out.println("Activation function: " + activation.getClass().getSimpleName());
nextLayer.applyFunction();
}

Expand Down
1 change: 1 addition & 0 deletions src/main/java/net/echo/brain4j/nlp/LabelTransformer.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public double[] encode(String text, int length) {
return encoded;
}

// to be implemented
public String decode(double[] encoded) {
StringBuilder decoded = new StringBuilder();

Expand Down
11 changes: 11 additions & 0 deletions src/main/java/net/echo/brain4j/nlp/agents/Agent.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package net.echo.brain4j.nlp.agents;

import net.echo.brain4j.training.data.DataSet;

public interface Agent {
String process(String input);
void train(DataSet conversationData);
double evaluate(DataSet testData);
void save(String path);
void load(String path);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package net.echo.brain4j.nlp.agents.attention;

public interface AttentionMechanism {
double[] attend(double[] input, String contextKey);
}

Loading

0 comments on commit e77e194

Please sign in to comment.