Skip to content

Commit 79fbb29

Browse files
jiayulufacebook-github-bot
authored andcommitted
torchrec changes v2 (#3274)
Summary: Pull Request resolved: #3274 add utility function to facilitate embedding quantization. Reviewed By: iamzainhuda, liangbeixu Differential Revision: D80051435 fbshipit-source-id: 3a28da01b1a48f859e6fdc33117055e301de4893
1 parent 3795566 commit 79fbb29

File tree

6 files changed

+338
-10
lines changed

6 files changed

+338
-10
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,10 +1104,22 @@ def init_parameters(self) -> None:
11041104
self.split_embedding_weights(),
11051105
):
11061106
assert param.shape == (rows, emb_dim) # pyre-ignore[16]
1107-
param.data.uniform_( # pyre-ignore[16]
1108-
weight_init_min,
1109-
weight_init_max,
1110-
)
1107+
if param.data.dtype in [ # pyre-ignore[16]
1108+
torch.float8_e4m3fn,
1109+
torch.float8_e5m2,
1110+
]:
1111+
tmp_param = torch.zeros(
1112+
param.shape, device=param.device # pyre-ignore[16]
1113+
)
1114+
tmp_param.uniform_(weight_init_min, weight_init_max).to(
1115+
param.data.dtype
1116+
)
1117+
param.data.copy_(tmp_param)
1118+
else:
1119+
param.data.uniform_(
1120+
weight_init_min,
1121+
weight_init_max,
1122+
)
11111123

11121124
def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
11131125
return self.emb_module(
@@ -1914,10 +1926,22 @@ def init_parameters(self) -> None:
19141926
self.split_embedding_weights(),
19151927
):
19161928
assert param.shape == (rows, emb_dim) # pyre-ignore[16]
1917-
param.data.uniform_( # pyre-ignore[16]
1918-
weight_init_min,
1919-
weight_init_max,
1920-
)
1929+
if param.data.dtype in [ # pyre-ignore[16]
1930+
torch.float8_e4m3fn,
1931+
torch.float8_e5m2,
1932+
]:
1933+
tmp_param = torch.zeros(
1934+
param.shape, device=param.device # pyre-ignore[16]
1935+
)
1936+
tmp_param.uniform_(weight_init_min, weight_init_max).to(
1937+
param.data.dtype
1938+
)
1939+
param.data.copy_(tmp_param)
1940+
else:
1941+
param.data.uniform_(
1942+
weight_init_min,
1943+
weight_init_max,
1944+
)
19211945

19221946
def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
19231947
weights = features.weights_or_none()
@@ -2552,6 +2576,8 @@ def __init__(
25522576
fused_params = config.fused_params or {}
25532577
if "cache_precision" not in fused_params:
25542578
fused_params["cache_precision"] = weights_precision
2579+
if weights_precision == SparseType.NFP8:
2580+
fused_params["cache_precision"] = SparseType.FP16
25552581
self._emb_module: SplitTableBatchedEmbeddingBagsCodegen = (
25562582
SplitTableBatchedEmbeddingBagsCodegen(
25572583
embedding_specs=list(

torchrec/distributed/embeddingbag.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1306,8 +1306,13 @@ def reset_parameters(self) -> None:
13061306
continue
13071307
assert table_config.init_fn is not None
13081308
param = self.embedding_bags[f"{table_config.name}"].weight
1309-
# pyre-ignore
1310-
table_config.init_fn(param)
1309+
if param.data.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
1310+
tmp_param = torch.zeros(param.shape, device=param.device) # pyre-ignore
1311+
table_config.init_fn(tmp_param).to(param.data.dtype) # pyre-ignore
1312+
param.data.copy_(tmp_param) # pyre-ignore
1313+
else:
1314+
# pyre-ignore
1315+
table_config.init_fn(param)
13111316

13121317
sharding_type = self.module_sharding_plan[table_config.name].sharding_type
13131318
if sharding_type == ShardingType.DATA_PARALLEL.value:

torchrec/distributed/tests/test_utils.py

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
import random
1414
import unittest
1515
from typing import cast, List, Optional, Tuple
16+
from unittest.mock import Mock, patch
1617

1718
import torch
1819
import torch.distributed as dist
20+
from fbgemm_gpu.split_table_batched_embeddings_ops_training import SparseType
1921
from hypothesis import given, settings, strategies as st, Verbosity
2022
from torchrec.distributed.embedding_sharding import bucketize_kjt_before_all2all
2123
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
@@ -33,6 +35,7 @@
3335
ShardMetadata,
3436
)
3537
from torchrec.distributed.utils import (
38+
_quantize_embedding_modules,
3639
add_params_from_parameter_sharding,
3740
convert_to_fbgemm_types,
3841
get_bucket_metadata_from_shard_metadata,
@@ -79,6 +82,7 @@ def test_get_unsharded_module_names(self) -> None:
7982
dense_device=device,
8083
sparse_device=device,
8184
)
85+
8286
dmp = DistributedModelParallel(
8387
module=m,
8488
init_data_parallel=False,
@@ -95,6 +99,229 @@ def test_get_unsharded_module_names(self) -> None:
9599
dist.destroy_process_group()
96100

97101

102+
class QuantizeEmbeddingModulesTest(unittest.TestCase):
103+
def test_quantize_embedding_modules(self) -> None:
104+
"""Test that _quantize_embedding_modules correctly converts embedding weight tensors."""
105+
# Create a mock embedding module that mimics SplitTableBatchedEmbeddingBagsCodegen
106+
mock_emb = Mock()
107+
108+
# Create mock tensors that support the operations we need
109+
mock_weights_dev = Mock()
110+
mock_weights_dev.dtype = torch.float32
111+
mock_weights_dev.to.return_value = Mock()
112+
mock_weights_dev.to.return_value.dtype = torch.float16
113+
storage_mock_dev = Mock()
114+
storage_mock_dev.resize_ = Mock()
115+
mock_weights_dev.untyped_storage.return_value = storage_mock_dev
116+
117+
mock_weights_host = Mock()
118+
mock_weights_host.dtype = torch.float32
119+
mock_weights_host.to.return_value = Mock()
120+
mock_weights_host.to.return_value.dtype = torch.float16
121+
storage_mock_host = Mock()
122+
storage_mock_host.resize_ = Mock()
123+
mock_weights_host.untyped_storage.return_value = storage_mock_host
124+
125+
mock_weights_uvm = Mock()
126+
mock_weights_uvm.dtype = torch.float32
127+
mock_weights_uvm.to.return_value = Mock()
128+
mock_weights_uvm.to.return_value.dtype = torch.float16
129+
storage_mock_uvm = Mock()
130+
storage_mock_uvm.resize_ = Mock()
131+
mock_weights_uvm.untyped_storage.return_value = storage_mock_uvm
132+
133+
mock_emb.weights_dev = mock_weights_dev
134+
mock_emb.weights_host = mock_weights_host
135+
mock_emb.weights_uvm = mock_weights_uvm
136+
mock_emb.weights_precision = SparseType.FP32
137+
138+
# Create a module that contains the mock embedding
139+
module = torch.nn.Module()
140+
141+
# Mock the _group_sharded_modules function to return our mock embedding
142+
with patch(
143+
"torchrec.distributed.utils._group_sharded_modules"
144+
) as mock_group_sharded:
145+
mock_group_sharded.return_value = [mock_emb]
146+
147+
# Mock the data_type_to_sparse_type function
148+
with patch(
149+
"torchrec.distributed.utils.data_type_to_sparse_type"
150+
) as mock_convert:
151+
mock_sparse_type = Mock()
152+
mock_sparse_type.as_dtype.return_value = torch.float16
153+
mock_convert.return_value = mock_sparse_type
154+
155+
# Mock the logger
156+
with patch("torchrec.distributed.utils.logger") as mock_logger:
157+
# Call the function with FP16 data type
158+
_quantize_embedding_modules(module, DataType.FP16)
159+
160+
# Verify that _group_sharded_modules was called with the module
161+
mock_group_sharded.assert_called_once_with(module)
162+
163+
# Verify that data_type_to_sparse_type was called with FP16
164+
mock_convert.assert_called_once_with(DataType.FP16)
165+
166+
# Verify that logger.info was called with the expected message
167+
mock_logger.info.assert_called_once_with(
168+
f"convert embedding modules to converted_dtype={DataType.FP16.value} quantization"
169+
)
170+
171+
# Verify that .to() was called on each tensor with the correct dtype
172+
mock_weights_dev.to.assert_called_once_with(torch.float16)
173+
mock_weights_host.to.assert_called_once_with(torch.float16)
174+
mock_weights_uvm.to.assert_called_once_with(torch.float16)
175+
176+
# Verify that the storage resize was called for each tensor
177+
storage_mock_dev.resize_.assert_called_once_with(0)
178+
storage_mock_host.resize_.assert_called_once_with(0)
179+
storage_mock_uvm.resize_.assert_called_once_with(0)
180+
181+
# Verify that weights_precision is correctly set to the converted sparse type
182+
self.assertEqual(mock_emb.weights_precision, mock_sparse_type)
183+
184+
def test_quantize_embedding_modules_no_sharded_modules(self) -> None:
185+
"""Test that _quantize_embedding_modules handles modules with no sharded embeddings."""
186+
# Create a module with no sharded embeddings
187+
module = torch.nn.Module()
188+
189+
# Mock the _group_sharded_modules function to return empty list
190+
with patch(
191+
"torchrec.distributed.utils._group_sharded_modules"
192+
) as mock_group_sharded:
193+
mock_group_sharded.return_value = []
194+
195+
# Mock the data_type_to_sparse_type function
196+
with patch(
197+
"torchrec.distributed.utils.data_type_to_sparse_type"
198+
) as mock_convert:
199+
mock_sparse_type = Mock()
200+
mock_convert.return_value = mock_sparse_type
201+
202+
# Mock the logger
203+
with patch("torchrec.distributed.utils.logger") as mock_logger:
204+
# Call the function - should not raise any errors
205+
_quantize_embedding_modules(module, DataType.FP16)
206+
207+
# Verify that _group_sharded_modules was called
208+
mock_group_sharded.assert_called_once_with(module)
209+
210+
# Verify that data_type_to_sparse_type was called
211+
mock_convert.assert_called_once_with(DataType.FP16)
212+
213+
# Verify that logger.info was called
214+
mock_logger.info.assert_called_once()
215+
216+
def test_quantize_embedding_modules_multiple_embeddings(self) -> None:
217+
"""Test that _quantize_embedding_modules handles multiple embedding modules."""
218+
# Create multiple mock embedding modules
219+
mock_emb1 = Mock()
220+
mock_emb2 = Mock()
221+
222+
# Create fully mocked tensors for first embedding
223+
mock_weights_dev1 = Mock()
224+
mock_weights_dev1.dtype = torch.float32
225+
mock_weights_dev1.to.return_value = Mock()
226+
mock_weights_dev1.to.return_value.dtype = torch.int8
227+
storage_mock_dev1 = Mock()
228+
storage_mock_dev1.resize_ = Mock()
229+
mock_weights_dev1.untyped_storage.return_value = storage_mock_dev1
230+
231+
mock_weights_host1 = Mock()
232+
mock_weights_host1.dtype = torch.float32
233+
mock_weights_host1.to.return_value = Mock()
234+
mock_weights_host1.to.return_value.dtype = torch.int8
235+
storage_mock_host1 = Mock()
236+
storage_mock_host1.resize_ = Mock()
237+
mock_weights_host1.untyped_storage.return_value = storage_mock_host1
238+
239+
mock_weights_uvm1 = Mock()
240+
mock_weights_uvm1.dtype = torch.float32
241+
mock_weights_uvm1.to.return_value = Mock()
242+
mock_weights_uvm1.to.return_value.dtype = torch.int8
243+
storage_mock_uvm1 = Mock()
244+
storage_mock_uvm1.resize_ = Mock()
245+
mock_weights_uvm1.untyped_storage.return_value = storage_mock_uvm1
246+
247+
mock_emb1.weights_dev = mock_weights_dev1
248+
mock_emb1.weights_host = mock_weights_host1
249+
mock_emb1.weights_uvm = mock_weights_uvm1
250+
mock_emb1.weights_precision = SparseType.FP32
251+
252+
# Create fully mocked tensors for second embedding
253+
mock_weights_dev2 = Mock()
254+
mock_weights_dev2.dtype = torch.float32
255+
mock_weights_dev2.to.return_value = Mock()
256+
mock_weights_dev2.to.return_value.dtype = torch.int8
257+
storage_mock_dev2 = Mock()
258+
storage_mock_dev2.resize_ = Mock()
259+
mock_weights_dev2.untyped_storage.return_value = storage_mock_dev2
260+
261+
mock_weights_host2 = Mock()
262+
mock_weights_host2.dtype = torch.float32
263+
mock_weights_host2.to.return_value = Mock()
264+
mock_weights_host2.to.return_value.dtype = torch.int8
265+
storage_mock_host2 = Mock()
266+
storage_mock_host2.resize_ = Mock()
267+
mock_weights_host2.untyped_storage.return_value = storage_mock_host2
268+
269+
mock_weights_uvm2 = Mock()
270+
mock_weights_uvm2.dtype = torch.float32
271+
mock_weights_uvm2.to.return_value = Mock()
272+
mock_weights_uvm2.to.return_value.dtype = torch.int8
273+
storage_mock_uvm2 = Mock()
274+
storage_mock_uvm2.resize_ = Mock()
275+
mock_weights_uvm2.untyped_storage.return_value = storage_mock_uvm2
276+
277+
mock_emb2.weights_dev = mock_weights_dev2
278+
mock_emb2.weights_host = mock_weights_host2
279+
mock_emb2.weights_uvm = mock_weights_uvm2
280+
mock_emb2.weights_precision = SparseType.FP32
281+
282+
# Create a module
283+
module = torch.nn.Module()
284+
285+
# Mock the _group_sharded_modules function to return both mock embeddings
286+
with patch(
287+
"torchrec.distributed.utils._group_sharded_modules"
288+
) as mock_group_sharded:
289+
mock_group_sharded.return_value = [mock_emb1, mock_emb2]
290+
291+
# Mock the data_type_to_sparse_type function
292+
with patch(
293+
"torchrec.distributed.utils.data_type_to_sparse_type"
294+
) as mock_convert:
295+
mock_sparse_type = Mock()
296+
mock_sparse_type.as_dtype.return_value = torch.int8
297+
mock_convert.return_value = mock_sparse_type
298+
299+
# Call the function
300+
_quantize_embedding_modules(module, DataType.INT8)
301+
302+
# Verify that .to() was called on each tensor with the correct dtype
303+
mock_weights_dev1.to.assert_called_once_with(torch.int8)
304+
mock_weights_host1.to.assert_called_once_with(torch.int8)
305+
mock_weights_uvm1.to.assert_called_once_with(torch.int8)
306+
307+
mock_weights_dev2.to.assert_called_once_with(torch.int8)
308+
mock_weights_host2.to.assert_called_once_with(torch.int8)
309+
mock_weights_uvm2.to.assert_called_once_with(torch.int8)
310+
311+
# Verify that the storage resize was called for each tensor
312+
storage_mock_dev1.resize_.assert_called_once_with(0)
313+
storage_mock_host1.resize_.assert_called_once_with(0)
314+
storage_mock_uvm1.resize_.assert_called_once_with(0)
315+
316+
storage_mock_dev2.resize_.assert_called_once_with(0)
317+
storage_mock_host2.resize_.assert_called_once_with(0)
318+
storage_mock_uvm2.resize_.assert_called_once_with(0)
319+
320+
# Verify that weights_precision is correctly set to the converted sparse type
321+
self.assertEqual(mock_emb1.weights_precision, mock_sparse_type)
322+
self.assertEqual(mock_emb2.weights_precision, mock_sparse_type)
323+
324+
98325
def _compute_translated_lengths(
99326
row_indices: List[int],
100327
indices_offsets: List[int],

0 commit comments

Comments
 (0)