@@ -112,10 +112,10 @@ CHIP_ERROR AES_CCM_encrypt(const uint8_t * plaintext, size_t plaintext_length, c
112
112
VerifyOrReturnError (aad != nullptr || aad_length == 0 , CHIP_ERROR_INVALID_ARGUMENT);
113
113
114
114
const psa_algorithm_t algorithm = PSA_ALG_AEAD_WITH_SHORTENED_TAG (PSA_ALG_CCM, tag_length);
115
- psa_status_t status = PSA_SUCCESS;
116
115
psa_aead_operation_t operation = PSA_AEAD_OPERATION_INIT;
117
- size_t out_length;
118
- size_t tag_out_length;
116
+ psa_status_t status = PSA_SUCCESS;
117
+ size_t out_length = 0 ;
118
+ size_t tag_out_length = 0 ;
119
119
120
120
status = psa_aead_encrypt_setup (&operation, key.As <psa_key_id_t >(), algorithm);
121
121
VerifyOrReturnError (status == PSA_SUCCESS, CHIP_ERROR_INTERNAL);
@@ -126,30 +126,71 @@ CHIP_ERROR AES_CCM_encrypt(const uint8_t * plaintext, size_t plaintext_length, c
126
126
status = psa_aead_set_nonce (&operation, nonce, nonce_length);
127
127
VerifyOrReturnError (status == PSA_SUCCESS, CHIP_ERROR_INTERNAL);
128
128
129
- if (aad_length != 0 )
129
+ if (0 == aad_length )
130
130
{
131
- status = psa_aead_update_ad (&operation, aad, aad_length);
132
- VerifyOrReturnError (status == PSA_SUCCESS, CHIP_ERROR_INTERNAL);
131
+ ChipLogDetail (Crypto, " AES_CCM_encrypt: Using aad == null path" );
133
132
}
134
133
else
135
134
{
136
- ChipLogDetail (Crypto, " AES_CCM_encrypt: Using aad == null path" );
135
+ status = psa_aead_update_ad (&operation, aad, aad_length);
136
+ VerifyOrReturnError (status == PSA_SUCCESS, CHIP_ERROR_INTERNAL);
137
137
}
138
138
139
- if (plaintext_length != 0 )
139
+ if (0 == plaintext_length )
140
140
{
141
- status = psa_aead_update (&operation, plaintext, plaintext_length, ciphertext,
142
- PSA_AEAD_UPDATE_OUTPUT_SIZE (PSA_KEY_TYPE_AES, algorithm, plaintext_length), &out_length);
143
- VerifyOrReturnError (status == PSA_SUCCESS, CHIP_ERROR_INTERNAL);
144
-
145
- ciphertext += out_length;
146
-
147
- status = psa_aead_finish (&operation, ciphertext, PSA_AEAD_FINISH_OUTPUT_SIZE (PSA_KEY_TYPE_AES, algorithm), &out_length, tag,
148
- tag_length, &tag_out_length);
141
+ // Empty plaintext
142
+ status = psa_aead_finish (&operation, nullptr , 0 , &out_length, tag, tag_length, &tag_out_length);
149
143
}
150
144
else
151
145
{
152
- status = psa_aead_finish (&operation, nullptr , 0 , &out_length, tag, tag_length, &tag_out_length);
146
+ // psa_aead_update() requires use of the macro PSA_AEAD_UPDATE_OUTPUT_SIZE to determine the output buffer size.
147
+ // For AES-CCM, PSA_AEAD_UPDATE_OUTPUT_SIZE will round up the size to the next multiple of the block size (16).
148
+ // If the ciphertext length is not a multiple of the block size, we will encrypt in two steps, first with the
149
+ // block_aligned_length, and then with a rounded up partial_block_length, where a temporary buffer will be used for the
150
+ // output.
151
+ constexpr uint8_t kBlockSize = PSA_BLOCK_CIPHER_BLOCK_LENGTH (PSA_KEY_TYPE_AES);
152
+ size_t block_aligned_length = (plaintext_length / kBlockSize ) * kBlockSize ;
153
+ size_t partial_block_length = plaintext_length % kBlockSize ;
154
+ size_t ciphertext_length = 0 ;
155
+ uint8_t temp[kBlockSize ] = { 0 };
156
+
157
+ // Make sure the calculated block_aligned_length is compliant with PSA's output size requirements.
158
+ VerifyOrReturnError (block_aligned_length == PSA_AEAD_UPDATE_OUTPUT_SIZE (PSA_KEY_TYPE_AES, algorithm, block_aligned_length),
159
+ CHIP_ERROR_INTERNAL);
160
+
161
+ // Add the aligned part of the plaintext
162
+ status = psa_aead_update (&operation, plaintext, block_aligned_length, ciphertext, block_aligned_length, &out_length);
163
+ VerifyOrReturnError (status == PSA_SUCCESS, CHIP_ERROR_INTERNAL);
164
+ VerifyOrReturnError (out_length == block_aligned_length, CHIP_ERROR_INTERNAL);
165
+ ciphertext_length += out_length;
166
+
167
+ if (partial_block_length > 0 )
168
+ {
169
+ // The update output should fit in the temp buffer
170
+ size_t max_output = PSA_AEAD_UPDATE_OUTPUT_SIZE (PSA_KEY_TYPE_AES, algorithm, partial_block_length);
171
+
172
+ // Add the non-aligned end of the plaintext
173
+ status =
174
+ psa_aead_update (&operation, &plaintext[block_aligned_length], partial_block_length, temp, max_output, &out_length);
175
+ VerifyOrReturnError (status == PSA_SUCCESS, CHIP_ERROR_INTERNAL);
176
+ VerifyOrReturnError (ciphertext_length + out_length <= plaintext_length, CHIP_ERROR_INTERNAL);
177
+ // Add the encrypted output, if any
178
+ memcpy (&ciphertext[ciphertext_length], temp, out_length);
179
+ ciphertext_length += out_length;
180
+ }
181
+
182
+ // The finish output should fit in the temp buffer
183
+ size_t max_finish = PSA_AEAD_FINISH_OUTPUT_SIZE (PSA_KEY_TYPE_AES, algorithm);
184
+ VerifyOrReturnError (max_finish <= sizeof (temp), CHIP_ERROR_BUFFER_TOO_SMALL);
185
+
186
+ // The finish may return the last part of the ciphertext
187
+ status = psa_aead_finish (&operation, temp, max_finish, &out_length, tag, tag_length, &tag_out_length);
188
+ VerifyOrReturnError (status == PSA_SUCCESS, CHIP_ERROR_INTERNAL);
189
+ VerifyOrReturnError (ciphertext_length + out_length <= plaintext_length, CHIP_ERROR_INTERNAL);
190
+ // Add the encrypted output, if any
191
+ memcpy (&ciphertext[ciphertext_length], temp, out_length);
192
+ ciphertext_length += out_length;
193
+ VerifyOrReturnError (ciphertext_length == plaintext_length, CHIP_ERROR_INTERNAL);
153
194
}
154
195
VerifyOrReturnError (status == PSA_SUCCESS && tag_length == tag_out_length, CHIP_ERROR_INTERNAL);
155
196
@@ -166,9 +207,9 @@ CHIP_ERROR AES_CCM_decrypt(const uint8_t * ciphertext, size_t ciphertext_length,
166
207
VerifyOrReturnError (aad != nullptr || aad_length == 0 , CHIP_ERROR_INVALID_ARGUMENT);
167
208
168
209
const psa_algorithm_t algorithm = PSA_ALG_AEAD_WITH_SHORTENED_TAG (PSA_ALG_CCM, tag_length);
169
- psa_status_t status = PSA_SUCCESS;
170
210
psa_aead_operation_t operation = PSA_AEAD_OPERATION_INIT;
171
- size_t outLength;
211
+ psa_status_t status = PSA_SUCCESS;
212
+ size_t out_length = 0 ;
172
213
173
214
status = psa_aead_decrypt_setup (&operation, key.As <psa_key_id_t >(), algorithm);
174
215
VerifyOrReturnError (status == PSA_SUCCESS, CHIP_ERROR_INTERNAL);
@@ -179,32 +220,71 @@ CHIP_ERROR AES_CCM_decrypt(const uint8_t * ciphertext, size_t ciphertext_length,
179
220
status = psa_aead_set_nonce (&operation, nonce, nonce_length);
180
221
VerifyOrReturnError (status == PSA_SUCCESS, CHIP_ERROR_INTERNAL);
181
222
182
- if (aad_length != 0 )
223
+ if (0 == aad_length )
183
224
{
184
- status = psa_aead_update_ad (&operation, aad, aad_length);
185
- VerifyOrReturnError (status == PSA_SUCCESS, CHIP_ERROR_INTERNAL);
225
+ ChipLogDetail (Crypto, " AES_CCM_decrypt: Using aad == null path" );
186
226
}
187
227
else
188
228
{
189
- ChipLogDetail (Crypto, " AES_CCM_decrypt: Using aad == null path" );
229
+ status = psa_aead_update_ad (&operation, aad, aad_length);
230
+ VerifyOrReturnError (status == PSA_SUCCESS, CHIP_ERROR_INTERNAL);
190
231
}
191
232
192
- if (ciphertext_length != 0 )
233
+ if (0 == ciphertext_length )
193
234
{
194
- status = psa_aead_update (&operation, ciphertext, ciphertext_length, plaintext,
195
- PSA_AEAD_UPDATE_OUTPUT_SIZE (PSA_KEY_TYPE_AES, algorithm, ciphertext_length), &outLength);
196
- VerifyOrReturnError (status == PSA_SUCCESS, CHIP_ERROR_INTERNAL);
197
-
198
- plaintext += outLength;
199
-
200
- status = psa_aead_verify (&operation, plaintext, PSA_AEAD_VERIFY_OUTPUT_SIZE (PSA_KEY_TYPE_AES, algorithm), &outLength, tag,
201
- tag_length);
235
+ status = psa_aead_verify (&operation, nullptr , 0 , &out_length, tag, tag_length);
202
236
}
203
237
else
204
238
{
205
- status = psa_aead_verify (&operation, nullptr , 0 , &outLength, tag, tag_length);
206
- }
239
+ // psa_aead_update() requires use of the macro PSA_AEAD_UPDATE_OUTPUT_SIZE to determine the output buffer size.
240
+ // For AES-CCM, PSA_AEAD_UPDATE_OUTPUT_SIZE will round up the size to the next multiple of the block size (16).
241
+ // If the plaintext length is not a multiple of the block size, we will encrypt in two steps, first with the
242
+ // block_aligned_length, and then with a rounded up partial_block_length, where a temporary buffer will be used for the
243
+ // output.
244
+ constexpr uint8_t kBlockSize = PSA_BLOCK_CIPHER_BLOCK_LENGTH (PSA_KEY_TYPE_AES);
245
+ size_t block_aligned_length = (ciphertext_length / kBlockSize ) * kBlockSize ;
246
+ size_t partial_block_length = ciphertext_length % kBlockSize ;
247
+ size_t plaintext_length = 0 ;
248
+ uint8_t temp[kBlockSize ] = { 0 };
249
+
250
+ // Make sure the calculated block_aligned_length is compliant with PSA's output size requirements.
251
+ VerifyOrReturnError (block_aligned_length == PSA_AEAD_UPDATE_OUTPUT_SIZE (PSA_KEY_TYPE_AES, algorithm, block_aligned_length),
252
+ CHIP_ERROR_INTERNAL);
253
+
254
+ // Add the aligned part of the ciphertext
255
+ status = psa_aead_update (&operation, ciphertext, block_aligned_length, plaintext, block_aligned_length, &out_length);
256
+ VerifyOrReturnError (status == PSA_SUCCESS, CHIP_ERROR_INTERNAL);
257
+ VerifyOrReturnError (out_length == block_aligned_length, CHIP_ERROR_INTERNAL);
258
+ plaintext_length += out_length;
259
+
260
+ if (partial_block_length > 0 )
261
+ {
262
+ // The update output should fit in the temp buffer
263
+ size_t max_output = PSA_AEAD_UPDATE_OUTPUT_SIZE (PSA_KEY_TYPE_AES, algorithm, partial_block_length);
264
+
265
+ // Add the non-aligned end of the ciphertext
266
+ status =
267
+ psa_aead_update (&operation, &ciphertext[block_aligned_length], partial_block_length, temp, max_output, &out_length);
268
+ VerifyOrReturnError (status == PSA_SUCCESS, CHIP_ERROR_INTERNAL);
269
+ VerifyOrReturnError (plaintext_length + out_length <= ciphertext_length, CHIP_ERROR_INTERNAL);
270
+ // Add the decrypted output, if any
271
+ memcpy (&plaintext[plaintext_length], temp, out_length);
272
+ plaintext_length += out_length;
273
+ }
274
+
275
+ // The finish output should fit in the temp buffer
276
+ size_t max_verify = PSA_AEAD_VERIFY_OUTPUT_SIZE (PSA_KEY_TYPE_AES, algorithm);
277
+ VerifyOrReturnError (max_verify <= sizeof (temp), CHIP_ERROR_BUFFER_TOO_SMALL);
207
278
279
+ // Complete verification
280
+ status = psa_aead_verify (&operation, temp, max_verify, &out_length, tag, tag_length);
281
+ VerifyOrReturnError (status == PSA_SUCCESS, CHIP_ERROR_INTERNAL);
282
+ VerifyOrReturnError (plaintext_length + out_length <= ciphertext_length, CHIP_ERROR_INTERNAL);
283
+ // Add the decrypted output, if any
284
+ memcpy (&plaintext[plaintext_length], temp, out_length);
285
+ plaintext_length += out_length;
286
+ VerifyOrReturnError (ciphertext_length == plaintext_length, CHIP_ERROR_INTERNAL);
287
+ }
208
288
VerifyOrReturnError (status == PSA_SUCCESS, CHIP_ERROR_INTERNAL);
209
289
210
290
return CHIP_NO_ERROR;
0 commit comments