Skip to content

Create an iRoPE embedding in PyTorch #194

@david-thrower

Description

@david-thrower

Create an iRoPE embedding in PyTorch

Proposed solution to test onece PyTorch cerebros model is ready

import math
import torch
import torch.nn as nn


# -------------  RotaryEmbedding  -------------
class RotaryEmbedding(nn.Module):
    """
    Generates the (sin, cos) tensors used for the interleaved RoPE
    described in the original RoPE paper.
    """
    def __init__(self, dim: int, max_seq_len: int = 1024, temperature: float = 10000.0):
        super().__init__()
        if dim % 2 != 0:
            raise ValueError(
                f"Embedding dimension `dim` ({dim}) must be even for RotaryEmbedding."
            )
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.temperature = float(temperature)

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Args
        ----
        x : FloatTensor (B, T, C)  – the incoming activations
        Returns
        -------
        sin, cos : (B, T, C)       – broadcast-ready sin / cos tensors
        """
        B, T, _ = x.shape
        device, dtype = x.device, x.dtype

        # Compute inverse frequencies [dim/2]
        inv_freq = 1.0 / (
            self.temperature ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
        )

        # Positions [T]
        position = torch.arange(T, device=device, dtype=torch.float32)

        # Outer product → [T, dim/2]
        sinusoid_inp = torch.einsum("i,j->ij", position, inv_freq)

        # Repeat to interleave: [a, b] → [a, a, b, b, ...]
        sin = torch.sin(sinusoid_inp).repeat_interleave(2, dim=-1)
        cos = torch.cos(sinusoid_inp).repeat_interleave(2, dim=-1)

        # Add batch dimension and broadcast
        sin = sin.unsqueeze(0).expand(B, T, -1).to(dtype)
        cos = cos.unsqueeze(0).expand(B, T, -1).to(dtype)

        return sin, cos


# -------------  InterleavedRoPE  -------------
class InterleavedRoPE(nn.Module):
    """
    Applies rotary positional embeddings to the input tensor.
    """
    def __init__(self, dim: int, max_seq_len: int = 1024):
        super().__init__()
        if dim % 2 != 0:
            raise ValueError(
                f"Embedding dimension `dim` ({dim}) must be even for InterleavedRoPE."
            )
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.rotary_emb = RotaryEmbedding(dim, max_seq_len)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args
        ----
        x : FloatTensor (B, T, C)
        Returns
        -------
        FloatTensor (B, T, C)
        """
        sin, cos = self.rotary_emb(x)
        return apply_rotary_pos_emb(x, sin, cos)


# -------------  Helpers (exact Keras semantics) -------------
def split_alternate(x: torch.Tensor) -> torch.Tensor:
    """
    Re-arranges the last dimension so that the even and odd halves
    are swapped:  [a0, b0, a1, b1, ...] -> [a0, a1, ..., b0, b1, ...]
    """
    B, T, C = x.size()
    x = x.view(B, T, C // 2, 2).transpose(-2, -1).contiguous()
    return x.view(B, T, C)


def rotate_half(x: torch.Tensor) -> torch.Tensor:
    """
    Rotate the second half of the vector by π.
    """
    x = split_alternate(x)
    d = x.size(-1)
    x1, x2 = x[..., : d // 2], x[..., d // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(x: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor:
    """
    Apply the RoPE formula:  x' = x * cos + rotate_half(x) * sin
    """
    x = x.float()  # ensure fp32 for numerical stability
    return (x * cos) + (rotate_half(x) * sin)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions