-
Notifications
You must be signed in to change notification settings - Fork 6
Open
Description
Create an iRoPE embedding in PyTorch
Proposed solution to test onece PyTorch cerebros model is ready
import math
import torch
import torch.nn as nn
# ------------- RotaryEmbedding -------------
class RotaryEmbedding(nn.Module):
"""
Generates the (sin, cos) tensors used for the interleaved RoPE
described in the original RoPE paper.
"""
def __init__(self, dim: int, max_seq_len: int = 1024, temperature: float = 10000.0):
super().__init__()
if dim % 2 != 0:
raise ValueError(
f"Embedding dimension `dim` ({dim}) must be even for RotaryEmbedding."
)
self.dim = dim
self.max_seq_len = max_seq_len
self.temperature = float(temperature)
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Args
----
x : FloatTensor (B, T, C) – the incoming activations
Returns
-------
sin, cos : (B, T, C) – broadcast-ready sin / cos tensors
"""
B, T, _ = x.shape
device, dtype = x.device, x.dtype
# Compute inverse frequencies [dim/2]
inv_freq = 1.0 / (
self.temperature ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
)
# Positions [T]
position = torch.arange(T, device=device, dtype=torch.float32)
# Outer product → [T, dim/2]
sinusoid_inp = torch.einsum("i,j->ij", position, inv_freq)
# Repeat to interleave: [a, b] → [a, a, b, b, ...]
sin = torch.sin(sinusoid_inp).repeat_interleave(2, dim=-1)
cos = torch.cos(sinusoid_inp).repeat_interleave(2, dim=-1)
# Add batch dimension and broadcast
sin = sin.unsqueeze(0).expand(B, T, -1).to(dtype)
cos = cos.unsqueeze(0).expand(B, T, -1).to(dtype)
return sin, cos
# ------------- InterleavedRoPE -------------
class InterleavedRoPE(nn.Module):
"""
Applies rotary positional embeddings to the input tensor.
"""
def __init__(self, dim: int, max_seq_len: int = 1024):
super().__init__()
if dim % 2 != 0:
raise ValueError(
f"Embedding dimension `dim` ({dim}) must be even for InterleavedRoPE."
)
self.dim = dim
self.max_seq_len = max_seq_len
self.rotary_emb = RotaryEmbedding(dim, max_seq_len)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args
----
x : FloatTensor (B, T, C)
Returns
-------
FloatTensor (B, T, C)
"""
sin, cos = self.rotary_emb(x)
return apply_rotary_pos_emb(x, sin, cos)
# ------------- Helpers (exact Keras semantics) -------------
def split_alternate(x: torch.Tensor) -> torch.Tensor:
"""
Re-arranges the last dimension so that the even and odd halves
are swapped: [a0, b0, a1, b1, ...] -> [a0, a1, ..., b0, b1, ...]
"""
B, T, C = x.size()
x = x.view(B, T, C // 2, 2).transpose(-2, -1).contiguous()
return x.view(B, T, C)
def rotate_half(x: torch.Tensor) -> torch.Tensor:
"""
Rotate the second half of the vector by π.
"""
x = split_alternate(x)
d = x.size(-1)
x1, x2 = x[..., : d // 2], x[..., d // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(x: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor:
"""
Apply the RoPE formula: x' = x * cos + rotate_half(x) * sin
"""
x = x.float() # ensure fp32 for numerical stability
return (x * cos) + (rotate_half(x) * sin)
Metadata
Metadata
Assignees
Labels
No labels