Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type Checking #141

Merged
merged 19 commits into from
Aug 28, 2024
12 changes: 11 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
# Copyright 2024 MosaicML MegaBlocks authors
# Copyright 2024 Databricks authors
# SPDX-License-Identifier: Apache-2.0

default_language_version:
python: python3
repos:
# - repo: local
# hooks:
# - id: pyright
# name: pyright
# entry: pyright
# language: node
# types: [python]
# pass_filenames: false
# args: [--warnings]
# additional_dependencies: ["pyright@1.1.310"]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.2
hooks:
Expand Down
153 changes: 106 additions & 47 deletions megablocks/backend/kernels.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
# Copyright 2024 Databricks
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Optional

import torch
import triton
import triton.language as tl


def assert_is_tensor(x, ndim):
def assert_is_tensor(x: torch.Tensor, ndim: int):
if x.ndim != ndim:
raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')


def assert_is_matrix(x):
def assert_is_matrix(x: torch.Tensor):
assert_is_tensor(x, 2)


def assert_is_vector(x):
def assert_is_vector(x: torch.Tensor):
if x.ndim != 1:
raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')


def assert_equal(a, b):
def assert_equal(a: Any, b: Any):
if a != b:
raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)

Expand All @@ -43,13 +44,13 @@ def assert_equal(a, b):
)
@triton.jit
def _padded_copy(
a,
b,
indices,
bin_ids,
weights,
bins,
padded_bins,
a: torch.Tensor,
b: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
weights: Any,
bins: torch.Tensor,
padded_bins: torch.Tensor,
NUM_COLUMNS: tl.constexpr,
TOP_K: tl.constexpr,
BLOCK_X: tl.constexpr,
Expand Down Expand Up @@ -93,7 +94,8 @@ def _padded_copy(
iptr = a if A_TO_B else b
optr = b if A_TO_B else a

for i in range(tl.cdiv(NUM_COLUMNS, BLOCK_X)):
iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
for _ in range(iterations):
mask = offsets < NUM_COLUMNS
x = tl.load(iptr + offsets, mask=mask)
x = x.to(tl.float32) * scale.to(tl.float32)
Expand All @@ -103,7 +105,15 @@ def _padded_copy(
offsets += BLOCK_X


def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
def padded_gather(
x: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
weights: Optional[torch.Tensor],
bins: torch.Tensor,
padded_bins: torch.Tensor,
top_k: int,
):
# Validate the input shapes.
assert_is_matrix(x)
assert_is_vector(indices)
Expand All @@ -119,7 +129,7 @@ def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):

# NOTE: Because of the padding, the output size is dynamic.
# We load the final padded bin bound to get the output rows.
output_rows = padded_bins[-1].cpu().item()
output_rows = int(padded_bins[-1].cpu().item())
out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
_padded_copy[(indices.shape[0],)](
x,
Expand All @@ -137,7 +147,14 @@ def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
return out


def gather(x, indices, bin_ids, weights, bins, top_k):
def gather(
x: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
weights: Optional[torch.Tensor],
bins: torch.Tensor,
top_k: int,
):
# Validate the input shapes.
assert_is_matrix(x)
assert_is_vector(indices)
Expand Down Expand Up @@ -169,7 +186,15 @@ def gather(x, indices, bin_ids, weights, bins, top_k):
return out


def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
def padded_scatter(
x: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
weights: Optional[torch.Tensor],
bins: torch.Tensor,
padded_bins: torch.Tensor,
top_k: int,
) -> torch.Tensor:
# Validate the input shapes.
assert_is_matrix(x)
assert_is_vector(indices)
Expand Down Expand Up @@ -202,7 +227,14 @@ def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])


def scatter(x, indices, bin_ids, weights, bins, top_k):
def scatter(
x: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
weights: Optional[torch.Tensor],
bins: torch.Tensor,
top_k: int,
) -> torch.Tensor:
return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)


Expand All @@ -225,13 +257,13 @@ def scatter(x, indices, bin_ids, weights, bins, top_k):
)
@triton.jit
def _padded_copy_wgrad(
x,
grad,
wgrad,
indices,
bin_ids,
bins,
padded_bins,
x: torch.Tensor,
grad: torch.Tensor,
wgrad: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
bins: torch.Tensor,
padded_bins: torch.Tensor,
NUM_COLUMNS: tl.constexpr,
TOP_K: tl.constexpr,
BLOCK_X: tl.constexpr,
Expand Down Expand Up @@ -263,7 +295,7 @@ def _padded_copy_wgrad(

acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
for i in range(iterations):
for _ in range(iterations):
mask = offsets < NUM_COLUMNS
data = tl.load(x + offsets, mask=mask).to(tl.float32)
scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
Expand All @@ -275,7 +307,15 @@ def _padded_copy_wgrad(
tl.store(wgrad, out)


def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
def padded_scatter_wgrad(
x: torch.Tensor,
grad: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
bins: torch.Tensor,
padded_bins: torch.Tensor,
top_k: int,
):
# Validate the input shapes.
assert_is_matrix(x)
assert_is_matrix(grad)
Expand All @@ -302,7 +342,14 @@ def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
return out


def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
def scatter_wgrad(
x: torch.Tensor,
grad: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
bins: torch.Tensor,
top_k: int,
):
return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)


Expand All @@ -323,13 +370,13 @@ def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
)
@triton.jit
def _binned_copy(
a,
b,
num_experts,
expert_capacity,
indices,
weights,
bins,
a: torch.Tensor,
b: torch.Tensor,
num_experts: int,
expert_capacity: int,
indices: torch.Tensor,
weights, #: Optional[torch.Tensor],
bins: torch.Tensor,
NUM_COLUMNS: tl.constexpr,
TOP_K: tl.constexpr,
BLOCK_X: tl.constexpr,
Expand Down Expand Up @@ -378,7 +425,7 @@ def _binned_copy(
optr = b if A_TO_B else a

iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
for i in range(iterations):
for _ in range(iterations):
mask = offsets < NUM_COLUMNS
x = tl.load(iptr + offsets, mask=mask)
x = x.to(tl.float32) * scale.to(tl.float32)
Expand All @@ -388,7 +435,14 @@ def _binned_copy(
offsets += BLOCK_X


def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
def binned_gather(
x: torch.Tensor,
indices: torch.Tensor,
weights: Optional[torch.Tensor],
bins: torch.Tensor,
expert_capacity: int,
top_k: int,
):
# Validate the input shapes.
assert_is_matrix(x)
assert_is_vector(indices)
Expand All @@ -400,7 +454,6 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k):

num_experts = bins.shape[0]
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)

_binned_copy[(num_experts, expert_capacity)](
x,
out,
Expand All @@ -417,7 +470,13 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
return out


def binned_scatter(x, indices, weights, bins, top_k):
def binned_scatter(
x: torch.Tensor,
indices: torch.Tensor,
weights: Optional[torch.Tensor],
bins: torch.Tensor,
top_k: int,
):
# Validate the input shapes.
assert_is_tensor(x, 3)
assert_is_vector(indices)
Expand Down Expand Up @@ -465,13 +524,13 @@ def binned_scatter(x, indices, weights, bins, top_k):
)
@triton.jit
def _binned_copy_wgrad(
x,
grad,
wgrad,
num_experts,
expert_capacity,
indices,
bins,
x: torch.Tensor,
grad: torch.Tensor,
wgrad: torch.Tensor,
num_experts: int,
expert_capacity: int,
indices: torch.Tensor,
bins: torch.Tensor,
NUM_COLUMNS: tl.constexpr,
TOP_K: tl.constexpr,
BLOCK_X: tl.constexpr,
Expand Down Expand Up @@ -505,7 +564,7 @@ def _binned_copy_wgrad(

acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
for i in range(iterations):
for _ in range(iterations):
mask = offsets < NUM_COLUMNS
data = tl.load(x + offsets, mask=mask).to(tl.float32)
scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
Expand All @@ -517,7 +576,7 @@ def _binned_copy_wgrad(
tl.store(wgrad, out)


def binned_scatter_wgrad(x, grad, indices, bins, top_k):
def binned_scatter_wgrad(x: torch.Tensor, grad: torch.Tensor, indices: torch.Tensor, bins: torch.Tensor, top_k: int):
# Validate the input shapes.
assert_is_tensor(x, 3)
assert_is_matrix(grad)
Expand Down
17 changes: 11 additions & 6 deletions megablocks/grouped_gemm_util.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
# Copyright 2024 Databricks
# SPDX-License-Identifier: Apache-2.0
import warnings

_grouped_gemm_is_available: bool = False
try:
import grouped_gemm
except ImportError:
grouped_gemm = None
_grouped_gemm_is_available = True
except ImportError as error:
warnings.warn('Grouped GEMM not available.')


def grouped_gemm_is_available():
return grouped_gemm is not None
return _grouped_gemm_is_available


def assert_grouped_gemm_is_available():
assert grouped_gemm_is_available(
), ('Grouped GEMM not available. Please run '
'`pip install git+https://github.com/tgale96/grouped_gemm@main`.')
msg = (
'Grouped GEMM not available. Please run '
'`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
)
assert _grouped_gemm_is_available, msg


backend = grouped_gemm.backend if grouped_gemm_is_available() else None
Expand Down
12 changes: 6 additions & 6 deletions megablocks/layers/activation_fn.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
# Copyright 2024 Databricks
# SPDX-License-Identifier: Apache-2.0

from typing import Callable
from typing import Any, Callable, Union

import stk
import torch
from stk import Matrix


def act_fn(
x: stk.Matrix,
x: Matrix,
function: Callable,
return_grad_fn: bool = False,
**kwargs,
):
assert isinstance(x, stk.Matrix)
) -> Union[tuple[Matrix, Any] | Matrix]:
assert isinstance(x, Matrix)
with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
if return_grad_fn:
x.data.requires_grad = True
out = function(x.data, **kwargs)
y = stk.Matrix(
y = Matrix(
x.size(),
out,
x.row_indices,
Expand Down
Loading
Loading