Skip to content

Commit 4317e36

Browse files
devin-ai-integration[bot]joycieland0xaguspunk
authored
Add OpenGradient plugin (#451)
* Add OpenGradient plugin for on-chain ML and LLM inference Co-Authored-By: joyce@paella.dev <joyce@paella.dev> * Update OpenGradient service file to use secure credential handling Co-Authored-By: joyce@paella.dev <joyce@paella.dev> * Update pnpm-lock.yaml for OpenGradient plugin Co-Authored-By: joyce@paella.dev <joyce@paella.dev> * Update readme * Fix package * Add changeset --------- Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Co-authored-by: joyce@paella.dev <joyce@paella.dev> Co-authored-by: Agustin Armellini Fischer <armellini13@gmail.com>
1 parent 7d49511 commit 4317e36

12 files changed

+641
-0
lines changed
+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"@goat-sdk/plugin-opengradient": patch
3+
---
4+
5+
Release package
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
<div align="center">
2+
<a href="https://github.com/goat-sdk/goat">
3+
<img src="https://github.com/user-attachments/assets/5fc7f121-259c-492c-8bca-f15fe7eb830c" alt="GOAT" width="100px" height="auto" style="object-fit: contain;">
4+
</a>
5+
</div>
6+
7+
# OpenGradient GOAT Plugin
8+
9+
This plugin integrates the OpenGradient service with the GOAT SDK, providing on-chain ML model inference and LLM interactions.
10+
11+
## Installation
12+
```bash
13+
npm install @goat-sdk/plugin-opengradient
14+
yarn add @goat-sdk/plugin-opengradient
15+
pnpm add @goat-sdk/plugin-opengradient
16+
```
17+
18+
## Usage
19+
```typescript
20+
import { opengradient } from "@goat-sdk/plugin-opengradient";
21+
22+
const tools = await getOnChainTools({
23+
wallet: viem(wallet),
24+
plugins: [
25+
opengradient(),
26+
],
27+
});
28+
```
29+
30+
## Tools
31+
* `opengradient_model_inference` - Run inference on machine learning models using OpenGradient
32+
* `opengradient_llm_completion` - Generate text completions using LLMs through OpenGradient
33+
* `opengradient_llm_chat` - Interact with LLMs using a chat interface through OpenGradient
34+
35+
36+
<footer>
37+
<br/>
38+
<br/>
39+
<div>
40+
<a href="https://github.com/goat-sdk/goat">
41+
<img src="https://github.com/user-attachments/assets/59fa5ddc-9d47-4d41-a51a-64f6798f94bd" alt="GOAT" width="100%" height="auto" style="object-fit: contain; max-width: 800px;">
42+
</a>
43+
</div>
44+
</footer>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
{
2+
"name": "@goat-sdk/plugin-opengradient",
3+
"version": "0.1.0",
4+
"files": ["dist/**/*", "README.md", "package.json"],
5+
"scripts": {
6+
"build": "tsup",
7+
"clean": "rm -rf dist",
8+
"test": "vitest run --passWithNoTests"
9+
},
10+
"main": "./dist/index.js",
11+
"module": "./dist/index.mjs",
12+
"types": "./dist/index.d.ts",
13+
"sideEffects": false,
14+
"homepage": "https://ohmygoat.dev",
15+
"repository": {
16+
"type": "git",
17+
"url": "git+https://github.com/goat-sdk/goat.git"
18+
},
19+
"license": "MIT",
20+
"bugs": {
21+
"url": "https://github.com/goat-sdk/goat/issues"
22+
},
23+
"keywords": ["ai", "agents", "web3"],
24+
"dependencies": {
25+
"@goat-sdk/core": "workspace:*",
26+
"opengradient-sdk": "1.0.0",
27+
"web3": "4.16.0",
28+
"zod": "catalog:"
29+
},
30+
"peerDependencies": {
31+
"@goat-sdk/core": "workspace:*"
32+
}
33+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
export * from "./opengradient.plugin";
2+
export * from "./parameters";
3+
export * from "./types";
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import { type Chain, PluginBase } from "@goat-sdk/core";
2+
import { OpengradientService } from "./opengradient.service";
3+
import { OpenGradientConfig } from "./types";
4+
5+
export type OpengradientPluginCtorParams = {
6+
config: OpenGradientConfig;
7+
};
8+
9+
export class OpengradientPlugin extends PluginBase {
10+
constructor({ config }: OpengradientPluginCtorParams) {
11+
super("opengradient", [new OpengradientService(config)]);
12+
}
13+
14+
supportsChain = (_chain: Chain) => true;
15+
}
16+
17+
export function opengradient(params: OpengradientPluginCtorParams) {
18+
return new OpengradientPlugin(params);
19+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import { Tool } from "@goat-sdk/core";
2+
import { WalletClientBase } from "@goat-sdk/core";
3+
import { Client } from "opengradient-sdk";
4+
import { LLMChatParameters, LLMCompletionParameters, ModelInferenceParameters } from "./parameters";
5+
import { OpenGradientConfig } from "./types";
6+
7+
export class OpengradientService {
8+
private client: Client | null = null;
9+
10+
constructor(private config: OpenGradientConfig) {}
11+
12+
private getClient(clientConfig: OpenGradientConfig): Client {
13+
if (!this.client) {
14+
this.client = new Client({
15+
privateKey: clientConfig.privateKey,
16+
});
17+
}
18+
return this.client;
19+
}
20+
21+
@Tool({
22+
name: "opengradient_model_inference",
23+
description: "Run inference on a machine learning model using OpenGradient",
24+
})
25+
async runModelInference(walletClient: WalletClientBase, parameters: ModelInferenceParameters) {
26+
const client = this.getClient(this.config);
27+
const [txHash, modelOutput] = await client.infer(
28+
parameters.modelCid,
29+
parameters.inferenceMode,
30+
parameters.modelInput,
31+
);
32+
33+
return {
34+
transactionHash: txHash,
35+
output: modelOutput,
36+
};
37+
}
38+
39+
@Tool({
40+
name: "opengradient_llm_completion",
41+
description: "Generate text completions using an LLM through OpenGradient",
42+
})
43+
async runLLMCompletion(walletClient: WalletClientBase, parameters: LLMCompletionParameters) {
44+
const client = this.getClient(this.config);
45+
const [txHash, completion] = await client.llmCompletion(
46+
parameters.modelCid,
47+
parameters.inferenceMode,
48+
parameters.prompt,
49+
parameters.maxTokens,
50+
parameters.stopSequence,
51+
parameters.temperature,
52+
);
53+
54+
return {
55+
transactionHash: txHash,
56+
completion,
57+
};
58+
}
59+
60+
@Tool({
61+
name: "opengradient_llm_chat",
62+
description: "Interact with an LLM using a chat interface through OpenGradient",
63+
})
64+
async runLLMChat(walletClient: WalletClientBase, parameters: LLMChatParameters) {
65+
const client = this.getClient(this.config);
66+
const [txHash, finishReason, message] = await client.llmChat(
67+
parameters.modelCid,
68+
parameters.inferenceMode,
69+
parameters.messages,
70+
parameters.maxTokens,
71+
parameters.stopSequence,
72+
parameters.temperature,
73+
parameters.tools,
74+
parameters.toolChoice,
75+
);
76+
77+
return {
78+
transactionHash: txHash,
79+
finishReason,
80+
message,
81+
};
82+
}
83+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import { createToolParameters } from "@goat-sdk/core";
2+
import { z } from "zod";
3+
4+
export class ModelInferenceParameters extends createToolParameters(
5+
z.object({
6+
modelCid: z.string().describe("CID of the model to run inference with"),
7+
inferenceMode: z.number().describe("Inference mode to use (0 for VANILLA, 1 for TEE)"),
8+
modelInput: z.record(z.any()).describe("Input to the model in the form of a JSON object"),
9+
}),
10+
) {}
11+
12+
export class LLMCompletionParameters extends createToolParameters(
13+
z.object({
14+
modelCid: z.string().describe("CID of the LLM model to use"),
15+
inferenceMode: z.number().describe("Inference mode to use (0 for VANILLA, 1 for TEE)"),
16+
prompt: z.string().describe("Text prompt for the LLM"),
17+
maxTokens: z.number().default(100).describe("Maximum number of tokens to generate"),
18+
stopSequence: z.array(z.string()).default([]).describe("Sequences that will stop generation if encountered"),
19+
temperature: z.number().default(0).describe("Temperature for sampling (0.0 to 1.0)"),
20+
}),
21+
) {}
22+
23+
export class LLMChatParameters extends createToolParameters(
24+
z.object({
25+
modelCid: z.string().describe("CID of the LLM model to use"),
26+
inferenceMode: z.number().describe("Inference mode to use (0 for VANILLA, 1 for TEE)"),
27+
messages: z
28+
.array(
29+
z.object({
30+
role: z.enum(["system", "user", "assistant", "tool"]).describe("Role of the message sender"),
31+
content: z.string().describe("Content of the message"),
32+
toolCalls: z.array(z.any()).optional().describe("Tool calls made in this message"),
33+
toolCallId: z.string().optional().describe("ID of the tool call this message is responding to"),
34+
name: z.string().optional().describe("Name of the entity sending the message"),
35+
}),
36+
)
37+
.describe("Array of conversation messages"),
38+
maxTokens: z.number().default(100).describe("Maximum number of tokens to generate"),
39+
stopSequence: z.array(z.string()).default([]).describe("Sequences that will stop generation if encountered"),
40+
temperature: z.number().default(0).describe("Temperature for sampling (0.0 to 1.0)"),
41+
tools: z.array(z.any()).default([]).describe("Tools available to the model"),
42+
toolChoice: z.string().optional().describe("Tool choice strategy ('auto', 'none', or specific tool name)"),
43+
}),
44+
) {}
45+
46+
export class ClientConfigParameters extends createToolParameters(
47+
z.object({
48+
privateKey: z.string().describe("Private key for authentication with OpenGradient"),
49+
}),
50+
) {}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
export type OpenGradientConfig = {
2+
privateKey: string;
3+
};
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"$schema": "https://json.schemastore.org/tsconfig",
3+
"extends": "../../../tsconfig.base.json",
4+
"include": ["src/**/*"],
5+
"exclude": ["node_modules", "dist"]
6+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import { defineConfig } from "tsup";
2+
import { treeShakableConfig } from "../../../tsup.config.base";
3+
4+
export default defineConfig({
5+
...treeShakableConfig,
6+
});
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
{
2+
"$schema": "https://turbo.build/schema.json",
3+
"extends": ["//"],
4+
"tasks": {
5+
"build": {
6+
"inputs": ["src/**", "tsup.config.ts", "!./**/*.test.{ts,tsx}", "tsconfig.json"],
7+
"dependsOn": ["^build"],
8+
"outputs": ["dist/**"]
9+
}
10+
}
11+
}

0 commit comments

Comments
 (0)