13
13
import random
14
14
import unittest
15
15
from typing import cast , List , Optional , Tuple
16
+ from unittest .mock import Mock , patch
16
17
17
18
import torch
18
19
import torch .distributed as dist
20
+ from fbgemm_gpu .split_table_batched_embeddings_ops_training import SparseType
19
21
from hypothesis import given , settings , strategies as st , Verbosity
20
22
from torchrec .distributed .embedding_sharding import bucketize_kjt_before_all2all
21
23
from torchrec .distributed .embeddingbag import EmbeddingBagCollectionSharder
33
35
ShardMetadata ,
34
36
)
35
37
from torchrec .distributed .utils import (
38
+ _quantize_embedding_modules ,
36
39
add_params_from_parameter_sharding ,
37
40
convert_to_fbgemm_types ,
38
41
get_bucket_metadata_from_shard_metadata ,
@@ -79,6 +82,7 @@ def test_get_unsharded_module_names(self) -> None:
79
82
dense_device = device ,
80
83
sparse_device = device ,
81
84
)
85
+
82
86
dmp = DistributedModelParallel (
83
87
module = m ,
84
88
init_data_parallel = False ,
@@ -95,6 +99,229 @@ def test_get_unsharded_module_names(self) -> None:
95
99
dist .destroy_process_group ()
96
100
97
101
102
+ class QuantizeEmbeddingModulesTest (unittest .TestCase ):
103
+ def test_quantize_embedding_modules (self ) -> None :
104
+ """Test that _quantize_embedding_modules correctly converts embedding weight tensors."""
105
+ # Create a mock embedding module that mimics SplitTableBatchedEmbeddingBagsCodegen
106
+ mock_emb = Mock ()
107
+
108
+ # Create mock tensors that support the operations we need
109
+ mock_weights_dev = Mock ()
110
+ mock_weights_dev .dtype = torch .float32
111
+ mock_weights_dev .to .return_value = Mock ()
112
+ mock_weights_dev .to .return_value .dtype = torch .float16
113
+ storage_mock_dev = Mock ()
114
+ storage_mock_dev .resize_ = Mock ()
115
+ mock_weights_dev .untyped_storage .return_value = storage_mock_dev
116
+
117
+ mock_weights_host = Mock ()
118
+ mock_weights_host .dtype = torch .float32
119
+ mock_weights_host .to .return_value = Mock ()
120
+ mock_weights_host .to .return_value .dtype = torch .float16
121
+ storage_mock_host = Mock ()
122
+ storage_mock_host .resize_ = Mock ()
123
+ mock_weights_host .untyped_storage .return_value = storage_mock_host
124
+
125
+ mock_weights_uvm = Mock ()
126
+ mock_weights_uvm .dtype = torch .float32
127
+ mock_weights_uvm .to .return_value = Mock ()
128
+ mock_weights_uvm .to .return_value .dtype = torch .float16
129
+ storage_mock_uvm = Mock ()
130
+ storage_mock_uvm .resize_ = Mock ()
131
+ mock_weights_uvm .untyped_storage .return_value = storage_mock_uvm
132
+
133
+ mock_emb .weights_dev = mock_weights_dev
134
+ mock_emb .weights_host = mock_weights_host
135
+ mock_emb .weights_uvm = mock_weights_uvm
136
+ mock_emb .weights_precision = SparseType .FP32
137
+
138
+ # Create a module that contains the mock embedding
139
+ module = torch .nn .Module ()
140
+
141
+ # Mock the _group_sharded_modules function to return our mock embedding
142
+ with patch (
143
+ "torchrec.distributed.utils._group_sharded_modules"
144
+ ) as mock_group_sharded :
145
+ mock_group_sharded .return_value = [mock_emb ]
146
+
147
+ # Mock the data_type_to_sparse_type function
148
+ with patch (
149
+ "torchrec.distributed.utils.data_type_to_sparse_type"
150
+ ) as mock_convert :
151
+ mock_sparse_type = Mock ()
152
+ mock_sparse_type .as_dtype .return_value = torch .float16
153
+ mock_convert .return_value = mock_sparse_type
154
+
155
+ # Mock the logger
156
+ with patch ("torchrec.distributed.utils.logger" ) as mock_logger :
157
+ # Call the function with FP16 data type
158
+ _quantize_embedding_modules (module , DataType .FP16 )
159
+
160
+ # Verify that _group_sharded_modules was called with the module
161
+ mock_group_sharded .assert_called_once_with (module )
162
+
163
+ # Verify that data_type_to_sparse_type was called with FP16
164
+ mock_convert .assert_called_once_with (DataType .FP16 )
165
+
166
+ # Verify that logger.info was called with the expected message
167
+ mock_logger .info .assert_called_once_with (
168
+ f"convert embedding modules to converted_dtype={ DataType .FP16 .value } quantization"
169
+ )
170
+
171
+ # Verify that .to() was called on each tensor with the correct dtype
172
+ mock_weights_dev .to .assert_called_once_with (torch .float16 )
173
+ mock_weights_host .to .assert_called_once_with (torch .float16 )
174
+ mock_weights_uvm .to .assert_called_once_with (torch .float16 )
175
+
176
+ # Verify that the storage resize was called for each tensor
177
+ storage_mock_dev .resize_ .assert_called_once_with (0 )
178
+ storage_mock_host .resize_ .assert_called_once_with (0 )
179
+ storage_mock_uvm .resize_ .assert_called_once_with (0 )
180
+
181
+ # Verify that weights_precision is correctly set to the converted sparse type
182
+ self .assertEqual (mock_emb .weights_precision , mock_sparse_type )
183
+
184
+ def test_quantize_embedding_modules_no_sharded_modules (self ) -> None :
185
+ """Test that _quantize_embedding_modules handles modules with no sharded embeddings."""
186
+ # Create a module with no sharded embeddings
187
+ module = torch .nn .Module ()
188
+
189
+ # Mock the _group_sharded_modules function to return empty list
190
+ with patch (
191
+ "torchrec.distributed.utils._group_sharded_modules"
192
+ ) as mock_group_sharded :
193
+ mock_group_sharded .return_value = []
194
+
195
+ # Mock the data_type_to_sparse_type function
196
+ with patch (
197
+ "torchrec.distributed.utils.data_type_to_sparse_type"
198
+ ) as mock_convert :
199
+ mock_sparse_type = Mock ()
200
+ mock_convert .return_value = mock_sparse_type
201
+
202
+ # Mock the logger
203
+ with patch ("torchrec.distributed.utils.logger" ) as mock_logger :
204
+ # Call the function - should not raise any errors
205
+ _quantize_embedding_modules (module , DataType .FP16 )
206
+
207
+ # Verify that _group_sharded_modules was called
208
+ mock_group_sharded .assert_called_once_with (module )
209
+
210
+ # Verify that data_type_to_sparse_type was called
211
+ mock_convert .assert_called_once_with (DataType .FP16 )
212
+
213
+ # Verify that logger.info was called
214
+ mock_logger .info .assert_called_once ()
215
+
216
+ def test_quantize_embedding_modules_multiple_embeddings (self ) -> None :
217
+ """Test that _quantize_embedding_modules handles multiple embedding modules."""
218
+ # Create multiple mock embedding modules
219
+ mock_emb1 = Mock ()
220
+ mock_emb2 = Mock ()
221
+
222
+ # Create fully mocked tensors for first embedding
223
+ mock_weights_dev1 = Mock ()
224
+ mock_weights_dev1 .dtype = torch .float32
225
+ mock_weights_dev1 .to .return_value = Mock ()
226
+ mock_weights_dev1 .to .return_value .dtype = torch .int8
227
+ storage_mock_dev1 = Mock ()
228
+ storage_mock_dev1 .resize_ = Mock ()
229
+ mock_weights_dev1 .untyped_storage .return_value = storage_mock_dev1
230
+
231
+ mock_weights_host1 = Mock ()
232
+ mock_weights_host1 .dtype = torch .float32
233
+ mock_weights_host1 .to .return_value = Mock ()
234
+ mock_weights_host1 .to .return_value .dtype = torch .int8
235
+ storage_mock_host1 = Mock ()
236
+ storage_mock_host1 .resize_ = Mock ()
237
+ mock_weights_host1 .untyped_storage .return_value = storage_mock_host1
238
+
239
+ mock_weights_uvm1 = Mock ()
240
+ mock_weights_uvm1 .dtype = torch .float32
241
+ mock_weights_uvm1 .to .return_value = Mock ()
242
+ mock_weights_uvm1 .to .return_value .dtype = torch .int8
243
+ storage_mock_uvm1 = Mock ()
244
+ storage_mock_uvm1 .resize_ = Mock ()
245
+ mock_weights_uvm1 .untyped_storage .return_value = storage_mock_uvm1
246
+
247
+ mock_emb1 .weights_dev = mock_weights_dev1
248
+ mock_emb1 .weights_host = mock_weights_host1
249
+ mock_emb1 .weights_uvm = mock_weights_uvm1
250
+ mock_emb1 .weights_precision = SparseType .FP32
251
+
252
+ # Create fully mocked tensors for second embedding
253
+ mock_weights_dev2 = Mock ()
254
+ mock_weights_dev2 .dtype = torch .float32
255
+ mock_weights_dev2 .to .return_value = Mock ()
256
+ mock_weights_dev2 .to .return_value .dtype = torch .int8
257
+ storage_mock_dev2 = Mock ()
258
+ storage_mock_dev2 .resize_ = Mock ()
259
+ mock_weights_dev2 .untyped_storage .return_value = storage_mock_dev2
260
+
261
+ mock_weights_host2 = Mock ()
262
+ mock_weights_host2 .dtype = torch .float32
263
+ mock_weights_host2 .to .return_value = Mock ()
264
+ mock_weights_host2 .to .return_value .dtype = torch .int8
265
+ storage_mock_host2 = Mock ()
266
+ storage_mock_host2 .resize_ = Mock ()
267
+ mock_weights_host2 .untyped_storage .return_value = storage_mock_host2
268
+
269
+ mock_weights_uvm2 = Mock ()
270
+ mock_weights_uvm2 .dtype = torch .float32
271
+ mock_weights_uvm2 .to .return_value = Mock ()
272
+ mock_weights_uvm2 .to .return_value .dtype = torch .int8
273
+ storage_mock_uvm2 = Mock ()
274
+ storage_mock_uvm2 .resize_ = Mock ()
275
+ mock_weights_uvm2 .untyped_storage .return_value = storage_mock_uvm2
276
+
277
+ mock_emb2 .weights_dev = mock_weights_dev2
278
+ mock_emb2 .weights_host = mock_weights_host2
279
+ mock_emb2 .weights_uvm = mock_weights_uvm2
280
+ mock_emb2 .weights_precision = SparseType .FP32
281
+
282
+ # Create a module
283
+ module = torch .nn .Module ()
284
+
285
+ # Mock the _group_sharded_modules function to return both mock embeddings
286
+ with patch (
287
+ "torchrec.distributed.utils._group_sharded_modules"
288
+ ) as mock_group_sharded :
289
+ mock_group_sharded .return_value = [mock_emb1 , mock_emb2 ]
290
+
291
+ # Mock the data_type_to_sparse_type function
292
+ with patch (
293
+ "torchrec.distributed.utils.data_type_to_sparse_type"
294
+ ) as mock_convert :
295
+ mock_sparse_type = Mock ()
296
+ mock_sparse_type .as_dtype .return_value = torch .int8
297
+ mock_convert .return_value = mock_sparse_type
298
+
299
+ # Call the function
300
+ _quantize_embedding_modules (module , DataType .INT8 )
301
+
302
+ # Verify that .to() was called on each tensor with the correct dtype
303
+ mock_weights_dev1 .to .assert_called_once_with (torch .int8 )
304
+ mock_weights_host1 .to .assert_called_once_with (torch .int8 )
305
+ mock_weights_uvm1 .to .assert_called_once_with (torch .int8 )
306
+
307
+ mock_weights_dev2 .to .assert_called_once_with (torch .int8 )
308
+ mock_weights_host2 .to .assert_called_once_with (torch .int8 )
309
+ mock_weights_uvm2 .to .assert_called_once_with (torch .int8 )
310
+
311
+ # Verify that the storage resize was called for each tensor
312
+ storage_mock_dev1 .resize_ .assert_called_once_with (0 )
313
+ storage_mock_host1 .resize_ .assert_called_once_with (0 )
314
+ storage_mock_uvm1 .resize_ .assert_called_once_with (0 )
315
+
316
+ storage_mock_dev2 .resize_ .assert_called_once_with (0 )
317
+ storage_mock_host2 .resize_ .assert_called_once_with (0 )
318
+ storage_mock_uvm2 .resize_ .assert_called_once_with (0 )
319
+
320
+ # Verify that weights_precision is correctly set to the converted sparse type
321
+ self .assertEqual (mock_emb1 .weights_precision , mock_sparse_type )
322
+ self .assertEqual (mock_emb2 .weights_precision , mock_sparse_type )
323
+
324
+
98
325
def _compute_translated_lengths (
99
326
row_indices : List [int ],
100
327
indices_offsets : List [int ],
0 commit comments