Skip to content

Commit 941b8d1

Browse files
authored
Merge pull request #2039 from Phala-Network/fix/tee-derive-key
fix: Update Key Derive in TEE
2 parents 0551c8a + 7e18a9e commit 941b8d1

File tree

9 files changed

+354
-8
lines changed

9 files changed

+354
-8
lines changed

docs/docs/advanced/eliza-in-tee.md

+8-2
Original file line numberDiff line numberDiff line change
@@ -57,16 +57,22 @@ Example usage:
5757
const provider = new DeriveKeyProvider(teeMode);
5858
// For Solana
5959
const { keypair, attestation } = await provider.deriveEd25519Keypair(
60-
"/",
6160
secretSalt,
61+
"solana",
6262
agentId,
6363
);
6464
// For EVM
6565
const { keypair, attestation } = await provider.deriveEcdsaKeypair(
66-
"/",
6766
secretSalt,
67+
"evm",
6868
agentId,
6969
);
70+
71+
// For raw key derivation
72+
const rawKey = await provider.deriveRawKey(
73+
secretSalt,
74+
"raw",
75+
);
7076
```
7177

7278
---

packages/plugin-evm/src/providers/wallet.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -292,8 +292,8 @@ export const initWalletProvider = async (runtime: IAgentRuntime) => {
292292

293293
const deriveKeyProvider = new DeriveKeyProvider(teeMode);
294294
const deriveKeyResult = await deriveKeyProvider.deriveEcdsaKeypair(
295-
"/",
296295
walletSecretSalt,
296+
"evm",
297297
runtime.agentId
298298
);
299299
return new WalletProvider(

packages/plugin-solana/src/keypairUtils.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ export async function getWalletKey(
3030

3131
const deriveKeyProvider = new DeriveKeyProvider(teeMode);
3232
const deriveKeyResult = await deriveKeyProvider.deriveEd25519Keypair(
33-
"/",
3433
walletSecretSalt,
34+
"solana",
3535
runtime.agentId
3636
);
3737

packages/plugin-tee/package.json

+2-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@
3232
"scripts": {
3333
"build": "tsup --format esm --dts",
3434
"dev": "tsup --format esm --dts --watch",
35-
"lint": "eslint --fix --cache ."
35+
"lint": "eslint --fix --cache .",
36+
"test": "vitest run"
3637
},
3738
"peerDependencies": {
3839
"whatwg-url": "7.1.0"

packages/plugin-tee/src/providers/deriveKeyProvider.ts

+22-2
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,12 @@ class DeriveKeyProvider {
7272
return quote;
7373
}
7474

75+
/**
76+
* Derives a raw key from the given path and subject.
77+
* @param path - The path to derive the key from. This is used to derive the key from the root of trust.
78+
* @param subject - The subject to derive the key from. This is used for the certificate chain.
79+
* @returns The derived key.
80+
*/
7581
async rawDeriveKey(
7682
path: string,
7783
subject: string
@@ -94,6 +100,13 @@ class DeriveKeyProvider {
94100
}
95101
}
96102

103+
/**
104+
* Derives an Ed25519 keypair from the given path and subject.
105+
* @param path - The path to derive the key from. This is used to derive the key from the root of trust.
106+
* @param subject - The subject to derive the key from. This is used for the certificate chain.
107+
* @param agentId - The agent ID to generate an attestation for.
108+
* @returns An object containing the derived keypair and attestation.
109+
*/
97110
async deriveEd25519Keypair(
98111
path: string,
99112
subject: string,
@@ -130,6 +143,13 @@ class DeriveKeyProvider {
130143
}
131144
}
132145

146+
/**
147+
* Derives an ECDSA keypair from the given path and subject.
148+
* @param path - The path to derive the key from. This is used to derive the key from the root of trust.
149+
* @param subject - The subject to derive the key from. This is used for the certificate chain.
150+
* @param agentId - The agent ID to generate an attestation for. This is used for the certificate chain.
151+
* @returns An object containing the derived keypair and attestation.
152+
*/
133153
async deriveEcdsaKeypair(
134154
path: string,
135155
subject: string,
@@ -184,13 +204,13 @@ const deriveKeyProvider: Provider = {
184204
const secretSalt =
185205
runtime.getSetting("WALLET_SECRET_SALT") || "secret_salt";
186206
const solanaKeypair = await provider.deriveEd25519Keypair(
187-
"/",
188207
secretSalt,
208+
"solana",
189209
agentId
190210
);
191211
const evmKeypair = await provider.deriveEcdsaKeypair(
192-
"/",
193212
secretSalt,
213+
"evm",
194214
agentId
195215
);
196216
return JSON.stringify({

packages/plugin-tee/src/providers/walletProvider.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -299,8 +299,8 @@ const walletProvider: Provider = {
299299
keypair: Keypair;
300300
attestation: RemoteAttestationQuote;
301301
} = await deriveKeyProvider.deriveEd25519Keypair(
302-
"/",
303302
runtime.getSetting("WALLET_SECRET_SALT"),
303+
"solana",
304304
agentId
305305
);
306306
publicKey = derivedKeyPair.keypair.publicKey;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import { describe, it, expect, vi, beforeEach } from 'vitest';
2+
import { DeriveKeyProvider } from '../providers/deriveKeyProvider';
3+
import { TappdClient } from '@phala/dstack-sdk';
4+
import { TEEMode } from '../types/tee';
5+
import { Keypair } from '@solana/web3.js';
6+
import { privateKeyToAccount } from 'viem/accounts';
7+
8+
// Mock dependencies
9+
vi.mock('@phala/dstack-sdk', () => ({
10+
TappdClient: vi.fn().mockImplementation(() => ({
11+
deriveKey: vi.fn().mockResolvedValue({
12+
asUint8Array: () => new Uint8Array([1, 2, 3, 4, 5])
13+
}),
14+
tdxQuote: vi.fn().mockResolvedValue({
15+
quote: 'mock-quote-data',
16+
replayRtmrs: () => ['rtmr0', 'rtmr1', 'rtmr2', 'rtmr3']
17+
}),
18+
rawDeriveKey: vi.fn()
19+
}))
20+
}));
21+
22+
vi.mock('@solana/web3.js', () => ({
23+
Keypair: {
24+
fromSeed: vi.fn().mockReturnValue({
25+
publicKey: {
26+
toBase58: () => 'mock-solana-public-key'
27+
}
28+
})
29+
}
30+
}));
31+
32+
vi.mock('viem/accounts', () => ({
33+
privateKeyToAccount: vi.fn().mockReturnValue({
34+
address: 'mock-evm-address'
35+
})
36+
}));
37+
38+
describe('DeriveKeyProvider', () => {
39+
beforeEach(() => {
40+
vi.clearAllMocks();
41+
});
42+
43+
describe('constructor', () => {
44+
it('should initialize with LOCAL mode', () => {
45+
const _provider = new DeriveKeyProvider(TEEMode.LOCAL);
46+
expect(TappdClient).toHaveBeenCalledWith('http://localhost:8090');
47+
});
48+
49+
it('should initialize with DOCKER mode', () => {
50+
const _provider = new DeriveKeyProvider(TEEMode.DOCKER);
51+
expect(TappdClient).toHaveBeenCalledWith('http://host.docker.internal:8090');
52+
});
53+
54+
it('should initialize with PRODUCTION mode', () => {
55+
const _provider = new DeriveKeyProvider(TEEMode.PRODUCTION);
56+
expect(TappdClient).toHaveBeenCalledWith();
57+
});
58+
59+
it('should throw error for invalid mode', () => {
60+
expect(() => new DeriveKeyProvider('INVALID_MODE')).toThrow('Invalid TEE_MODE');
61+
});
62+
});
63+
64+
describe('rawDeriveKey', () => {
65+
let _provider: DeriveKeyProvider;
66+
67+
beforeEach(() => {
68+
_provider = new DeriveKeyProvider(TEEMode.LOCAL);
69+
});
70+
71+
it('should derive raw key successfully', async () => {
72+
const path = 'test-path';
73+
const subject = 'test-subject';
74+
const result = await _provider.rawDeriveKey(path, subject);
75+
76+
const client = TappdClient.mock.results[0].value;
77+
expect(client.deriveKey).toHaveBeenCalledWith(path, subject);
78+
expect(result.asUint8Array()).toEqual(new Uint8Array([1, 2, 3, 4, 5]));
79+
});
80+
81+
it('should handle errors during raw key derivation', async () => {
82+
const mockError = new Error('Key derivation failed');
83+
vi.mocked(TappdClient).mockImplementationOnce(() => {
84+
const instance = new TappdClient();
85+
instance.deriveKey = vi.fn().mockRejectedValueOnce(mockError);
86+
instance.tdxQuote = vi.fn();
87+
instance.rawDeriveKey = vi.fn();
88+
return instance;
89+
});
90+
91+
const provider = new DeriveKeyProvider(TEEMode.LOCAL);
92+
await expect(provider.rawDeriveKey('path', 'subject')).rejects.toThrow(mockError);
93+
});
94+
});
95+
96+
describe('deriveEd25519Keypair', () => {
97+
let provider: DeriveKeyProvider;
98+
99+
beforeEach(() => {
100+
provider = new DeriveKeyProvider(TEEMode.LOCAL);
101+
});
102+
103+
it('should derive Ed25519 keypair successfully', async () => {
104+
const path = 'test-path';
105+
const subject = 'test-subject';
106+
const agentId = 'test-agent';
107+
108+
const result = await provider.deriveEd25519Keypair(path, subject, agentId);
109+
110+
expect(result).toHaveProperty('keypair');
111+
expect(result).toHaveProperty('attestation');
112+
expect(result.keypair.publicKey.toBase58()).toBe('mock-solana-public-key');
113+
});
114+
});
115+
116+
describe('deriveEcdsaKeypair', () => {
117+
let provider: DeriveKeyProvider;
118+
119+
beforeEach(() => {
120+
provider = new DeriveKeyProvider(TEEMode.LOCAL);
121+
});
122+
123+
it('should derive ECDSA keypair successfully', async () => {
124+
const path = 'test-path';
125+
const subject = 'test-subject';
126+
const agentId = 'test-agent';
127+
128+
const result = await provider.deriveEcdsaKeypair(path, subject, agentId);
129+
130+
expect(result).toHaveProperty('keypair');
131+
expect(result).toHaveProperty('attestation');
132+
expect(result.keypair.address).toBe('mock-evm-address');
133+
});
134+
});
135+
});
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import { describe, it, expect, vi, beforeEach } from 'vitest';
2+
import { RemoteAttestationProvider } from '../providers/remoteAttestationProvider';
3+
import { TappdClient } from '@phala/dstack-sdk';
4+
import { TEEMode } from '../types/tee';
5+
6+
// Mock TappdClient
7+
vi.mock('@phala/dstack-sdk', () => ({
8+
TappdClient: vi.fn().mockImplementation(() => ({
9+
tdxQuote: vi.fn().mockResolvedValue({
10+
quote: 'mock-quote-data',
11+
replayRtmrs: () => ['rtmr0', 'rtmr1', 'rtmr2', 'rtmr3']
12+
}),
13+
deriveKey: vi.fn()
14+
}))
15+
}));
16+
17+
describe('RemoteAttestationProvider', () => {
18+
beforeEach(() => {
19+
vi.clearAllMocks();
20+
});
21+
22+
describe('constructor', () => {
23+
it('should initialize with LOCAL mode', () => {
24+
const _provider = new RemoteAttestationProvider(TEEMode.LOCAL);
25+
expect(TappdClient).toHaveBeenCalledWith('http://localhost:8090');
26+
});
27+
28+
it('should initialize with DOCKER mode', () => {
29+
const _provider = new RemoteAttestationProvider(TEEMode.DOCKER);
30+
expect(TappdClient).toHaveBeenCalledWith('http://host.docker.internal:8090');
31+
});
32+
33+
it('should initialize with PRODUCTION mode', () => {
34+
const _provider = new RemoteAttestationProvider(TEEMode.PRODUCTION);
35+
expect(TappdClient).toHaveBeenCalledWith();
36+
});
37+
38+
it('should throw error for invalid mode', () => {
39+
expect(() => new RemoteAttestationProvider('INVALID_MODE')).toThrow('Invalid TEE_MODE');
40+
});
41+
});
42+
43+
describe('generateAttestation', () => {
44+
let provider: RemoteAttestationProvider;
45+
46+
beforeEach(() => {
47+
provider = new RemoteAttestationProvider(TEEMode.LOCAL);
48+
});
49+
50+
it('should generate attestation successfully', async () => {
51+
const reportData = 'test-report-data';
52+
const quote = await provider.generateAttestation(reportData);
53+
54+
expect(quote).toEqual({
55+
quote: 'mock-quote-data',
56+
timestamp: expect.any(Number)
57+
});
58+
});
59+
60+
it('should handle errors during attestation generation', async () => {
61+
const mockError = new Error('TDX Quote generation failed');
62+
const mockTdxQuote = vi.fn().mockRejectedValue(mockError);
63+
vi.mocked(TappdClient).mockImplementationOnce(() => ({
64+
tdxQuote: mockTdxQuote,
65+
deriveKey: vi.fn()
66+
}));
67+
68+
const provider = new RemoteAttestationProvider(TEEMode.LOCAL);
69+
await expect(provider.generateAttestation('test-data')).rejects.toThrow('Failed to generate TDX Quote');
70+
});
71+
72+
it('should pass hash algorithm to tdxQuote when provided', async () => {
73+
const reportData = 'test-report-data';
74+
const hashAlgorithm = 'raw';
75+
await provider.generateAttestation(reportData, hashAlgorithm);
76+
77+
const client = TappdClient.mock.results[0].value;
78+
expect(client.tdxQuote).toHaveBeenCalledWith(reportData, hashAlgorithm);
79+
});
80+
});
81+
});

0 commit comments

Comments
 (0)