Skip to content

[PyTorch] Refactor activation offloading of quantized tensors. #1738

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions tests/pytorch/test_cpu_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,22 @@
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_available, reason_for_no_fp8_block = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)

fp8_recipes = [
None, # non-fp8
# recipe.MXFP8BlockScaling(), - scale inverse tensors offloading doest not work yet
recipe.MXFP8BlockScaling(),
recipe.Float8CurrentScaling(),
recipe.DelayedScaling(),
recipe.Float8BlockScaling(),
]

SIZE = 512
SIZE = 64
NUM_HEADS = 8
NUM_LAYERS = 5
EPSILON = 0.1
EPSILON = 0.05

# Flash attention saves some internal tensor for the backward pass
# that cannot be offloaded to CPU.
Expand Down Expand Up @@ -94,6 +98,9 @@ def _measure_memory_between_forward_and_backward(models, fp8_recipe, cpu_offload
tensor = model(tensor)
tensor = sync_function(tensor)

import gc

gc.collect()
max_mem_used = torch.cuda.memory_allocated() / (1024**2)
torch.cuda.synchronize()

Expand All @@ -115,7 +122,6 @@ def test_cpu_offload(fp8_recipe, model_key) -> None:
the difference being the size of the FP8 cache that is not offloaded to the CPU.
We also expect this memory consumption to be smaller than in scenario (1).
"""

model_cls = model_types[model_key]
models_list = [model_cls() for _ in range(NUM_LAYERS)]

Expand All @@ -124,6 +130,8 @@ def test_cpu_offload(fp8_recipe, model_key) -> None:
if fp8_recipe is not None:
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_available:
pytest.skip(reason_for_no_fp8_block)

without_offloading = _measure_memory_between_forward_and_backward(
models_list, fp8_recipe, False
Expand Down
125 changes: 15 additions & 110 deletions transformer_engine/pytorch/cpu_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@

import torch

from .tensor.quantized_tensor import QuantizedTensorBase

from .tensor.float8_tensor import Float8Tensor

__all__ = ["get_cpu_offload_context"]

CPUOffloadEnabled = False
Expand All @@ -23,17 +19,13 @@ def mark_activation_offload(*tensors):
for tensor in tensors:
if tensor is None:
continue
if type(tensor) in [torch.Tensor, torch.nn.Parameter]:
if isinstance(tensor, torch.Tensor):
tensor.activation_offloading = True
else:
data_tensors = tensor.get_data_tensors()
data_tensors = tensor.get_data_tensors(scaling_factors=True)
for tensor in data_tensors:
if tensor is not None:
tensor.activation_offloading = True
# This is a hack to force clear the tensor after it is offloaded.
# It is needed, because .*TensorBase classes are saved in the ctx,
# and they contain the reference to their data tensors.
tensor.needs_force_clear = True


def is_cpu_offload_enabled() -> bool:
Expand Down Expand Up @@ -240,14 +232,11 @@ def on_group_commit_backward(self):
def offload(src_tensor, pin_memory=True):
"""Offload."""

cpu_backup = torch.empty(
src_tensor.size(),
dtype=src_tensor.dtype,
layout=src_tensor.layout,
cpu_backup = torch.empty_like(
src_tensor,
device="cpu",
pin_memory=pin_memory,
)

cpu_backup.copy_(src_tensor, non_blocking=pin_memory)
state = (src_tensor.device, cpu_backup)
return state
Expand Down Expand Up @@ -311,9 +300,6 @@ def __init__(
self.num_layers = num_model_group
# Data Structure to maintain reference to activation tensors
self.tensor_tag_to_buf = {}
# Data structure to hold the FP8/MXFP8 tensor objects
self.fp8_tensor_object_map = {}
self.float8_transpose_cache_valid = {}
# Tracking the number of layers offloaded
self.offloaded_group_count = 0
# Core data structure that decides the window for offloading
Expand Down Expand Up @@ -344,46 +330,19 @@ def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any:
),
)

is_quantized_tensor = isinstance(tensor, QuantizedTensorBase)

if not torch_stray_tensor:

# obtain a unique tensor tag
tensor_tag = (self.current_group, self.tensor_count_current_group)
self.tensor_count_current_group += 1

assert tensor_tag not in self.tensor_tag_to_state

if is_quantized_tensor:
tensor_list, _ = tensor.prepare_for_saving()

self.tensor_tag_to_state[tensor_tag] = []
self.tensor_tag_to_buf[tensor_tag] = []
self.tensor_tag_to_state[tensor_tag] = tensor

self.fp8_tensor_object_map[tensor_tag] = tensor
if isinstance(tensor, Float8Tensor):
self.float8_transpose_cache_valid[tensor_tag] = getattr(
tensor, "_transpose_invalid"
)
else:
tensor_list = [tensor]

for t in tensor_list:
if is_quantized_tensor:
self.tensor_tag_to_state[tensor_tag].append(t)
else:
self.tensor_tag_to_state[tensor_tag] = t

if (
self.current_group < self.num_offload_group
and self.tensor_need_offloading_checker(t)
):
if is_quantized_tensor:
self.tensor_tag_to_buf[tensor_tag].append(t)
# Need to clear the internal data reference for the quantized tensors
tensor.clear()
else:
self.tensor_tag_to_buf[tensor_tag] = t
if self.current_group < self.num_offload_group and self.tensor_need_offloading_checker(
tensor
):
self.tensor_tag_to_buf[tensor_tag] = tensor
else:
tensor_tag = (-1, self.torch_tensor_count)
self.torch_tensor_count += 1
Expand All @@ -395,12 +354,6 @@ def tensor_pop(self, tensor_tag, **kwargs):
"""Tensor pop."""
assert tensor_tag in self.tensor_tag_to_state
tensor = self.tensor_tag_to_state.pop(tensor_tag)

# Handling the quantized tensor case specially here
if isinstance(tensor, list):
self.fp8_tensor_object_map[tensor_tag].restore_from_saved(tensor)
tensor = self.fp8_tensor_object_map.pop(tensor_tag)

self.tensor_tag_to_buf.pop(tensor_tag, None)

# the tensor should have been copied back in on_group_commit_backward()
Expand All @@ -416,36 +369,11 @@ def bulk_offload_group(self, group_to_offload):
if group_id == group_to_offload:
assert not isinstance(state, tuple)

is_quantized_tensor = isinstance(state, list)

if is_quantized_tensor:
tensor_list = state
self.tensor_tag_to_state[tensor_tag] = []
else:
tensor_list = [state]

for tensor_on_device in tensor_list:
# `tensor_offloaded` is a hacky way of dealing with columnwise-only
# quantized tensors for CPU offloading. The complication is due to
# the `rowwise_data` being `None`. The offloading checker incorrectly
# returns `False` and the entire `state` ([None, columnwise_tensor])
# is added to the tensor tag state dict. A better design would change
# how quantized tensors are kept track of in the offload handler.
# Currently at every stage it is ensured that a quantized tensor is a
# list whereas a non-quantized tensor is standalone object, which is
# not good! TODO(@sanandaraj5597)
tensor_offloaded = False
# if offload, return the reference to cpu copy
if self.tensor_need_offloading_checker(tensor_on_device):
tensor_offloaded = True
state = SynchronizedGroupOffloadHandler.offload(tensor_on_device)
if is_quantized_tensor:
if tensor_offloaded:
self.tensor_tag_to_state[tensor_tag].append(state)
else:
self.tensor_tag_to_state[tensor_tag].append(tensor_on_device)
else:
self.tensor_tag_to_state[tensor_tag] = state
tensor_on_device = state
# if offload, return the reference to cpu copy
if self.tensor_need_offloading_checker(tensor_on_device):
state = SynchronizedGroupOffloadHandler.offload(tensor_on_device)
self.tensor_tag_to_state[tensor_tag] = state

def synchronize_on_group_commit_forward(self, current_group):
"""Synchronize on group commit forward."""
Expand All @@ -465,14 +393,8 @@ def synchronize_on_group_commit_forward(self, current_group):
torch.cuda.current_stream().wait_stream(self.d2h_stream)

# Time to free the activation memory after usage
for tensor_tag, tensor_buf in self.tensor_tag_to_buf.items():
for tensor_tag, _ in self.tensor_tag_to_buf.items():
if tensor_tag[0] == self.offloaded_group_count:
if hasattr(tensor_buf, "needs_force_clear"):
# Need to clear activation tensor - sometimes references persist in the code.
# This is the case for example with the Float8TensorBase class,
# which is saved directly inside the ctx while its internal tensors are
# saved inside save_for_backward.
tensor_buf.data = torch.Tensor()
# Release the pointer to the tensor
self.tensor_tag_to_buf[tensor_tag] = None

Expand Down Expand Up @@ -502,23 +424,6 @@ def bulk_reload_group(self, group_to_reload):
if isinstance(state, tuple):
recovered_tensor = SynchronizedGroupOffloadHandler.reload(state)
self.tensor_tag_to_state[tensor_label] = recovered_tensor
elif isinstance(state, list):
tensor_list = []
for state_tuple in state:
if isinstance(state_tuple, tuple):
tensor_list.append(
SynchronizedGroupOffloadHandler.reload(state_tuple)
)
else:
tensor_list.append(state_tuple)
_ = self.fp8_tensor_object_map[tensor_label].restore_from_saved(tensor_list)
if isinstance(self.fp8_tensor_object_map[tensor_label], Float8Tensor):
self.fp8_tensor_object_map[tensor_label]._transpose_invalid = (
self.float8_transpose_cache_valid.pop(tensor_label)
)
self.tensor_tag_to_state[tensor_label] = self.fp8_tensor_object_map.pop(
tensor_label
)

def on_group_commit_backward(self):
# first decrement the current group.
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/optimizers/fused_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def _initialize_state(
if store_param_remainders:
data = torch.zeros_like(param, dtype=torch.int16)
else:
data = torch.empty_like(param, dtype=dtype)
data = torch.empty_like(param.detach().dequantize(), dtype=dtype)
if zero_buffer:
data.zero_()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,15 @@ def restore_from_saved(
self._columnwise_scale_inv = tensors[3]
return tensors[4:]

def get_data_tensors(self):
def get_data_tensors(self, scaling_factors=False):
"""Get this Tensor's data."""
if scaling_factors:
return (
self._rowwise_data,
self._columnwise_data,
self._rowwise_scale_inv,
self._columnwise_scale_inv,
)
return self._rowwise_data, self._columnwise_data

def _transpose_dq_columnwise_output(self, columnwise_dq: torch.Tensor) -> torch.Tensor:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,10 @@ def restore_from_saved(
self._scale_inv = tensors[2]
return tensors[3:]

def get_data_tensors(self):
def get_data_tensors(self, scaling_factors=False):
"""Get this Tensor's data."""
if scaling_factors:
return self._data, self._transpose, self._scale_inv
return self._data, self._transpose

def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,15 @@ def restore_from_saved(
self._columnwise_scale_inv = tensors[3]
return tensors[4:]

def get_data_tensors(self):
def get_data_tensors(self, scaling_factors=False):
"""Get this Tensor's data."""
if scaling_factors:
return (
self._rowwise_data,
self._columnwise_data,
self._rowwise_scale_inv,
self._columnwise_scale_inv,
)
return self._rowwise_data, self._columnwise_data

def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor:
Expand Down
55 changes: 55 additions & 0 deletions transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,42 @@ def clone(self) -> Float8BlockwiseQTensor:
},
)

def empty_like(self, *args, **kwargs):
"""Create a new empty tensor with the same shape and type as this tensor"""
new_rowwise_data = (
torch.empty_like(self._rowwise_data, *args, **kwargs)
if self._rowwise_data is not None
else None
)
new_columnwise_data = (
torch.empty_like(self._columnwise_data, *args, **kwargs)
if self._columnwise_data is not None
else None
)
new_rowwise_scale_inv = (
torch.empty_like(self._rowwise_scale_inv, *args, **kwargs)
if self._rowwise_scale_inv is not None
else None
)
new_columnwise_scale_inv = (
torch.empty_like(self._columnwise_scale_inv, *args, **kwargs)
if self._columnwise_scale_inv is not None
else None
)

return Float8BlockwiseQTensor(
shape=self.shape,
dtype=self.dtype,
fp8_dtype=self._fp8_dtype,
rowwise_data=new_rowwise_data,
rowwise_scale_inv=new_rowwise_scale_inv,
columnwise_data=new_columnwise_data,
columnwise_scale_inv=new_columnwise_scale_inv,
quantizer=self._quantizer,
is_2D_scaled=self._is_2D_scaled,
requires_grad=self.requires_grad,
)

def view(self, *shape: Tuple[int]) -> Float8BlockwiseQTensor:
# pylint: disable=missing-function-docstring
return _ViewFunc.apply(self, shape)
Expand Down Expand Up @@ -357,6 +393,25 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
)
return Float8BlockwiseQTensor.make_like(tensor)

if func == torch.ops.aten.copy_.default:
dst, src = args[0], args[1]
if isinstance(src, Float8BlockwiseQTensor) and isinstance(dst, Float8BlockwiseQTensor):
if dst._rowwise_data is not None:
dst._rowwise_data.copy_(src._rowwise_data)
if dst._rowwise_scale_inv is not None:
dst._rowwise_scale_inv.copy_(src._rowwise_scale_inv)
if dst._columnwise_data is not None:
dst._columnwise_data.copy_(src._columnwise_data)
if dst._columnwise_scale_inv is not None:
dst._columnwise_scale_inv.copy_(src._columnwise_scale_inv)
return dst
elif func == torch.ops.aten.is_pinned.default:
if args[0]._rowwise_data is not None:
return args[0]._rowwise_data.is_pinned()
if args[0]._columnwise_data is not None:
return args[0]._columnwise_data.is_pinned()
raise RuntimeError("Cannot check if pinned for Float8BlockwiseQTensor with no data.")

# Default case
return super().__torch_dispatch__(func, types, args, kwargs)

Expand Down
Loading
Loading