@@ -284,45 +284,68 @@ CHIP_ERROR PsaKdf::Init(const ByteSpan & secret, const ByteSpan & salt, const By
284
284
psa_reset_key_attributes (&attrs);
285
285
VerifyOrReturnError (status == PSA_SUCCESS, CHIP_ERROR_INTERNAL);
286
286
287
- return InitOperation (mSecretKeyId , salt, info);
287
+ PsaHkdfKeyHandle hkdfKeyHandle = { .mKeyId = mSecretKeyId , .mIsKeyId = true };
288
+
289
+ return InitOperation (hkdfKeyHandle, salt, info);
288
290
}
289
291
290
292
CHIP_ERROR PsaKdf::Init (const HkdfKeyHandle & hkdfKey, const ByteSpan & salt, const ByteSpan & info)
291
293
{
292
- return InitOperation (hkdfKey.As <psa_key_id_t >(), salt, info);
294
+ return InitOperation (hkdfKey.As <PsaHkdfKeyHandle >(), salt, info);
293
295
}
294
296
295
- CHIP_ERROR PsaKdf::InitOperation (psa_key_id_t hkdfKey, const ByteSpan & salt, const ByteSpan & info)
297
+ CHIP_ERROR PsaKdf::InitOperation (PsaHkdfKeyHandle hkdfKey, const ByteSpan & salt, const ByteSpan & info)
296
298
{
297
- psa_status_t status = psa_key_derivation_setup (&mOperation , PSA_ALG_HKDF (PSA_ALG_SHA_256));
298
- VerifyOrReturnError (status == PSA_SUCCESS, CHIP_ERROR_INTERNAL);
299
-
300
- if (salt.size () > 0 )
299
+ psa_status_t status;
300
+ if (hkdfKey.mIsKeyId )
301
301
{
302
- status = psa_key_derivation_input_bytes (&mOperation , PSA_KEY_DERIVATION_INPUT_SALT, salt. data (), salt. size ( ));
302
+ status = psa_key_derivation_setup (&mOperation , PSA_ALG_HKDF (PSA_ALG_SHA_256 ));
303
303
VerifyOrReturnError (status == PSA_SUCCESS, CHIP_ERROR_INTERNAL);
304
+
305
+ if (salt.size () > 0 )
306
+ {
307
+ status = psa_key_derivation_input_bytes (&mOperation , PSA_KEY_DERIVATION_INPUT_SALT, salt.data (), salt.size ());
308
+ VerifyOrReturnError (status == PSA_SUCCESS, CHIP_ERROR_INTERNAL);
309
+ }
310
+
311
+ status = psa_key_derivation_input_key (&mOperation , PSA_KEY_DERIVATION_INPUT_SECRET, hkdfKey.mKeyId );
312
+ VerifyOrReturnError (status == PSA_SUCCESS, CHIP_ERROR_INTERNAL);
313
+
314
+ status = psa_key_derivation_input_bytes (&mOperation , PSA_KEY_DERIVATION_INPUT_INFO, info.data (), info.size ());
315
+ VerifyOrReturnError (status == PSA_SUCCESS, CHIP_ERROR_INTERNAL);
316
+
317
+ mDerivationOperation = &mOperation ;
304
318
}
319
+ else
320
+ {
321
+ mDerivationOperation = hkdfKey.mKeyDerivationOp ;
305
322
306
- status = psa_key_derivation_input_key (&mOperation , PSA_KEY_DERIVATION_INPUT_SECRET, hkdfKey);
307
- VerifyOrReturnError (status == PSA_SUCCESS, CHIP_ERROR_INTERNAL);
323
+ if (salt.size () > 0 )
324
+ {
325
+ status = psa_key_derivation_input_bytes (mDerivationOperation , PSA_KEY_DERIVATION_INPUT_SALT, salt.data (), salt.size ());
326
+ VerifyOrReturnError (status == PSA_SUCCESS, CHIP_ERROR_INTERNAL);
327
+ }
308
328
309
- status = psa_key_derivation_input_bytes (&mOperation , PSA_KEY_DERIVATION_INPUT_INFO, info.data (), info.size ());
310
- VerifyOrReturnError (status == PSA_SUCCESS, CHIP_ERROR_INTERNAL);
329
+ status = psa_key_derivation_input_bytes (mDerivationOperation , PSA_KEY_DERIVATION_INPUT_INFO, info.data (), info.size ());
330
+ VerifyOrReturnError (status == PSA_SUCCESS, CHIP_ERROR_INTERNAL);
331
+ }
311
332
312
333
return CHIP_NO_ERROR;
313
334
}
314
335
315
336
CHIP_ERROR PsaKdf::DeriveBytes (const MutableByteSpan & output)
316
337
{
317
- psa_status_t status = psa_key_derivation_output_bytes (&mOperation , output.data (), output.size ());
338
+ psa_status_t status = psa_key_derivation_output_bytes (mDerivationOperation , output.data (), output.size ());
339
+
318
340
VerifyOrReturnError (status == PSA_SUCCESS, CHIP_ERROR_INTERNAL);
319
341
320
342
return CHIP_NO_ERROR;
321
343
}
322
344
323
345
CHIP_ERROR PsaKdf::DeriveKey (const psa_key_attributes_t & attributes, psa_key_id_t & keyId)
324
346
{
325
- psa_status_t status = psa_key_derivation_output_key (&attributes, &mOperation , &keyId);
347
+ psa_status_t status = psa_key_derivation_output_key (&attributes, mDerivationOperation , &keyId);
348
+
326
349
VerifyOrReturnError (status == PSA_SUCCESS, CHIP_ERROR_INTERNAL);
327
350
328
351
return CHIP_NO_ERROR;
0 commit comments