Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OpenGradient plugin #451

Merged
merged 6 commits into from
Mar 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions typescript/.changeset/tall-pants-turn.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@goat-sdk/plugin-opengradient": patch
---

Release package
44 changes: 44 additions & 0 deletions typescript/packages/plugins/opengradient/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
<div align="center">
<a href="https://github.com/goat-sdk/goat">
<img src="https://github.com/user-attachments/assets/5fc7f121-259c-492c-8bca-f15fe7eb830c" alt="GOAT" width="100px" height="auto" style="object-fit: contain;">
</a>
</div>

# OpenGradient GOAT Plugin

This plugin integrates the OpenGradient service with the GOAT SDK, providing on-chain ML model inference and LLM interactions.

## Installation
```bash
npm install @goat-sdk/plugin-opengradient
yarn add @goat-sdk/plugin-opengradient
pnpm add @goat-sdk/plugin-opengradient
```

## Usage
```typescript
import { opengradient } from "@goat-sdk/plugin-opengradient";

const tools = await getOnChainTools({
wallet: viem(wallet),
plugins: [
opengradient(),
],
});
```

## Tools
* `opengradient_model_inference` - Run inference on machine learning models using OpenGradient
* `opengradient_llm_completion` - Generate text completions using LLMs through OpenGradient
* `opengradient_llm_chat` - Interact with LLMs using a chat interface through OpenGradient


<footer>
<br/>
<br/>
<div>
<a href="https://github.com/goat-sdk/goat">
<img src="https://github.com/user-attachments/assets/59fa5ddc-9d47-4d41-a51a-64f6798f94bd" alt="GOAT" width="100%" height="auto" style="object-fit: contain; max-width: 800px;">
</a>
</div>
</footer>
33 changes: 33 additions & 0 deletions typescript/packages/plugins/opengradient/package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
{
"name": "@goat-sdk/plugin-opengradient",
"version": "0.1.0",
"files": ["dist/**/*", "README.md", "package.json"],
"scripts": {
"build": "tsup",
"clean": "rm -rf dist",
"test": "vitest run --passWithNoTests"
},
"main": "./dist/index.js",
"module": "./dist/index.mjs",
"types": "./dist/index.d.ts",
"sideEffects": false,
"homepage": "https://ohmygoat.dev",
"repository": {
"type": "git",
"url": "git+https://github.com/goat-sdk/goat.git"
},
"license": "MIT",
"bugs": {
"url": "https://github.com/goat-sdk/goat/issues"
},
"keywords": ["ai", "agents", "web3"],
"dependencies": {
"@goat-sdk/core": "workspace:*",
"opengradient-sdk": "1.0.0",
"web3": "4.16.0",
"zod": "catalog:"
},
"peerDependencies": {
"@goat-sdk/core": "workspace:*"
}
}
3 changes: 3 additions & 0 deletions typescript/packages/plugins/opengradient/src/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
export * from "./opengradient.plugin";
export * from "./parameters";
export * from "./types";
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import { type Chain, PluginBase } from "@goat-sdk/core";
import { OpengradientService } from "./opengradient.service";
import { OpenGradientConfig } from "./types";

export type OpengradientPluginCtorParams = {
config: OpenGradientConfig;
};

export class OpengradientPlugin extends PluginBase {
constructor({ config }: OpengradientPluginCtorParams) {
super("opengradient", [new OpengradientService(config)]);
}

supportsChain = (_chain: Chain) => true;
}

export function opengradient(params: OpengradientPluginCtorParams) {
return new OpengradientPlugin(params);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import { Tool } from "@goat-sdk/core";
import { WalletClientBase } from "@goat-sdk/core";
import { Client } from "opengradient-sdk";
import { LLMChatParameters, LLMCompletionParameters, ModelInferenceParameters } from "./parameters";
import { OpenGradientConfig } from "./types";

export class OpengradientService {
private client: Client | null = null;

constructor(private config: OpenGradientConfig) {}

private getClient(clientConfig: OpenGradientConfig): Client {
if (!this.client) {
this.client = new Client({
privateKey: clientConfig.privateKey,
});
}
return this.client;
}

@Tool({
name: "opengradient_model_inference",
description: "Run inference on a machine learning model using OpenGradient",
})
async runModelInference(walletClient: WalletClientBase, parameters: ModelInferenceParameters) {
const client = this.getClient(this.config);
const [txHash, modelOutput] = await client.infer(
parameters.modelCid,
parameters.inferenceMode,
parameters.modelInput,
);

return {
transactionHash: txHash,
output: modelOutput,
};
}

@Tool({
name: "opengradient_llm_completion",
description: "Generate text completions using an LLM through OpenGradient",
})
async runLLMCompletion(walletClient: WalletClientBase, parameters: LLMCompletionParameters) {
const client = this.getClient(this.config);
const [txHash, completion] = await client.llmCompletion(
parameters.modelCid,
parameters.inferenceMode,
parameters.prompt,
parameters.maxTokens,
parameters.stopSequence,
parameters.temperature,
);

return {
transactionHash: txHash,
completion,
};
}

@Tool({
name: "opengradient_llm_chat",
description: "Interact with an LLM using a chat interface through OpenGradient",
})
async runLLMChat(walletClient: WalletClientBase, parameters: LLMChatParameters) {
const client = this.getClient(this.config);
const [txHash, finishReason, message] = await client.llmChat(
parameters.modelCid,
parameters.inferenceMode,
parameters.messages,
parameters.maxTokens,
parameters.stopSequence,
parameters.temperature,
parameters.tools,
parameters.toolChoice,
);

return {
transactionHash: txHash,
finishReason,
message,
};
}
}
50 changes: 50 additions & 0 deletions typescript/packages/plugins/opengradient/src/parameters.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import { createToolParameters } from "@goat-sdk/core";
import { z } from "zod";

export class ModelInferenceParameters extends createToolParameters(
z.object({
modelCid: z.string().describe("CID of the model to run inference with"),
inferenceMode: z.number().describe("Inference mode to use (0 for VANILLA, 1 for TEE)"),
modelInput: z.record(z.any()).describe("Input to the model in the form of a JSON object"),
}),
) {}

export class LLMCompletionParameters extends createToolParameters(
z.object({
modelCid: z.string().describe("CID of the LLM model to use"),
inferenceMode: z.number().describe("Inference mode to use (0 for VANILLA, 1 for TEE)"),
prompt: z.string().describe("Text prompt for the LLM"),
maxTokens: z.number().default(100).describe("Maximum number of tokens to generate"),
stopSequence: z.array(z.string()).default([]).describe("Sequences that will stop generation if encountered"),
temperature: z.number().default(0).describe("Temperature for sampling (0.0 to 1.0)"),
}),
) {}

export class LLMChatParameters extends createToolParameters(
z.object({
modelCid: z.string().describe("CID of the LLM model to use"),
inferenceMode: z.number().describe("Inference mode to use (0 for VANILLA, 1 for TEE)"),
messages: z
.array(
z.object({
role: z.enum(["system", "user", "assistant", "tool"]).describe("Role of the message sender"),
content: z.string().describe("Content of the message"),
toolCalls: z.array(z.any()).optional().describe("Tool calls made in this message"),
toolCallId: z.string().optional().describe("ID of the tool call this message is responding to"),
name: z.string().optional().describe("Name of the entity sending the message"),
}),
)
.describe("Array of conversation messages"),
maxTokens: z.number().default(100).describe("Maximum number of tokens to generate"),
stopSequence: z.array(z.string()).default([]).describe("Sequences that will stop generation if encountered"),
temperature: z.number().default(0).describe("Temperature for sampling (0.0 to 1.0)"),
tools: z.array(z.any()).default([]).describe("Tools available to the model"),
toolChoice: z.string().optional().describe("Tool choice strategy ('auto', 'none', or specific tool name)"),
}),
) {}

export class ClientConfigParameters extends createToolParameters(
z.object({
privateKey: z.string().describe("Private key for authentication with OpenGradient"),
}),
) {}
3 changes: 3 additions & 0 deletions typescript/packages/plugins/opengradient/src/types/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
export type OpenGradientConfig = {
privateKey: string;
};
6 changes: 6 additions & 0 deletions typescript/packages/plugins/opengradient/tsconfig.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"$schema": "https://json.schemastore.org/tsconfig",
"extends": "../../../tsconfig.base.json",
"include": ["src/**/*"],
"exclude": ["node_modules", "dist"]
}
6 changes: 6 additions & 0 deletions typescript/packages/plugins/opengradient/tsup.config.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import { defineConfig } from "tsup";
import { treeShakableConfig } from "../../../tsup.config.base";

export default defineConfig({
...treeShakableConfig,
});
11 changes: 11 additions & 0 deletions typescript/packages/plugins/opengradient/turbo.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"$schema": "https://turbo.build/schema.json",
"extends": ["//"],
"tasks": {
"build": {
"inputs": ["src/**", "tsup.config.ts", "!./**/*.test.{ts,tsx}", "tsconfig.json"],
"dependsOn": ["^build"],
"outputs": ["dist/**"]
}
}
}
Loading