Skip to content

Layers

Base

CliffordModule

Bases: Module

Base module for Clifford algebra layers.

This module securely stores a reference to a shared CliffordAlgebra instance without registering it as a PyTorch submodule. In the Versor architecture, a single algebra instance (which contains precomputed geometric tensors) is heavily shared across multiple layers.

By bypassing standard submodule registration (via object.__setattr__) and overriding _apply, this base class ensures that: 1. No ownership conflicts occur in PyTorch's computational graph. 2. Device and dtype casting (e.g., .to(device), .cuda(), .half()) are automatically and safely propagated to the shared algebra buffers.

Source code in layers/primitives/base.py
class CliffordModule(nn.Module):
    """Base module for Clifford algebra layers.

    This module securely stores a reference to a shared ``CliffordAlgebra`` instance 
    without registering it as a PyTorch submodule. In the Versor architecture, 
    a single algebra instance (which contains precomputed geometric tensors) 
    is heavily shared across multiple layers. 

    By bypassing standard submodule registration (via ``object.__setattr__``) and 
    overriding ``_apply``, this base class ensures that:
    1. No ownership conflicts occur in PyTorch's computational graph.
    2. Device and dtype casting (e.g., ``.to(device)``, ``.cuda()``, ``.half()``) 
       are automatically and safely propagated to the shared algebra buffers.
    """

    def __init__(self, algebra: CliffordAlgebra):
        """Sets up the module.

        Args:
            algebra (CliffordAlgebra): The algebra instance.
        """
        super().__init__()
        # Bypass nn.Module.__setattr__ to avoid registering algebra as submodule.
        # Multiple layers share the same algebra - only one should "own" it.
        object.__setattr__(self, '_algebra', algebra)

    @property
    def algebra(self) -> CliffordAlgebra:
        """Return the algebra instance."""
        return self._algebra

    @property
    def p(self):
        return self._algebra.p

    @property
    def q(self):
        return self._algebra.q

    @property
    def r(self):
        return self._algebra.r

    def _apply(self, fn):
        """Override to also move the shared algebra tables."""
        result = super()._apply(fn)
        if self._algebra is not None:
            self._algebra._apply(fn)
        return result

    def forward(self, x):
        """Performs the forward pass computation."""
        raise NotImplementedError

algebra property

Return the algebra instance.

__init__(algebra)

Sets up the module.

Parameters:

Name Type Description Default
algebra CliffordAlgebra

The algebra instance.

required
Source code in layers/primitives/base.py
def __init__(self, algebra: CliffordAlgebra):
    """Sets up the module.

    Args:
        algebra (CliffordAlgebra): The algebra instance.
    """
    super().__init__()
    # Bypass nn.Module.__setattr__ to avoid registering algebra as submodule.
    # Multiple layers share the same algebra - only one should "own" it.
    object.__setattr__(self, '_algebra', algebra)

forward(x)

Performs the forward pass computation.

Source code in layers/primitives/base.py
def forward(self, x):
    """Performs the forward pass computation."""
    raise NotImplementedError

Primitives

RotorLayer

Bases: CliffordModule

Learnable versor layer with universal grade parameterization.

For grade=2 (default): learns R = exp(-B/2) and applies the isometry x' = RxR~. For grade=k: learns a grade-k element V and applies the versor product x' = hat(V) x V^{-1}, where hat denotes grade involution.

Preserves origin. For grade=2, also preserves lengths and angles (isometry).

The exp strategy (closed-form vs decomposition) is controlled by algebra.exp_policy -- see :class:core.decomposition.ExpPolicy.

Attributes:

Name Type Description
channels int

Number of versors.

grade int

Grade of the learnable parameter. Default 2 (bivector → rotor).

grade_weights Parameter

Learnable grade-k coefficients [channels, num_grade_elements].

Source code in layers/primitives/rotor.py
class RotorLayer(CliffordModule):
    """Learnable versor layer with universal grade parameterization.

    For grade=2 (default): learns R = exp(-B/2) and applies the isometry x' = RxR~.
    For grade=k: learns a grade-k element V and applies the versor product
    x' = hat(V) x V^{-1}, where hat denotes grade involution.

    Preserves origin. For grade=2, also preserves lengths and angles (isometry).

    The exp strategy (closed-form vs decomposition) is controlled by
    ``algebra.exp_policy`` -- see :class:`core.decomposition.ExpPolicy`.

    Attributes:
        channels (int): Number of versors.
        grade (int): Grade of the learnable parameter. Default 2 (bivector → rotor).
        grade_weights (nn.Parameter): Learnable grade-k coefficients [channels, num_grade_elements].
    """

    def __init__(
        self,
        algebra: CliffordAlgebra,
        channels: int,
        grade: int = 2,
    ):
        """Initialize the versor layer.

        Args:
            algebra (CliffordAlgebra): The algebra instance.
            channels (int): Number of features.
            grade (int): Grade of the learnable parameter.
                grade=2 (default): bivectors → rotors via exp(-B/2), Spin group.
                grade=1: vectors → reflections via hat(n) x n^{-1}, Pin group.
                grade=k: general grade-k versor product.
        """
        super().__init__(algebra)
        self.channels = channels
        self.grade = grade

        grade_mask = algebra.grade_masks[grade]
        self.register_buffer('grade_indices', grade_mask.nonzero(as_tuple=False).squeeze(-1))
        self.num_grade_elements = len(self.grade_indices)

        self.grade_weights = nn.Parameter(torch.Tensor(channels, self.num_grade_elements))
        if grade == 2:
            self.grade_weights._manifold = 'spin'

        # Versor cache for eval mode
        self._cached_V_left = None
        self._cached_V_right = None

        self.reset_parameters()

    # --- Backward-compat aliases (grade == 2 usage) ---

    @property
    def bivector_indices(self):
        return self.grade_indices

    @property
    def num_bivectors(self):
        return self.num_grade_elements

    @property
    def bivector_weights(self):
        return self.grade_weights

    # ---------------------------------------------------

    def reset_parameters(self):
        """Initialize with near-identity transform (small weights)."""
        nn.init.normal_(self.grade_weights, std=0.01)

    def _build_grade_element(self, device, dtype):
        """Scatter grade_weights into full multivector dimension [channels, dim]."""
        V = torch.zeros(self.channels, self.algebra.dim, device=device, dtype=dtype)
        indices = self.grade_indices.unsqueeze(0).expand(self.channels, -1)
        V.scatter_(1, indices, self.grade_weights)
        return V

    def _compute_versors(self, device, dtype):
        """Compute left and right factors for per_channel_sandwich.

        For grade=2: left = R = exp(-B/2), right = R~ (reverse).
        For grade=k: left = hat(V) (grade involution), right = V^{-1} (blade inverse).
          V is L2-normalized per channel before inversion so that blade_inverse
          remains exact (norm_sq is purely scalar for unit-norm grade-k elements).

        Returns:
            Tuple[Tensor, Tensor]: (V_left [C, dim], V_right [C, dim])
        """
        V = self._build_grade_element(device, dtype)
        if self.grade == 2:
            R = self.algebra.exp(-0.5 * V)
            return R, self.algebra.reverse(R)
        else:
            # Normalize per channel so blade_inverse is exact.
            # For a unit-norm grade-k element, V * V_rev = scalar everywhere.
            norm = V.norm(dim=-1, keepdim=True).clamp(min=1e-8)
            V = V / norm
            return self.algebra.grade_involution(V), self.algebra.blade_inverse(V)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply versor product x' = hat(V) x V^{-1} (= RxR~ for grade=2).

        Caches versors during eval mode for faster inference.

        Args:
            x (torch.Tensor): Input [Batch, Channels, Dim].

        Returns:
            torch.Tensor: Transformed input [Batch, Channels, Dim].
        """
        check_multivector(x, self.algebra, "RotorLayer input")
        check_channels(x, self.channels, "RotorLayer input")

        if not self.training and self._cached_V_left is not None:
            V_left, V_right = self._cached_V_left, self._cached_V_right
        else:
            V_left, V_right = self._compute_versors(x.device, x.dtype)
            if not self.training:
                self._cached_V_left = V_left
                self._cached_V_right = V_right

        return self.algebra.per_channel_sandwich(V_left, x, V_right)

    def train(self, mode: bool = True):
        """Invalidate versor cache when switching to train mode."""
        if mode:
            self._cached_V_left = None
            self._cached_V_right = None
        return super().train(mode)

    def prune_bivectors(self, threshold: float = 1e-4) -> int:
        """Zero out grade weights below threshold.

        Args:
            threshold (float): Cutoff magnitude.

        Returns:
            int: Number of pruned parameters.
        """
        with torch.no_grad():
            mask = torch.abs(self.grade_weights) >= threshold
            num_pruned = (~mask).sum().item()
            self.grade_weights.data.mul_(mask.float())
        return num_pruned

    def sparsity_loss(self) -> torch.Tensor:
        """Compute L1 sparsity regularization on grade weights."""
        return torch.norm(self.grade_weights, p=1)

__init__(algebra, channels, grade=2)

Initialize the versor layer.

Parameters:

Name Type Description Default
algebra CliffordAlgebra

The algebra instance.

required
channels int

Number of features.

required
grade int

Grade of the learnable parameter. grade=2 (default): bivectors → rotors via exp(-B/2), Spin group. grade=1: vectors → reflections via hat(n) x n^{-1}, Pin group. grade=k: general grade-k versor product.

2
Source code in layers/primitives/rotor.py
def __init__(
    self,
    algebra: CliffordAlgebra,
    channels: int,
    grade: int = 2,
):
    """Initialize the versor layer.

    Args:
        algebra (CliffordAlgebra): The algebra instance.
        channels (int): Number of features.
        grade (int): Grade of the learnable parameter.
            grade=2 (default): bivectors → rotors via exp(-B/2), Spin group.
            grade=1: vectors → reflections via hat(n) x n^{-1}, Pin group.
            grade=k: general grade-k versor product.
    """
    super().__init__(algebra)
    self.channels = channels
    self.grade = grade

    grade_mask = algebra.grade_masks[grade]
    self.register_buffer('grade_indices', grade_mask.nonzero(as_tuple=False).squeeze(-1))
    self.num_grade_elements = len(self.grade_indices)

    self.grade_weights = nn.Parameter(torch.Tensor(channels, self.num_grade_elements))
    if grade == 2:
        self.grade_weights._manifold = 'spin'

    # Versor cache for eval mode
    self._cached_V_left = None
    self._cached_V_right = None

    self.reset_parameters()

reset_parameters()

Initialize with near-identity transform (small weights).

Source code in layers/primitives/rotor.py
def reset_parameters(self):
    """Initialize with near-identity transform (small weights)."""
    nn.init.normal_(self.grade_weights, std=0.01)

forward(x)

Apply versor product x' = hat(V) x V^{-1} (= RxR~ for grade=2).

Caches versors during eval mode for faster inference.

Parameters:

Name Type Description Default
x Tensor

Input [Batch, Channels, Dim].

required

Returns:

Type Description
Tensor

torch.Tensor: Transformed input [Batch, Channels, Dim].

Source code in layers/primitives/rotor.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Apply versor product x' = hat(V) x V^{-1} (= RxR~ for grade=2).

    Caches versors during eval mode for faster inference.

    Args:
        x (torch.Tensor): Input [Batch, Channels, Dim].

    Returns:
        torch.Tensor: Transformed input [Batch, Channels, Dim].
    """
    check_multivector(x, self.algebra, "RotorLayer input")
    check_channels(x, self.channels, "RotorLayer input")

    if not self.training and self._cached_V_left is not None:
        V_left, V_right = self._cached_V_left, self._cached_V_right
    else:
        V_left, V_right = self._compute_versors(x.device, x.dtype)
        if not self.training:
            self._cached_V_left = V_left
            self._cached_V_right = V_right

    return self.algebra.per_channel_sandwich(V_left, x, V_right)

train(mode=True)

Invalidate versor cache when switching to train mode.

Source code in layers/primitives/rotor.py
def train(self, mode: bool = True):
    """Invalidate versor cache when switching to train mode."""
    if mode:
        self._cached_V_left = None
        self._cached_V_right = None
    return super().train(mode)

prune_bivectors(threshold=0.0001)

Zero out grade weights below threshold.

Parameters:

Name Type Description Default
threshold float

Cutoff magnitude.

0.0001

Returns:

Name Type Description
int int

Number of pruned parameters.

Source code in layers/primitives/rotor.py
def prune_bivectors(self, threshold: float = 1e-4) -> int:
    """Zero out grade weights below threshold.

    Args:
        threshold (float): Cutoff magnitude.

    Returns:
        int: Number of pruned parameters.
    """
    with torch.no_grad():
        mask = torch.abs(self.grade_weights) >= threshold
        num_pruned = (~mask).sum().item()
        self.grade_weights.data.mul_(mask.float())
    return num_pruned

sparsity_loss()

Compute L1 sparsity regularization on grade weights.

Source code in layers/primitives/rotor.py
def sparsity_loss(self) -> torch.Tensor:
    """Compute L1 sparsity regularization on grade weights."""
    return torch.norm(self.grade_weights, p=1)

MultiRotorLayer

Bases: CliffordModule

Multi-versor layer with weighted superposition: x' = sum_k w_k hat(V_k) x V_k^{-1}.

For grade=2 (default): each V_k = exp(-B_k/2) is a rotor, reducing to x' = sum_k w_k R_k x R~_k. For grade=k: each V_k is a grade-k versor applied via the general versor product.

The exp strategy is controlled by algebra.exp_policy.

Attributes:

Name Type Description
channels int

Input features.

num_rotors int

Number of overlapping versors.

grade int

Grade of the learnable parameters. Default 2 (rotors).

rotor_grade_weights Parameter

Grade-k coefficients [num_rotors, num_grade_elements].

weights Parameter

Mixing weights [channels, num_rotors].

Source code in layers/primitives/multi_rotor.py
class MultiRotorLayer(CliffordModule):
    """Multi-versor layer with weighted superposition: x' = sum_k w_k hat(V_k) x V_k^{-1}.

    For grade=2 (default): each V_k = exp(-B_k/2) is a rotor, reducing to
    x' = sum_k w_k R_k x R~_k.
    For grade=k: each V_k is a grade-k versor applied via the general versor product.

    The exp strategy is controlled by ``algebra.exp_policy``.

    Attributes:
        channels (int): Input features.
        num_rotors (int): Number of overlapping versors.
        grade (int): Grade of the learnable parameters. Default 2 (rotors).
        rotor_grade_weights (nn.Parameter): Grade-k coefficients [num_rotors, num_grade_elements].
        weights (nn.Parameter): Mixing weights [channels, num_rotors].
    """

    def __init__(
        self,
        algebra: CliffordAlgebra,
        channels: int,
        num_rotors: int = 8,
        grade: int = 2,
    ):
        """Initialize Multi-Versor Layer.

        Args:
            algebra (CliffordAlgebra): The algebra instance.
            channels (int): Input features.
            num_rotors (int): Number of parallel versor heads.
            grade (int): Grade of the learnable parameter.
                grade=2 (default): bivectors → rotors via exp(-B/2), Spin group.
                grade=k: general grade-k versor product.
        """
        super().__init__(algebra)
        self.channels = channels
        self.num_rotors = num_rotors
        self.grade = grade

        grade_mask = algebra.grade_masks[grade]
        self.register_buffer('grade_indices', grade_mask.nonzero(as_tuple=False).squeeze(-1))
        self.num_grade_elements = len(self.grade_indices)

        self.rotor_grade_weights = nn.Parameter(torch.Tensor(num_rotors, self.num_grade_elements))
        if grade == 2:
            self.rotor_grade_weights._manifold = 'spin'

        # Mixing weights (Euclidean — intentionally untagged)
        self.weights = nn.Parameter(torch.Tensor(channels, num_rotors))

        # Versor cache for eval mode
        self._cached_V_left = None
        self._cached_V_right = None

        self.reset_parameters()

    # --- Backward-compat aliases (grade == 2 usage) ---

    @property
    def bivector_indices(self):
        return self.grade_indices

    @property
    def num_bivectors(self):
        return self.num_grade_elements

    @property
    def rotor_bivectors(self):
        return self.rotor_grade_weights

    # ---------------------------------------------------

    def reset_parameters(self):
        """Initialize with small transforms and uniform mixing weights."""
        nn.init.normal_(self.rotor_grade_weights, std=0.01)
        nn.init.xavier_uniform_(self.weights)

    def _compute_versors(self, device, dtype):
        """Compute left and right factors for all K versors.

        For grade=2: left = R_k = exp(-B_k/2), right = R~_k.
        For grade=k: left = hat(V_k), right = V_k^{-1}.

        Returns:
            Tuple[Tensor, Tensor]: (V_left [K, dim], V_right [K, dim])
        """
        V = torch.zeros(self.num_rotors, self.algebra.dim, device=device, dtype=dtype)
        indices = self.grade_indices.unsqueeze(0).expand(self.num_rotors, -1)
        V.scatter_(1, indices, self.rotor_grade_weights)

        if self.grade == 2:
            R = self.algebra.exp(-0.5 * V)  # [K, D]
            return R, self.algebra.reverse(R)
        else:
            norm = V.norm(dim=-1, keepdim=True).clamp(min=1e-8)
            V = V / norm
            return self.algebra.grade_involution(V), self.algebra.blade_inverse(V)

    def forward(self, x: torch.Tensor, return_invariants: bool = False) -> torch.Tensor:
        """Apply weighted multi-versor superposition.

        Caches versors during eval mode for faster inference.

        Args:
            x (torch.Tensor): Input [Batch, Channels, Dim].
            return_invariants (bool): If True, returns per-grade norms instead of output.

        Returns:
            torch.Tensor: Transformed output [Batch, Channels, Dim].
        """
        check_multivector(x, self.algebra, "MultiRotorLayer input")
        check_channels(x, self.channels, "MultiRotorLayer input")

        if not self.training and self._cached_V_left is not None:
            V_left, V_right = self._cached_V_left, self._cached_V_right
        else:
            V_left, V_right = self._compute_versors(x.device, x.dtype)
            if not self.training:
                self._cached_V_left = V_left
                self._cached_V_right = V_right

        # Action-matrix sandwich: build K matrices once, apply via einsum
        versored_x = self.algebra.multi_rotor_sandwich(
            V_left, x, V_right,
        )  # [B, C, K, D]

        # Weighted superposition
        out = torch.einsum('ck,bcke->bce', self.weights, versored_x)

        if return_invariants:
            return self.algebra.get_grade_norms(out)

        return out

    def train(self, mode: bool = True):
        """Invalidate versor cache when switching to train mode."""
        if mode:
            self._cached_V_left = None
            self._cached_V_right = None
        return super().train(mode)

    def sparsity_loss(self) -> torch.Tensor:
        """Compute L1 sparsity loss for versor weights and mixing weights."""
        return torch.norm(self.rotor_grade_weights, p=1) + torch.norm(self.weights, p=1)

__init__(algebra, channels, num_rotors=8, grade=2)

Initialize Multi-Versor Layer.

Parameters:

Name Type Description Default
algebra CliffordAlgebra

The algebra instance.

required
channels int

Input features.

required
num_rotors int

Number of parallel versor heads.

8
grade int

Grade of the learnable parameter. grade=2 (default): bivectors → rotors via exp(-B/2), Spin group. grade=k: general grade-k versor product.

2
Source code in layers/primitives/multi_rotor.py
def __init__(
    self,
    algebra: CliffordAlgebra,
    channels: int,
    num_rotors: int = 8,
    grade: int = 2,
):
    """Initialize Multi-Versor Layer.

    Args:
        algebra (CliffordAlgebra): The algebra instance.
        channels (int): Input features.
        num_rotors (int): Number of parallel versor heads.
        grade (int): Grade of the learnable parameter.
            grade=2 (default): bivectors → rotors via exp(-B/2), Spin group.
            grade=k: general grade-k versor product.
    """
    super().__init__(algebra)
    self.channels = channels
    self.num_rotors = num_rotors
    self.grade = grade

    grade_mask = algebra.grade_masks[grade]
    self.register_buffer('grade_indices', grade_mask.nonzero(as_tuple=False).squeeze(-1))
    self.num_grade_elements = len(self.grade_indices)

    self.rotor_grade_weights = nn.Parameter(torch.Tensor(num_rotors, self.num_grade_elements))
    if grade == 2:
        self.rotor_grade_weights._manifold = 'spin'

    # Mixing weights (Euclidean — intentionally untagged)
    self.weights = nn.Parameter(torch.Tensor(channels, num_rotors))

    # Versor cache for eval mode
    self._cached_V_left = None
    self._cached_V_right = None

    self.reset_parameters()

reset_parameters()

Initialize with small transforms and uniform mixing weights.

Source code in layers/primitives/multi_rotor.py
def reset_parameters(self):
    """Initialize with small transforms and uniform mixing weights."""
    nn.init.normal_(self.rotor_grade_weights, std=0.01)
    nn.init.xavier_uniform_(self.weights)

forward(x, return_invariants=False)

Apply weighted multi-versor superposition.

Caches versors during eval mode for faster inference.

Parameters:

Name Type Description Default
x Tensor

Input [Batch, Channels, Dim].

required
return_invariants bool

If True, returns per-grade norms instead of output.

False

Returns:

Type Description
Tensor

torch.Tensor: Transformed output [Batch, Channels, Dim].

Source code in layers/primitives/multi_rotor.py
def forward(self, x: torch.Tensor, return_invariants: bool = False) -> torch.Tensor:
    """Apply weighted multi-versor superposition.

    Caches versors during eval mode for faster inference.

    Args:
        x (torch.Tensor): Input [Batch, Channels, Dim].
        return_invariants (bool): If True, returns per-grade norms instead of output.

    Returns:
        torch.Tensor: Transformed output [Batch, Channels, Dim].
    """
    check_multivector(x, self.algebra, "MultiRotorLayer input")
    check_channels(x, self.channels, "MultiRotorLayer input")

    if not self.training and self._cached_V_left is not None:
        V_left, V_right = self._cached_V_left, self._cached_V_right
    else:
        V_left, V_right = self._compute_versors(x.device, x.dtype)
        if not self.training:
            self._cached_V_left = V_left
            self._cached_V_right = V_right

    # Action-matrix sandwich: build K matrices once, apply via einsum
    versored_x = self.algebra.multi_rotor_sandwich(
        V_left, x, V_right,
    )  # [B, C, K, D]

    # Weighted superposition
    out = torch.einsum('ck,bcke->bce', self.weights, versored_x)

    if return_invariants:
        return self.algebra.get_grade_norms(out)

    return out

train(mode=True)

Invalidate versor cache when switching to train mode.

Source code in layers/primitives/multi_rotor.py
def train(self, mode: bool = True):
    """Invalidate versor cache when switching to train mode."""
    if mode:
        self._cached_V_left = None
        self._cached_V_right = None
    return super().train(mode)

sparsity_loss()

Compute L1 sparsity loss for versor weights and mixing weights.

Source code in layers/primitives/multi_rotor.py
def sparsity_loss(self) -> torch.Tensor:
    """Compute L1 sparsity loss for versor weights and mixing weights."""
    return torch.norm(self.rotor_grade_weights, p=1) + torch.norm(self.weights, p=1)

CliffordLinear

Bases: CliffordModule

Fully connected layer with optional rotor-based backend.

Can use either: - Traditional scalar weight matrix (default, backward compatible) - Rotor-based transformation (new, parameter efficient via RotorGadget)

The traditional backend uses O(in_channels x out_channels) parameters, while the rotor backend uses O(num_rotor_pairs x n(n-1)/2) parameters where n is the number of basis vectors.

Attributes:

Name Type Description
in_channels int

Input features.

out_channels int

Output features.

backend str

'traditional' or 'rotor'

weight Parameter | None

Weights [Out, In] (traditional backend only).

bias Parameter | None

Bias multivector [Out, Dim] (traditional backend only).

gadget Module | None

Rotor transformation (rotor backend only).

Source code in layers/primitives/linear.py
class CliffordLinear(CliffordModule):
    """Fully connected layer with optional rotor-based backend.

    Can use either:
    - Traditional scalar weight matrix (default, backward compatible)
    - Rotor-based transformation (new, parameter efficient via RotorGadget)

    The traditional backend uses O(in_channels x out_channels) parameters,
    while the rotor backend uses O(num_rotor_pairs x n(n-1)/2) parameters
    where n is the number of basis vectors.

    Attributes:
        in_channels (int): Input features.
        out_channels (int): Output features.
        backend (str): 'traditional' or 'rotor'
        weight (torch.nn.Parameter | None): Weights [Out, In] (traditional backend only).
        bias (torch.nn.Parameter | None): Bias multivector [Out, Dim] (traditional backend only).
        gadget (nn.Module | None): Rotor transformation (rotor backend only).
    """

    def __init__(
        self,
        algebra: CliffordAlgebra,
        in_channels: int,
        out_channels: int,
        backend: Literal['traditional', 'rotor'] = 'traditional',
        num_rotor_pairs: int = 4,
        aggregation: Literal['mean', 'sum', 'learned'] = 'mean',
        shuffle: Literal['none', 'fixed', 'random'] = 'none',
    ):
        """Initialize Clifford Linear.

        Args:
            algebra (CliffordAlgebra): The algebra instance.
            in_channels (int): Input size.
            out_channels (int): Output size.
            backend (str): 'traditional' for standard linear layer,
                          'rotor' for rotor-based transformation
            num_rotor_pairs (int): Number of rotor pairs (rotor backend only)
            aggregation (str): Aggregation method (rotor backend only)
            shuffle (str): Input channel shuffle strategy (rotor backend only):
                - 'none': No shuffle (default)
                - 'fixed': Fixed random permutation
                - 'random': Random permutation each forward pass
        """
        super().__init__(algebra)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.backend = backend

        if backend == 'traditional':
            self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels))
            self.bias = nn.Parameter(torch.Tensor(out_channels, algebra.dim))
            self.reset_parameters()
            self.gadget = None

        elif backend == 'rotor':
            from .rotor_gadget import RotorGadget
            self.gadget = RotorGadget(
                algebra=algebra,
                in_channels=in_channels,
                out_channels=out_channels,
                num_rotor_pairs=num_rotor_pairs,
                aggregation=aggregation,
                shuffle=shuffle,
                bias=True,  # Include bias in rotor gadget
            )
            self.weight = None
            self.bias = None

        else:
            raise ValueError(
                f"Unknown backend: {backend}. Must be 'traditional' or 'rotor'."
            )

    def reset_parameters(self):
        """Initialize weights with Xavier uniform and zero bias."""
        if self.backend == 'traditional':
            nn.init.xavier_uniform_(self.weight)
            nn.init.zeros_(self.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply channel-mixing linear transformation.

        Args:
            x (torch.Tensor): Input [Batch, In, Dim].

        Returns:
            torch.Tensor: Output [Batch, Out, Dim].
        """
        check_multivector(x, self.algebra, "CliffordLinear input")
        check_channels(x, self.in_channels, "CliffordLinear input")

        if self.backend == 'traditional':
            # Traditional linear transformation
            # x: [Batch, In, Dim]
            # weight: [Out, In]
            # out: [Batch, Out, Dim]
            out = torch.einsum('oi,bid->bod', self.weight, x)
            out = out + self.bias.unsqueeze(0)
            return out
        else:
            # Rotor-based transformation
            return self.gadget(x)

    def extra_repr(self) -> str:
        """String representation for debugging.

        Returns:
            str: Layer parameters description
        """
        if self.backend == 'traditional':
            return f"in_channels={self.in_channels}, out_channels={self.out_channels}, backend=traditional"
        else:
            return f"in_channels={self.in_channels}, out_channels={self.out_channels}, backend=rotor"

__init__(algebra, in_channels, out_channels, backend='traditional', num_rotor_pairs=4, aggregation='mean', shuffle='none')

Initialize Clifford Linear.

Parameters:

Name Type Description Default
algebra CliffordAlgebra

The algebra instance.

required
in_channels int

Input size.

required
out_channels int

Output size.

required
backend str

'traditional' for standard linear layer, 'rotor' for rotor-based transformation

'traditional'
num_rotor_pairs int

Number of rotor pairs (rotor backend only)

4
aggregation str

Aggregation method (rotor backend only)

'mean'
shuffle str

Input channel shuffle strategy (rotor backend only): - 'none': No shuffle (default) - 'fixed': Fixed random permutation - 'random': Random permutation each forward pass

'none'
Source code in layers/primitives/linear.py
def __init__(
    self,
    algebra: CliffordAlgebra,
    in_channels: int,
    out_channels: int,
    backend: Literal['traditional', 'rotor'] = 'traditional',
    num_rotor_pairs: int = 4,
    aggregation: Literal['mean', 'sum', 'learned'] = 'mean',
    shuffle: Literal['none', 'fixed', 'random'] = 'none',
):
    """Initialize Clifford Linear.

    Args:
        algebra (CliffordAlgebra): The algebra instance.
        in_channels (int): Input size.
        out_channels (int): Output size.
        backend (str): 'traditional' for standard linear layer,
                      'rotor' for rotor-based transformation
        num_rotor_pairs (int): Number of rotor pairs (rotor backend only)
        aggregation (str): Aggregation method (rotor backend only)
        shuffle (str): Input channel shuffle strategy (rotor backend only):
            - 'none': No shuffle (default)
            - 'fixed': Fixed random permutation
            - 'random': Random permutation each forward pass
    """
    super().__init__(algebra)
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.backend = backend

    if backend == 'traditional':
        self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels))
        self.bias = nn.Parameter(torch.Tensor(out_channels, algebra.dim))
        self.reset_parameters()
        self.gadget = None

    elif backend == 'rotor':
        from .rotor_gadget import RotorGadget
        self.gadget = RotorGadget(
            algebra=algebra,
            in_channels=in_channels,
            out_channels=out_channels,
            num_rotor_pairs=num_rotor_pairs,
            aggregation=aggregation,
            shuffle=shuffle,
            bias=True,  # Include bias in rotor gadget
        )
        self.weight = None
        self.bias = None

    else:
        raise ValueError(
            f"Unknown backend: {backend}. Must be 'traditional' or 'rotor'."
        )

reset_parameters()

Initialize weights with Xavier uniform and zero bias.

Source code in layers/primitives/linear.py
def reset_parameters(self):
    """Initialize weights with Xavier uniform and zero bias."""
    if self.backend == 'traditional':
        nn.init.xavier_uniform_(self.weight)
        nn.init.zeros_(self.bias)

forward(x)

Apply channel-mixing linear transformation.

Parameters:

Name Type Description Default
x Tensor

Input [Batch, In, Dim].

required

Returns:

Type Description
Tensor

torch.Tensor: Output [Batch, Out, Dim].

Source code in layers/primitives/linear.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Apply channel-mixing linear transformation.

    Args:
        x (torch.Tensor): Input [Batch, In, Dim].

    Returns:
        torch.Tensor: Output [Batch, Out, Dim].
    """
    check_multivector(x, self.algebra, "CliffordLinear input")
    check_channels(x, self.in_channels, "CliffordLinear input")

    if self.backend == 'traditional':
        # Traditional linear transformation
        # x: [Batch, In, Dim]
        # weight: [Out, In]
        # out: [Batch, Out, Dim]
        out = torch.einsum('oi,bid->bod', self.weight, x)
        out = out + self.bias.unsqueeze(0)
        return out
    else:
        # Rotor-based transformation
        return self.gadget(x)

extra_repr()

String representation for debugging.

Returns:

Name Type Description
str str

Layer parameters description

Source code in layers/primitives/linear.py
def extra_repr(self) -> str:
    """String representation for debugging.

    Returns:
        str: Layer parameters description
    """
    if self.backend == 'traditional':
        return f"in_channels={self.in_channels}, out_channels={self.out_channels}, backend=traditional"
    else:
        return f"in_channels={self.in_channels}, out_channels={self.out_channels}, backend=rotor"

RotorGadget

Bases: CliffordModule

Rotor-based linear transformation (Generalized Rotor Gadget).

Replaces standard linear layers with parameter-efficient rotor-sandwich transformations. Instead of using O(in_channels x out_channels) parameters, this uses O(num_rotor_pairs x n(n-1)/2) parameters where n is the number of basis vectors in the Clifford algebra.

Architecture
  1. Partition input channels into blocks
  2. For each rotor pair (i, j):
  3. Apply rotor sandwich: r_ij . x_i . s_ij.H
  4. Pool/aggregate results to output channels

The transformation is: psi(x) = r.x.s.H where r, s are rotors (bivector exponentials).

Attributes:

Name Type Description
algebra CliffordAlgebra

CliffordAlgebra instance

in_channels

Number of input channels

out_channels

Number of output channels

num_rotor_pairs

Number of rotor pairs to use

aggregation

Aggregation method ('mean', 'sum', or 'learned')

Source code in layers/primitives/rotor_gadget.py
class RotorGadget(CliffordModule):
    """Rotor-based linear transformation (Generalized Rotor Gadget).

    Replaces standard linear layers with parameter-efficient rotor-sandwich
    transformations. Instead of using O(in_channels x out_channels) parameters,
    this uses O(num_rotor_pairs x n(n-1)/2) parameters where n is the number
    of basis vectors in the Clifford algebra.

    Architecture:
        1. Partition input channels into blocks
        2. For each rotor pair (i, j):
           - Apply rotor sandwich: r_ij . x_i . s_ij.H
        3. Pool/aggregate results to output channels

    The transformation is: psi(x) = r.x.s.H where r, s are rotors (bivector exponentials).

    Attributes:
        algebra: CliffordAlgebra instance
        in_channels: Number of input channels
        out_channels: Number of output channels
        num_rotor_pairs: Number of rotor pairs to use
        aggregation: Aggregation method ('mean', 'sum', or 'learned')
    """

    def __init__(
        self,
        algebra: CliffordAlgebra,
        in_channels: int,
        out_channels: int,
        num_rotor_pairs: int = 4,
        aggregation: Literal['mean', 'sum', 'learned'] = 'mean',
        shuffle: Literal['none', 'fixed', 'random'] = 'none',
        bias: bool = False,
    ):
        """Initialize rotor gadget layer.

        Args:
            algebra: CliffordAlgebra instance
            in_channels: Number of input channels
            out_channels: Number of output channels
            num_rotor_pairs: Number of rotor pairs (higher = more expressive)
            aggregation: How to pool rotor outputs ('mean', 'sum', 'learned')
            shuffle: Input channel shuffle strategy:
                - 'none': No shuffle, sequential block assignment (default)
                - 'fixed': Random permutation at initialization (fixed during training)
                - 'random': Random permutation each forward pass (regularization)
            bias: Whether to include bias term (applied after transformation)
        """
        super().__init__(algebra)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_rotor_pairs = num_rotor_pairs
        self.aggregation = aggregation
        self.shuffle = shuffle

        # Use algebra's precomputed grade masks for bivector indices
        if algebra.num_grades > 2:
            bv_mask = algebra.grade_masks[2]
            self.register_buffer('bivector_indices', bv_mask.nonzero(as_tuple=False).squeeze(-1))
        else:
            self.register_buffer('bivector_indices', torch.tensor([], dtype=torch.long, device=algebra.device))
        self.num_bivectors = len(self.bivector_indices)

        if self.num_bivectors == 0:
            raise ValueError(
                f"Algebra has no bivectors. RotorGadget requires "
                "at least one bivector for rotation."
            )

        # Rotor parameters: bivector coefficients for exponential map
        # Left rotors: [num_rotor_pairs, num_bivectors]
        self.bivector_left = nn.Parameter(
            torch.randn(num_rotor_pairs, self.num_bivectors) * 0.1
        )
        self.bivector_left._manifold = 'spin'
        # Right rotors: [num_rotor_pairs, num_bivectors]
        self.bivector_right = nn.Parameter(
            torch.randn(num_rotor_pairs, self.num_bivectors) * 0.1
        )
        self.bivector_right._manifold = 'spin'

        # Channel routing: block diagonal partitioning (paper style)
        # Each rotor pair processes a subset of input channels
        self._setup_channel_routing()

        # Aggregation weights (if learned)
        if aggregation == 'learned':
            # Learned weights for combining rotor outputs
            self.agg_weights = nn.Parameter(
                torch.ones(num_rotor_pairs, out_channels) / num_rotor_pairs
            )
        else:
            self.register_buffer('agg_weights', None)

        # Optional bias
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channels, algebra.dim))
        else:
            self.register_buffer('bias', None)

        # Rotor cache for eval mode
        self._cached_rotors = None

    def _setup_channel_routing(self):
        """Set up block diagonal channel routing with optional shuffle.

        Partitions input and output channels into blocks, where each rotor
        pair operates on a specific block. Optionally shuffles input channels
        before routing for regularization.
        """
        # Compute block sizes
        in_block_size = max(1, self.in_channels // self.num_rotor_pairs)
        out_block_size = max(1, self.out_channels // self.num_rotor_pairs)

        # Create routing indices
        in_indices = []
        out_indices = []

        for i in range(self.num_rotor_pairs):
            # Input block for this rotor pair
            in_start = i * in_block_size
            in_end = min((i + 1) * in_block_size, self.in_channels)
            in_indices.append((in_start, in_end))

            # Output block for this rotor pair
            out_start = i * out_block_size
            out_end = min((i + 1) * out_block_size, self.out_channels)
            out_indices.append((out_start, out_end))

        self.in_indices = in_indices
        self.out_indices = out_indices

        # Precompute channel-to-rotor-pair mapping for vectorized forward
        ch2pair = torch.zeros(self.in_channels, dtype=torch.long)
        for i, (s, e) in enumerate(in_indices):
            if e > s:
                ch2pair[s:e] = i
        self.register_buffer('_ch2pair', ch2pair)

        # Set up channel shuffle permutation
        if self.shuffle == 'fixed':
            # Create fixed random permutation at initialization
            perm = torch.randperm(self.in_channels)
            self.register_buffer('channel_permutation', perm)
        elif self.shuffle == 'random':
            # Random shuffle each forward pass - no fixed permutation
            self.register_buffer('channel_permutation', None)
        else:  # 'none'
            # No shuffle - identity permutation
            self.register_buffer('channel_permutation', None)

    def _bivector_to_multivector(self, bivector_coeffs: torch.Tensor) -> torch.Tensor:
        """Convert bivector coefficients to full multivector via vectorized scatter.

        Args:
            bivector_coeffs: Tensor of shape [..., num_bivectors]

        Returns:
            Multivector tensor of shape [..., algebra.dim]
        """
        batch_shape = bivector_coeffs.shape[:-1]
        mv = torch.zeros(*batch_shape, self.algebra.dim,
                          device=bivector_coeffs.device, dtype=bivector_coeffs.dtype)
        # Expand indices to match batch shape for scatter_
        idx = self.bivector_indices.expand(*batch_shape, -1)
        mv.scatter_(-1, idx, bivector_coeffs)
        return mv

    def _compute_rotors(self):
        """Compute rotor multivectors from bivector parameters.

        Returns:
            Tuple of (left_rotors, right_rotors_reversed) where each is
            a tensor of shape [num_rotor_pairs, algebra.dim]
        """
        # Convert bivector parameters to multivectors
        B_left = self._bivector_to_multivector(self.bivector_left)  # [pairs, dim]
        B_right = self._bivector_to_multivector(self.bivector_right)  # [pairs, dim]

        # Compute rotors via exponential map: R = exp(-0.5 * B)
        R_left = self.algebra.exp(-0.5 * B_left)  # [pairs, dim]
        R_right = self.algebra.exp(-0.5 * B_right)  # [pairs, dim]

        # Compute reverse of right rotors for sandwich product
        R_right_rev = self.algebra.reverse(R_right)  # [pairs, dim]

        return R_left, R_right_rev

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply rotor-based transformation.

        Uses batched geometric products - all rotor pairs are applied in
        parallel via a single pair of GP calls.

        Args:
            x: Input tensor of shape [Batch, In_Channels, Dim]

        Returns:
            Output tensor of shape [Batch, Out_Channels, Dim]
        """
        check_multivector(x, self.algebra, "RotorGadget input")
        check_channels(x, self.in_channels, "RotorGadget input")

        # Apply input channel shuffle if enabled
        if self.shuffle == 'fixed':
            x = x[:, self.channel_permutation, :]
        elif self.shuffle == 'random':
            perm = torch.randperm(self.in_channels, device=x.device)
            x = x[:, perm, :]

        # Compute rotors (cached in eval mode)
        if not self.training and self._cached_rotors is not None:
            R_left, R_right_rev = self._cached_rotors
        else:
            R_left, R_right_rev = self._compute_rotors()
            if not self.training:
                self._cached_rotors = (R_left, R_right_rev)

        # Vectorized sandwich: map each channel to its rotor pair
        R_left_expanded = R_left[self._ch2pair].unsqueeze(0)         # [1, in_channels, D]
        R_right_expanded = R_right_rev[self._ch2pair].unsqueeze(0)   # [1, in_channels, D]

        # Two batched GPs instead of 2*K sequential GPs
        temp = self.algebra.geometric_product(R_left_expanded, x)
        concat_out = self.algebra.geometric_product(temp, R_right_expanded)

        # Map to output channels
        out = self._aggregate_to_output_channels(concat_out)

        if self.bias is not None:
            out = out + self.bias.unsqueeze(0)

        return out

    def _aggregate_to_output_channels(self, x: torch.Tensor) -> torch.Tensor:
        """Aggregate rotor pair outputs to match output channel count.

        Args:
            x: Concatenated outputs from rotor pairs [B, total_channels, dim]

        Returns:
            Aggregated output [B, out_channels, dim]
        """
        batch_size = x.shape[0]

        if self.aggregation == 'learned':
            # Weighted aggregation with learned weights
            # agg_weights: [num_pairs, out_channels]
            # Need to apply per-block
            outputs = []
            for i in range(self.num_rotor_pairs):
                in_start, in_end = self.in_indices[i]
                block_size = in_end - in_start
                if block_size == 0:
                    continue

                x_i = x[:, in_start:in_end, :]  # [B, block, dim]
                # Average over block channels and weight
                x_i_mean = x_i.mean(dim=1, keepdim=True)  # [B, 1, dim]
                # Expand to output channels with weights
                weighted = x_i_mean * self.agg_weights[i:i+1, :, None]  # [B, out_ch, dim]
                outputs.append(weighted)

            out = torch.stack(outputs, dim=0).sum(dim=0)  # [B, out_ch, dim]

        elif self.aggregation == 'sum':
            # Simple channel-wise sum with reshaping
            if x.shape[1] == self.out_channels:
                out = x
            elif x.shape[1] > self.out_channels:
                # Pool down by summing
                fold = x.shape[1] // self.out_channels
                out = x[:, :fold*self.out_channels, :].reshape(
                    batch_size, self.out_channels, fold, self.algebra.dim
                ).sum(dim=2)
            else:
                # Expand by tiling
                repeats = (self.out_channels + x.shape[1] - 1) // x.shape[1]
                out = x.repeat(1, repeats, 1)[:, :self.out_channels, :]

        else:  # 'mean'
            # Mean pooling
            if x.shape[1] == self.out_channels:
                out = x
            elif x.shape[1] > self.out_channels:
                # Pool down by averaging
                fold = x.shape[1] // self.out_channels
                out = x[:, :fold*self.out_channels, :].reshape(
                    batch_size, self.out_channels, fold, self.algebra.dim
                ).mean(dim=2)
            else:
                # Expand by tiling
                repeats = (self.out_channels + x.shape[1] - 1) // x.shape[1]
                out = x.repeat(1, repeats, 1)[:, :self.out_channels, :]

        return out

    def train(self, mode: bool = True):
        """Override to invalidate rotor cache when switching to train mode."""
        if mode:
            self._cached_rotors = None
        return super().train(mode)

    def extra_repr(self) -> str:
        """String representation for debugging."""
        return (
            f"in_channels={self.in_channels}, "
            f"out_channels={self.out_channels}, "
            f"num_rotor_pairs={self.num_rotor_pairs}, "
            f"aggregation={self.aggregation}, "
            f"shuffle={self.shuffle}, "
            f"bias={self.bias is not None}"
        )

__init__(algebra, in_channels, out_channels, num_rotor_pairs=4, aggregation='mean', shuffle='none', bias=False)

Initialize rotor gadget layer.

Parameters:

Name Type Description Default
algebra CliffordAlgebra

CliffordAlgebra instance

required
in_channels int

Number of input channels

required
out_channels int

Number of output channels

required
num_rotor_pairs int

Number of rotor pairs (higher = more expressive)

4
aggregation Literal['mean', 'sum', 'learned']

How to pool rotor outputs ('mean', 'sum', 'learned')

'mean'
shuffle Literal['none', 'fixed', 'random']

Input channel shuffle strategy: - 'none': No shuffle, sequential block assignment (default) - 'fixed': Random permutation at initialization (fixed during training) - 'random': Random permutation each forward pass (regularization)

'none'
bias bool

Whether to include bias term (applied after transformation)

False
Source code in layers/primitives/rotor_gadget.py
def __init__(
    self,
    algebra: CliffordAlgebra,
    in_channels: int,
    out_channels: int,
    num_rotor_pairs: int = 4,
    aggregation: Literal['mean', 'sum', 'learned'] = 'mean',
    shuffle: Literal['none', 'fixed', 'random'] = 'none',
    bias: bool = False,
):
    """Initialize rotor gadget layer.

    Args:
        algebra: CliffordAlgebra instance
        in_channels: Number of input channels
        out_channels: Number of output channels
        num_rotor_pairs: Number of rotor pairs (higher = more expressive)
        aggregation: How to pool rotor outputs ('mean', 'sum', 'learned')
        shuffle: Input channel shuffle strategy:
            - 'none': No shuffle, sequential block assignment (default)
            - 'fixed': Random permutation at initialization (fixed during training)
            - 'random': Random permutation each forward pass (regularization)
        bias: Whether to include bias term (applied after transformation)
    """
    super().__init__(algebra)

    self.in_channels = in_channels
    self.out_channels = out_channels
    self.num_rotor_pairs = num_rotor_pairs
    self.aggregation = aggregation
    self.shuffle = shuffle

    # Use algebra's precomputed grade masks for bivector indices
    if algebra.num_grades > 2:
        bv_mask = algebra.grade_masks[2]
        self.register_buffer('bivector_indices', bv_mask.nonzero(as_tuple=False).squeeze(-1))
    else:
        self.register_buffer('bivector_indices', torch.tensor([], dtype=torch.long, device=algebra.device))
    self.num_bivectors = len(self.bivector_indices)

    if self.num_bivectors == 0:
        raise ValueError(
            f"Algebra has no bivectors. RotorGadget requires "
            "at least one bivector for rotation."
        )

    # Rotor parameters: bivector coefficients for exponential map
    # Left rotors: [num_rotor_pairs, num_bivectors]
    self.bivector_left = nn.Parameter(
        torch.randn(num_rotor_pairs, self.num_bivectors) * 0.1
    )
    self.bivector_left._manifold = 'spin'
    # Right rotors: [num_rotor_pairs, num_bivectors]
    self.bivector_right = nn.Parameter(
        torch.randn(num_rotor_pairs, self.num_bivectors) * 0.1
    )
    self.bivector_right._manifold = 'spin'

    # Channel routing: block diagonal partitioning (paper style)
    # Each rotor pair processes a subset of input channels
    self._setup_channel_routing()

    # Aggregation weights (if learned)
    if aggregation == 'learned':
        # Learned weights for combining rotor outputs
        self.agg_weights = nn.Parameter(
            torch.ones(num_rotor_pairs, out_channels) / num_rotor_pairs
        )
    else:
        self.register_buffer('agg_weights', None)

    # Optional bias
    if bias:
        self.bias = nn.Parameter(torch.zeros(out_channels, algebra.dim))
    else:
        self.register_buffer('bias', None)

    # Rotor cache for eval mode
    self._cached_rotors = None

forward(x)

Apply rotor-based transformation.

Uses batched geometric products - all rotor pairs are applied in parallel via a single pair of GP calls.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape [Batch, In_Channels, Dim]

required

Returns:

Type Description
Tensor

Output tensor of shape [Batch, Out_Channels, Dim]

Source code in layers/primitives/rotor_gadget.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Apply rotor-based transformation.

    Uses batched geometric products - all rotor pairs are applied in
    parallel via a single pair of GP calls.

    Args:
        x: Input tensor of shape [Batch, In_Channels, Dim]

    Returns:
        Output tensor of shape [Batch, Out_Channels, Dim]
    """
    check_multivector(x, self.algebra, "RotorGadget input")
    check_channels(x, self.in_channels, "RotorGadget input")

    # Apply input channel shuffle if enabled
    if self.shuffle == 'fixed':
        x = x[:, self.channel_permutation, :]
    elif self.shuffle == 'random':
        perm = torch.randperm(self.in_channels, device=x.device)
        x = x[:, perm, :]

    # Compute rotors (cached in eval mode)
    if not self.training and self._cached_rotors is not None:
        R_left, R_right_rev = self._cached_rotors
    else:
        R_left, R_right_rev = self._compute_rotors()
        if not self.training:
            self._cached_rotors = (R_left, R_right_rev)

    # Vectorized sandwich: map each channel to its rotor pair
    R_left_expanded = R_left[self._ch2pair].unsqueeze(0)         # [1, in_channels, D]
    R_right_expanded = R_right_rev[self._ch2pair].unsqueeze(0)   # [1, in_channels, D]

    # Two batched GPs instead of 2*K sequential GPs
    temp = self.algebra.geometric_product(R_left_expanded, x)
    concat_out = self.algebra.geometric_product(temp, R_right_expanded)

    # Map to output channels
    out = self._aggregate_to_output_channels(concat_out)

    if self.bias is not None:
        out = out + self.bias.unsqueeze(0)

    return out

train(mode=True)

Override to invalidate rotor cache when switching to train mode.

Source code in layers/primitives/rotor_gadget.py
def train(self, mode: bool = True):
    """Override to invalidate rotor cache when switching to train mode."""
    if mode:
        self._cached_rotors = None
    return super().train(mode)

extra_repr()

String representation for debugging.

Source code in layers/primitives/rotor_gadget.py
def extra_repr(self) -> str:
    """String representation for debugging."""
    return (
        f"in_channels={self.in_channels}, "
        f"out_channels={self.out_channels}, "
        f"num_rotor_pairs={self.num_rotor_pairs}, "
        f"aggregation={self.aggregation}, "
        f"shuffle={self.shuffle}, "
        f"bias={self.bias is not None}"
    )

CliffordLayerNorm

Bases: CliffordModule

Geometric LayerNorm that preserves direction and recovers scale.

Normalizes the multivector to unit norm (preserving geometric direction), then injects the original log-magnitude into the scalar (grade-0) part via a learnable gate.

Attributes:

Name Type Description
weight Parameter

Per-channel direction scale [C].

bias Parameter

Per-channel scalar bias [C].

norm_scale Parameter

Per-channel gate for log-magnitude injection into grade-0. Initialized to zero so the layer starts identical to the old (scale-discarding) behaviour.

Source code in layers/primitives/normalization.py
class CliffordLayerNorm(CliffordModule):
    """Geometric LayerNorm that preserves direction and recovers scale.

    Normalizes the multivector to unit norm (preserving geometric direction),
    then injects the original log-magnitude into the scalar (grade-0) part
    via a learnable gate.

    Attributes:
        weight (nn.Parameter): Per-channel direction scale [C].
        bias (nn.Parameter): Per-channel scalar bias [C].
        norm_scale (nn.Parameter): Per-channel gate for log-magnitude
            injection into grade-0.  Initialized to zero so the layer
            starts identical to the old (scale-discarding) behaviour.
    """

    def __init__(self, algebra: CliffordAlgebra, channels: int, eps: float = 1e-6, recover: bool = True):
        """Sets up normalization.

        Args:
            algebra (CliffordAlgebra): The algebra instance.
            channels (int): Features.
            eps (float): Stability term.
            recover (bool): Whether to inject original scale into the scalar part.
        """
        super().__init__(algebra)
        self.eps = eps
        self.recover = recover

        self.weight = nn.Parameter(torch.ones(channels))
        self.bias = nn.Parameter(torch.zeros(channels))
        # Learnable gate: how much of the original log-magnitude to push
        # into the scalar part.  Zero-init -> backward compatible at start.
        if recover:
            self.norm_scale = nn.Parameter(torch.zeros(channels))
        else:
            self.register_buffer('norm_scale', None)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Normalizes energy, preserves direction, optionally recovers scale in grade-0.

        Args:
            x (torch.Tensor): Input [Batch, Channels, Dim].

        Returns:
            torch.Tensor: Normalized input.
        """
        # Per-channel magnitude
        norm = x.norm(dim=-1, keepdim=True)  # [B, C, 1]

        # Normalize direction
        x_normalized = x / (norm + self.eps)

        # Affine transform on direction
        out = x_normalized * self.weight.view(1, -1, 1)

        # Add bias and optional log-magnitude to grade-0 via mask
        g0 = self.algebra.grade_masks_float[0]  # [D], 1.0 at index 0
        if g0.dtype != x.dtype:
            g0 = g0.to(dtype=x.dtype)
        out = out + self.bias.view(1, -1, 1) * g0

        if self.recover:
            # Push original magnitude into scalar (grade-0) part.
            # log1p keeps the value bounded and well-behaved for gradients.
            log_norm = torch.log1p(norm.squeeze(-1)).unsqueeze(-1)  # [B, C, 1]
            out = out + self.norm_scale.view(1, -1, 1) * log_norm * g0

        return out

__init__(algebra, channels, eps=1e-06, recover=True)

Sets up normalization.

Parameters:

Name Type Description Default
algebra CliffordAlgebra

The algebra instance.

required
channels int

Features.

required
eps float

Stability term.

1e-06
recover bool

Whether to inject original scale into the scalar part.

True
Source code in layers/primitives/normalization.py
def __init__(self, algebra: CliffordAlgebra, channels: int, eps: float = 1e-6, recover: bool = True):
    """Sets up normalization.

    Args:
        algebra (CliffordAlgebra): The algebra instance.
        channels (int): Features.
        eps (float): Stability term.
        recover (bool): Whether to inject original scale into the scalar part.
    """
    super().__init__(algebra)
    self.eps = eps
    self.recover = recover

    self.weight = nn.Parameter(torch.ones(channels))
    self.bias = nn.Parameter(torch.zeros(channels))
    # Learnable gate: how much of the original log-magnitude to push
    # into the scalar part.  Zero-init -> backward compatible at start.
    if recover:
        self.norm_scale = nn.Parameter(torch.zeros(channels))
    else:
        self.register_buffer('norm_scale', None)

forward(x)

Normalizes energy, preserves direction, optionally recovers scale in grade-0.

Parameters:

Name Type Description Default
x Tensor

Input [Batch, Channels, Dim].

required

Returns:

Type Description
Tensor

torch.Tensor: Normalized input.

Source code in layers/primitives/normalization.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Normalizes energy, preserves direction, optionally recovers scale in grade-0.

    Args:
        x (torch.Tensor): Input [Batch, Channels, Dim].

    Returns:
        torch.Tensor: Normalized input.
    """
    # Per-channel magnitude
    norm = x.norm(dim=-1, keepdim=True)  # [B, C, 1]

    # Normalize direction
    x_normalized = x / (norm + self.eps)

    # Affine transform on direction
    out = x_normalized * self.weight.view(1, -1, 1)

    # Add bias and optional log-magnitude to grade-0 via mask
    g0 = self.algebra.grade_masks_float[0]  # [D], 1.0 at index 0
    if g0.dtype != x.dtype:
        g0 = g0.to(dtype=x.dtype)
    out = out + self.bias.view(1, -1, 1) * g0

    if self.recover:
        # Push original magnitude into scalar (grade-0) part.
        # log1p keeps the value bounded and well-behaved for gradients.
        log_norm = torch.log1p(norm.squeeze(-1)).unsqueeze(-1)  # [B, C, 1]
        out = out + self.norm_scale.view(1, -1, 1) * log_norm * g0

    return out

BladeSelector

Bases: CliffordModule

Blade Selector. Filters insignificant components.

Learns to weigh geometric grades, suppressing less relevant ones.

Attributes:

Name Type Description
weights Parameter

Soft gates [Channels, Dim].

Source code in layers/primitives/projection.py
class BladeSelector(CliffordModule):
    """Blade Selector. Filters insignificant components.

    Learns to weigh geometric grades, suppressing less relevant ones.

    Attributes:
        weights (nn.Parameter): Soft gates [Channels, Dim].
    """

    def __init__(self, algebra: CliffordAlgebra, channels: int):
        """Sets up the selector.

        Args:
            algebra (CliffordAlgebra): The algebra instance.
            channels (int): Input features.
        """
        super().__init__(algebra)

        self.weights = nn.Parameter(torch.Tensor(channels, algebra.dim))

        self.reset_parameters()

    def reset_parameters(self):
        """Initializes weights to one (pass-through)."""
        nn.init.ones_(self.weights)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Gates the grades.

        Args:
            x (torch.Tensor): Input [Batch, Channels, Dim].

        Returns:
            torch.Tensor: Filtered input.
        """
        # Sigmoid gate
        w = torch.sigmoid(self.weights).unsqueeze(0)
        return x * w

__init__(algebra, channels)

Sets up the selector.

Parameters:

Name Type Description Default
algebra CliffordAlgebra

The algebra instance.

required
channels int

Input features.

required
Source code in layers/primitives/projection.py
def __init__(self, algebra: CliffordAlgebra, channels: int):
    """Sets up the selector.

    Args:
        algebra (CliffordAlgebra): The algebra instance.
        channels (int): Input features.
    """
    super().__init__(algebra)

    self.weights = nn.Parameter(torch.Tensor(channels, algebra.dim))

    self.reset_parameters()

reset_parameters()

Initializes weights to one (pass-through).

Source code in layers/primitives/projection.py
def reset_parameters(self):
    """Initializes weights to one (pass-through)."""
    nn.init.ones_(self.weights)

forward(x)

Gates the grades.

Parameters:

Name Type Description Default
x Tensor

Input [Batch, Channels, Dim].

required

Returns:

Type Description
Tensor

torch.Tensor: Filtered input.

Source code in layers/primitives/projection.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Gates the grades.

    Args:
        x (torch.Tensor): Input [Batch, Channels, Dim].

    Returns:
        torch.Tensor: Filtered input.
    """
    # Sigmoid gate
    w = torch.sigmoid(self.weights).unsqueeze(0)
    return x * w

Blocks

GeometricProductAttention

Bases: CliffordModule

Multi-head attention using geometric product scoring.

Standard attention: score(Q, K) = / sqrt(d) (scalar only)

GA attention

product = Q_c * reverse(K_c) (geometric product per head-channel) score = (0 + lambda * ||_2||_F) / sqrt(H_c * dim)

The grade-0 (scalar) part measures alignment (like dot product). The grade-2 (bivector) part measures relative orientation - novel.

Memory: naive [B, H, L, L, H_c, D] is too large. We chunk over L_q in blocks of BLOCK_SIZE to bound peak VRAM.

Attributes:

Name Type Description
num_heads int

Number of attention heads.

head_channels int

Channels per head.

causal bool

If True, apply autoregressive causal mask.

bivector_weight float

lambda_ - weight of bivector score component.

Source code in layers/blocks/attention.py
class GeometricProductAttention(CliffordModule):
    """Multi-head attention using geometric product scoring.

    Standard attention: score(Q, K) = <Q, K> / sqrt(d)  (scalar only)

    GA attention:
        product = Q_c * reverse(K_c)    (geometric product per head-channel)
        score   = (<product>_0 + lambda_ * ||<product>_2||_F) / sqrt(H_c * dim)

    The grade-0 (scalar) part measures alignment (like dot product).
    The grade-2 (bivector) part measures relative orientation - novel.

    Memory: naive [B, H, L, L, H_c, D] is too large. We chunk over L_q
    in blocks of BLOCK_SIZE to bound peak VRAM.

    Attributes:
        num_heads (int): Number of attention heads.
        head_channels (int): Channels per head.
        causal (bool): If True, apply autoregressive causal mask.
        bivector_weight (float): lambda_ - weight of bivector score component.
    """

    def __init__(
        self,
        algebra: CliffordAlgebra,
        channels: int,
        num_heads: int,
        causal: bool = True,
        bivector_weight: float = 0.5,
        dropout: float = 0.0,
    ):
        """Sets up geometric product attention.

        Args:
            algebra: Clifford algebra instance.
            channels: Total number of multivector channels.
            num_heads: Number of attention heads.
            causal: Apply causal mask for autoregressive generation.
            bivector_weight: lambda_ weight on bivector score component.
            dropout: Dropout rate on attention weights.
        """
        super().__init__(algebra)
        assert channels % num_heads == 0, \
            f"channels ({channels}) must be divisible by num_heads ({num_heads})"

        self.channels = channels
        self.num_heads = num_heads
        self.head_channels = channels // num_heads
        self.causal = causal
        self.bivector_weight = bivector_weight

        # Q, K, V projections operate on [B*L, channels, dim]
        self.q_proj = CliffordLinear(algebra, channels, channels)
        self.k_proj = CliffordLinear(algebra, channels, channels)
        self.v_proj = CliffordLinear(algebra, channels, channels)
        self.out_proj = CliffordLinear(algebra, channels, channels)

        self.attn_dropout = nn.Dropout(dropout) if dropout > 0.0 else None

        # Precompute bilinear score tables (replaces pairwise geometric product)
        self._precompute_score_tables()

    def _precompute_score_tables(self):
        """Precomputes lookup tables for efficient attention scoring.

        Replaces the O(L**2) full pairwise geometric product with direct bilinear
        forms for grade-0 and grade-2 components of Q * reverse(K):

        Grade-0:  <Q * rev(K)>_0 = Sum_a Q[a] * K[a] * metric_rev[a]
                  -> simple weighted dot product, no pairwise expansion needed.

        Grade-2:  <Q * rev(K)>_r = Sum_a Q[a] * K[a^r] * g2_sign[r, a]
                  -> precompute K_g2 once, then batched matmul.

        Memory: ~4 MB peak vs ~256 MB for the naive B_gathered approach.
        """
        alg = self.algebra
        D = alg.dim

        # Grade-0 metric: metric_rev[a] = gp_signs[a, 0] * rev_signs[a]
        # gp_signs[a, 0] is the sign when A[a] * B[a] contributes to output blade 0
        metric_rev = alg.gp_signs[:, 0].float() * alg.rev_signs.float()
        self.register_buffer('_metric_rev', metric_rev)  # [D]

        # Grade-2 tables: for each grade-2 blade r, for each A-blade a:
        #   B-blade  = a XOR r
        #   sign     = rev_sign[a^r] * gp_signs[a, r]
        g2_blades = [i for i in range(D) if bin(i).count('1') == 2]
        n_g2 = len(g2_blades)
        self.n_g2 = n_g2

        if n_g2 > 0:
            a_idx = torch.arange(D, device=alg.device)
            r_vals = torch.tensor(g2_blades, dtype=torch.long, device=alg.device)  # [n_g2]

            # b_idx[r, a] = a XOR r_vals[r]
            b_idx = a_idx.unsqueeze(0) ^ r_vals.unsqueeze(1)  # [n_g2, D]

            # rev_sign at the B-blade position
            rev_b = alg.rev_signs.float()[b_idx]  # [n_g2, D]

            # gp_signs[a, r_val]: sign when A[a] pairs with B[a^r] to give output r
            # alg.gp_signs[:, r_vals] -> [D, n_g2]; transpose -> [n_g2, D]
            gp_ar = alg.gp_signs[:, r_vals].float().T  # [n_g2, D]

            g2_sign = rev_b * gp_ar  # [n_g2, D]
        else:
            b_idx = torch.zeros(0, D, dtype=torch.long, device=alg.device)
            g2_sign = torch.zeros(0, D, device=alg.device)

        self.register_buffer('_g2_b_idx', b_idx)   # [n_g2, D] long
        self.register_buffer('_g2_sign', g2_sign)  # [n_g2, D] float

    def _compute_score(
        self,
        q_head: torch.Tensor,
        k_head: torch.Tensor,
        k_g2: torch.Tensor,
    ) -> torch.Tensor:
        """Computes GA attention score using precomputed bilinear form tables.

        Avoids the O(B.H.Lq.Lk.Hc.D.BLOCK) memory of the full pairwise
        geometric product. Instead:

          Grade-0: score_g0 = Q_weighted @ K^T  (weighted dot product, peak ~1 MB)
          Grade-2: batched matmul via precomputed k_g2            (peak ~4 MB)

        Args:
            q_head: Query block [B, H, Lq, Hc, D]
            k_head: Keys        [B, H, Lk, Hc, D]
            k_g2:   Precomputed [B, H, Lk, Hc, n_g2, D]
                    k_g2[b,h,j,c,r,d] = K[b,h,j,c, d^r] * g2_sign[r, d]

        Returns:
            scores: [B, H, Lq, Lk]
        """
        B, H, Lq, Hc, D = q_head.shape
        Lk = k_head.shape[2]
        n_g2 = self.n_g2

        # == Grade-0 score ====================================================
        # <Q * rev(K)>_0 = Sum_c Sum_d  Q[c,d] * K[c,d] * metric_rev[d]
        # Implemented as a batched matrix multiply: [B,H,Lq,Hc*D] @ [B,H,Hc*D,Lk]
        q_weighted = q_head * self._metric_rev          # [B, H, Lq, Hc, D]
        q_flat = q_weighted.reshape(B, H, Lq, Hc * D)  # [B, H, Lq, Hc*D]
        k_flat = k_head.reshape(B, H, Lk, Hc * D)      # [B, H, Lk, Hc*D]
        score_g0 = torch.matmul(q_flat, k_flat.transpose(-2, -1))  # [B, H, Lq, Lk]

        # == Grade-2 score ====================================================
        # ||<Q * rev(K)>_2||_F = sqrt(Sum_c Sum_r (Sum_d Q[c,d]*k_g2[j,c,r,d])^2)
        # Batched matmul merging (B, H, Hc) into one batch dimension:
        #   q_2d:     [B*H*Hc, Lq, D]
        #   k_g2_2d:  [B*H*Hc, Lk*n_g2, D]   (Lk and n_g2 merged, n_g2 varies fast)
        #   comp:     [B*H*Hc, Lq, Lk*n_g2]
        # Peak ~4 MB vs ~256 MB for the naive B_gathered approach.
        if n_g2 > 0:
            q_2d = q_head.permute(0, 1, 3, 2, 4).reshape(B * H * Hc, Lq, D)
            # k_g2: [B, H, Lk, Hc, n_g2, D] -> permute to [B, H, Hc, Lk, n_g2, D]
            k_g2_t = k_g2.permute(0, 1, 3, 2, 4, 5)
            k_g2_2d = k_g2_t.reshape(B * H * Hc, Lk * n_g2, D)
            # [B*H*Hc, Lq, D] @ [B*H*Hc, D, Lk*n_g2] -> [B*H*Hc, Lq, Lk*n_g2]
            comp = torch.bmm(q_2d, k_g2_2d.transpose(-2, -1))
            # Sum squared components over n_g2, then sum over Hc -> [B, H, Lq, Lk]
            comp_sq = comp.reshape(B * H * Hc, Lq, Lk, n_g2).pow(2).sum(-1)  # [B*H*Hc, Lq, Lk]
            score_g2_sq = comp_sq.reshape(B, H, Hc, Lq, Lk).sum(2)           # [B, H, Lq, Lk]
            score_g2 = score_g2_sq.sqrt()
        else:
            score_g2 = torch.zeros_like(score_g0)

        # Combined score
        scale = math.sqrt(self.head_channels * self.algebra.dim)
        return (score_g0 + self.bivector_weight * score_g2) / scale

    def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None) -> torch.Tensor:
        """Computes geometric product attention.

        Args:
            x: Input multivectors [B, L, C, D].
            key_padding_mask: Optional [B, L] bool mask where True = padded (ignored).

        Returns:
            Output multivectors [B, L, C, D].
        """
        B, L, C, D = x.shape

        # Project Q, K, V (CliffordLinear expects [B, C, D])
        x_flat = x.reshape(B * L, C, D)
        Q = self.q_proj(x_flat).reshape(B, L, C, D)
        K = self.k_proj(x_flat).reshape(B, L, C, D)
        V = self.v_proj(x_flat).reshape(B, L, C, D)

        H = self.num_heads
        Hc = self.head_channels

        # Reshape to [B, H, L, Hc, D]
        Q = Q.reshape(B, L, H, Hc, D).permute(0, 2, 1, 3, 4)  # [B, H, L, Hc, D]
        K = K.reshape(B, L, H, Hc, D).permute(0, 2, 1, 3, 4)
        V = V.reshape(B, L, H, Hc, D).permute(0, 2, 1, 3, 4)

        # Build causal mask once [L, L]
        if self.causal:
            causal_mask = torch.triu(
                torch.ones(L, L, device=x.device, dtype=torch.bool), diagonal=1
            )  # True = masked (future)
        else:
            causal_mask = None

        # Precompute K_g2 once for all query blocks - much cheaper than recomputing
        # k_g2[b,h,j,c,r,d] = K[b,h,j,c, d^r_val] * g2_sign[r, d]
        # Shape: [B, H, L, Hc, n_g2, D]  ~= 768 KB for the small MPS config
        K_g2 = K[..., self._g2_b_idx] * self._g2_sign  # [B, H, L, Hc, n_g2, D]

        # Chunked attention over query positions to bound memory
        output_chunks = []
        for q_start in range(0, L, _BLOCK_SIZE):
            q_end = min(q_start + _BLOCK_SIZE, L)

            Q_block = Q[:, :, q_start:q_end]  # [B, H, Lq, Hc, D]

            # Compute scores: [B, H, Lq, L]
            scores = self._compute_score(Q_block, K, K_g2)

            # Apply causal mask
            if causal_mask is not None:
                mask_block = causal_mask[q_start:q_end, :]  # [Lq, L]
                scores = scores.masked_fill(
                    mask_block.unsqueeze(0).unsqueeze(0), float('-inf')
                )

            # Apply key padding mask: True = padded -> -inf
            if key_padding_mask is not None:
                # key_padding_mask: [B, L] -> [B, 1, 1, L]
                scores = scores.masked_fill(
                    key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf')
                )

            # Softmax + dropout
            attn_weights = F.softmax(scores, dim=-1)  # [B, H, Lq, L]
            if self.attn_dropout is not None:
                attn_weights = self.attn_dropout(attn_weights)

            # Aggregate values: sum_k attn[b,h,i,k] * V[b,h,k,Hc,D]
            # attn_weights: [B, H, Lq, L]
            # V:            [B, H, L,  Hc, D]
            # out:          [B, H, Lq, Hc, D]
            out_block = torch.einsum('bhij,bhjcd->bhicd', attn_weights, V)
            output_chunks.append(out_block)

        # Reassemble: [B, H, L, Hc, D]
        output = torch.cat(output_chunks, dim=2)

        # Merge heads back: [B, L, C, D]
        output = output.permute(0, 2, 1, 3, 4).reshape(B, L, C, D)

        # Output projection
        output = self.out_proj(output.reshape(B * L, C, D)).reshape(B, L, C, D)

        return output

__init__(algebra, channels, num_heads, causal=True, bivector_weight=0.5, dropout=0.0)

Sets up geometric product attention.

Parameters:

Name Type Description Default
algebra CliffordAlgebra

Clifford algebra instance.

required
channels int

Total number of multivector channels.

required
num_heads int

Number of attention heads.

required
causal bool

Apply causal mask for autoregressive generation.

True
bivector_weight float

lambda_ weight on bivector score component.

0.5
dropout float

Dropout rate on attention weights.

0.0
Source code in layers/blocks/attention.py
def __init__(
    self,
    algebra: CliffordAlgebra,
    channels: int,
    num_heads: int,
    causal: bool = True,
    bivector_weight: float = 0.5,
    dropout: float = 0.0,
):
    """Sets up geometric product attention.

    Args:
        algebra: Clifford algebra instance.
        channels: Total number of multivector channels.
        num_heads: Number of attention heads.
        causal: Apply causal mask for autoregressive generation.
        bivector_weight: lambda_ weight on bivector score component.
        dropout: Dropout rate on attention weights.
    """
    super().__init__(algebra)
    assert channels % num_heads == 0, \
        f"channels ({channels}) must be divisible by num_heads ({num_heads})"

    self.channels = channels
    self.num_heads = num_heads
    self.head_channels = channels // num_heads
    self.causal = causal
    self.bivector_weight = bivector_weight

    # Q, K, V projections operate on [B*L, channels, dim]
    self.q_proj = CliffordLinear(algebra, channels, channels)
    self.k_proj = CliffordLinear(algebra, channels, channels)
    self.v_proj = CliffordLinear(algebra, channels, channels)
    self.out_proj = CliffordLinear(algebra, channels, channels)

    self.attn_dropout = nn.Dropout(dropout) if dropout > 0.0 else None

    # Precompute bilinear score tables (replaces pairwise geometric product)
    self._precompute_score_tables()

forward(x, key_padding_mask=None)

Computes geometric product attention.

Parameters:

Name Type Description Default
x Tensor

Input multivectors [B, L, C, D].

required
key_padding_mask Tensor

Optional [B, L] bool mask where True = padded (ignored).

None

Returns:

Type Description
Tensor

Output multivectors [B, L, C, D].

Source code in layers/blocks/attention.py
def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None) -> torch.Tensor:
    """Computes geometric product attention.

    Args:
        x: Input multivectors [B, L, C, D].
        key_padding_mask: Optional [B, L] bool mask where True = padded (ignored).

    Returns:
        Output multivectors [B, L, C, D].
    """
    B, L, C, D = x.shape

    # Project Q, K, V (CliffordLinear expects [B, C, D])
    x_flat = x.reshape(B * L, C, D)
    Q = self.q_proj(x_flat).reshape(B, L, C, D)
    K = self.k_proj(x_flat).reshape(B, L, C, D)
    V = self.v_proj(x_flat).reshape(B, L, C, D)

    H = self.num_heads
    Hc = self.head_channels

    # Reshape to [B, H, L, Hc, D]
    Q = Q.reshape(B, L, H, Hc, D).permute(0, 2, 1, 3, 4)  # [B, H, L, Hc, D]
    K = K.reshape(B, L, H, Hc, D).permute(0, 2, 1, 3, 4)
    V = V.reshape(B, L, H, Hc, D).permute(0, 2, 1, 3, 4)

    # Build causal mask once [L, L]
    if self.causal:
        causal_mask = torch.triu(
            torch.ones(L, L, device=x.device, dtype=torch.bool), diagonal=1
        )  # True = masked (future)
    else:
        causal_mask = None

    # Precompute K_g2 once for all query blocks - much cheaper than recomputing
    # k_g2[b,h,j,c,r,d] = K[b,h,j,c, d^r_val] * g2_sign[r, d]
    # Shape: [B, H, L, Hc, n_g2, D]  ~= 768 KB for the small MPS config
    K_g2 = K[..., self._g2_b_idx] * self._g2_sign  # [B, H, L, Hc, n_g2, D]

    # Chunked attention over query positions to bound memory
    output_chunks = []
    for q_start in range(0, L, _BLOCK_SIZE):
        q_end = min(q_start + _BLOCK_SIZE, L)

        Q_block = Q[:, :, q_start:q_end]  # [B, H, Lq, Hc, D]

        # Compute scores: [B, H, Lq, L]
        scores = self._compute_score(Q_block, K, K_g2)

        # Apply causal mask
        if causal_mask is not None:
            mask_block = causal_mask[q_start:q_end, :]  # [Lq, L]
            scores = scores.masked_fill(
                mask_block.unsqueeze(0).unsqueeze(0), float('-inf')
            )

        # Apply key padding mask: True = padded -> -inf
        if key_padding_mask is not None:
            # key_padding_mask: [B, L] -> [B, 1, 1, L]
            scores = scores.masked_fill(
                key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf')
            )

        # Softmax + dropout
        attn_weights = F.softmax(scores, dim=-1)  # [B, H, Lq, L]
        if self.attn_dropout is not None:
            attn_weights = self.attn_dropout(attn_weights)

        # Aggregate values: sum_k attn[b,h,i,k] * V[b,h,k,Hc,D]
        # attn_weights: [B, H, Lq, L]
        # V:            [B, H, L,  Hc, D]
        # out:          [B, H, Lq, Hc, D]
        out_block = torch.einsum('bhij,bhjcd->bhicd', attn_weights, V)
        output_chunks.append(out_block)

    # Reassemble: [B, H, L, Hc, D]
    output = torch.cat(output_chunks, dim=2)

    # Merge heads back: [B, L, C, D]
    output = output.permute(0, 2, 1, 3, 4).reshape(B, L, C, D)

    # Output projection
    output = self.out_proj(output.reshape(B * L, C, D)).reshape(B, L, C, D)

    return output

MultiRotorFFN

Bases: CliffordModule

Embedded Geometric Toolbox - Feed-Forward Network via rotor superposition.

Standard transformers use: Linear -> GELU -> Linear. This replaces that with:

CliffordLinear(expand) -> CliffordLayerNorm
    -> MultiRotorLayer(K rotors) -> GeometricGELU
    -> CliffordLinear(contract) -> BladeSelector

The expand step lifts x into a ffn_mult x channels toolbox subspace. MultiRotorLayer applies K parallel rotors, each exploring a different rotation plane - this IS the nonlinearity, not just a scalar gate. The contract step projects back to the original channel count.

Designed as a standalone module so it can be reused in other tasks (md17, pdbbind, etc.) beyond the language model.

Parameters:

Name Type Description Default
algebra CliffordAlgebra

The algebra instance.

required
channels int

Input/output channel count.

required
ffn_mult int

Expansion factor (ffn_channels = channels * ffn_mult).

4
num_rotors int

Number of parallel rotors K in the toolbox.

8
use_rotor_backend bool

Use RotorGadget backend for CliffordLinear.

False

Input/Output shape: [B, C, D] where D = algebra.dim.

Source code in layers/blocks/multi_rotor_ffn.py
class MultiRotorFFN(CliffordModule):
    """Embedded Geometric Toolbox - Feed-Forward Network via rotor superposition.

    Standard transformers use: Linear -> GELU -> Linear.
    This replaces that with:

        CliffordLinear(expand) -> CliffordLayerNorm
            -> MultiRotorLayer(K rotors) -> GeometricGELU
            -> CliffordLinear(contract) -> BladeSelector

    The expand step lifts x into a ``ffn_mult x channels`` toolbox subspace.
    ``MultiRotorLayer`` applies K parallel rotors, each exploring a different
    rotation plane - this IS the nonlinearity, not just a scalar gate.
    The contract step projects back to the original channel count.

    Designed as a standalone module so it can be reused in other tasks
    (md17, pdbbind, etc.) beyond the language model.

    Args:
        algebra (CliffordAlgebra): The algebra instance.
        channels (int): Input/output channel count.
        ffn_mult (int): Expansion factor (ffn_channels = channels * ffn_mult).
        num_rotors (int): Number of parallel rotors K in the toolbox.
        use_rotor_backend (bool): Use RotorGadget backend for CliffordLinear.

    Input/Output shape: ``[B, C, D]`` where D = algebra.dim.
    """

    def __init__(
        self,
        algebra: CliffordAlgebra,
        channels: int,
        ffn_mult: int = 4,
        num_rotors: int = 8,
        use_rotor_backend: bool = False,
    ):
        super().__init__(algebra)
        self.channels = channels
        ffn_channels = channels * ffn_mult
        backend = 'rotor' if use_rotor_backend else 'traditional'

        self.expand = CliffordLinear(algebra, channels, ffn_channels, backend=backend)
        self.norm = CliffordLayerNorm(algebra, ffn_channels)
        self.toolbox = MultiRotorLayer(algebra, ffn_channels, num_rotors)
        self.act = GeometricGELU(algebra, channels=ffn_channels)
        self.contract = CliffordLinear(algebra, ffn_channels, channels, backend=backend)
        self.gate = BladeSelector(algebra, channels)

    def forward(self, x) -> torch.Tensor:
        """Applies the geometric toolbox FFN.

        Args:
            x (torch.Tensor): Input ``[B, C, D]``.

        Returns:
            torch.Tensor: Output ``[B, C, D]``.
        """
        h = self.expand(x)    # [B, ffn_channels, D]
        h = self.norm(h)      # [B, ffn_channels, D]
        h = self.toolbox(h)   # [B, ffn_channels, D]  - K-rotor superposition
        h = self.act(h)       # [B, ffn_channels, D]
        h = self.contract(h)  # [B, channels, D]
        h = self.gate(h)      # [B, channels, D]      - per-blade gating
        return h

forward(x)

Applies the geometric toolbox FFN.

Parameters:

Name Type Description Default
x Tensor

Input [B, C, D].

required

Returns:

Type Description
Tensor

torch.Tensor: Output [B, C, D].

Source code in layers/blocks/multi_rotor_ffn.py
def forward(self, x) -> torch.Tensor:
    """Applies the geometric toolbox FFN.

    Args:
        x (torch.Tensor): Input ``[B, C, D]``.

    Returns:
        torch.Tensor: Output ``[B, C, D]``.
    """
    h = self.expand(x)    # [B, ffn_channels, D]
    h = self.norm(h)      # [B, ffn_channels, D]
    h = self.toolbox(h)   # [B, ffn_channels, D]  - K-rotor superposition
    h = self.act(h)       # [B, ffn_channels, D]
    h = self.contract(h)  # [B, channels, D]
    h = self.gate(h)      # [B, channels, D]      - per-blade gating
    return h

GeometricTransformerBlock

Bases: CliffordModule

Modular Geometric Transformer block.

Architecture: 1. Pre-norm 2. Geometric Attention (Standard or Entropy-Gated) 3. Residual connection 4. Pre-norm 5. Multi-Rotor FFN 6. Residual connection

Source code in layers/blocks/transformer.py
class GeometricTransformerBlock(CliffordModule):
    """Modular Geometric Transformer block.

    Architecture:
    1. Pre-norm
    2. Geometric Attention (Standard or Entropy-Gated)
    3. Residual connection
    4. Pre-norm
    5. Multi-Rotor FFN
    6. Residual connection
    """
    def __init__(
        self, 
        algebra: CliffordAlgebra, 
        channels: int, 
        num_heads: int = 4, 
        num_rotors: int = 8, 
        dropout: float = 0.1, 
        use_entropy_gating: bool = False, 
        eta: float = 1.5, 
        H_base: float = 0.5
    ):
        """Initializes the Geometric Transformer Block.

        Args:
            algebra: Clifford algebra instance.
            channels: Total multivector channels.
            num_heads: Number of attention heads.
            num_rotors: Number of rotors in the FFN.
            dropout: Dropout rate.
            use_entropy_gating: If True, uses EntropyGatedAttention.
            eta: Gating multiplier for entropy attention.
            H_base: Base entropy threshold.
        """
        super().__init__(algebra)
        self.use_entropy_gating = use_entropy_gating
        self.norm1 = CliffordLayerNorm(algebra, channels)

        if use_entropy_gating:
            self.attn = EntropyGatedAttention(algebra, channels, num_heads, eta=eta, H_base=H_base)
        else:
            self.attn = GeometricProductAttention(algebra, channels, num_heads, causal=False, dropout=dropout)

        self.norm2 = CliffordLayerNorm(algebra, channels)

        # Check MultiRotorFFN class name in multi_rotor_ffn.py
        from .multi_rotor_ffn import MultiRotorFFN
        self.ffn = MultiRotorFFN(algebra, channels, num_rotors=num_rotors)

    def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None,
                return_state: bool = False) -> torch.Tensor:
        """Forward pass through the transformer block.

        Args:
            x: Input multivectors [B, L, C, D].
            key_padding_mask: Optional [B, L] bool mask where True = padded.
            return_state: If True, returns intermediate entropy/gating states.

        Returns:
            Processed multivectors [B, L, C, D] (and optionally intermediate states).
        """
        B, L, C, D = x.shape

        # 1. Attention path
        res = x
        x_n = self.norm1(x.reshape(B*L, C, D)).reshape(B, L, C, D)

        if self.use_entropy_gating and return_state:
            attn_out, H, lambda_dyn = self.attn(x_n, key_padding_mask=key_padding_mask, return_gating=True)
        else:
            attn_out = self.attn(x_n, key_padding_mask=key_padding_mask)
            H, lambda_dyn = None, None

        x = res + attn_out

        # 2. FFN path
        res = x
        x_n = self.norm2(x.reshape(B*L, C, D)).reshape(B, L, C, D)
        f_out = self.ffn(x_n.reshape(B*L, C, D)).reshape(B, L, C, D)
        x = res + f_out

        if return_state:
            return x, H, lambda_dyn
        return x

__init__(algebra, channels, num_heads=4, num_rotors=8, dropout=0.1, use_entropy_gating=False, eta=1.5, H_base=0.5)

Initializes the Geometric Transformer Block.

Parameters:

Name Type Description Default
algebra CliffordAlgebra

Clifford algebra instance.

required
channels int

Total multivector channels.

required
num_heads int

Number of attention heads.

4
num_rotors int

Number of rotors in the FFN.

8
dropout float

Dropout rate.

0.1
use_entropy_gating bool

If True, uses EntropyGatedAttention.

False
eta float

Gating multiplier for entropy attention.

1.5
H_base float

Base entropy threshold.

0.5
Source code in layers/blocks/transformer.py
def __init__(
    self, 
    algebra: CliffordAlgebra, 
    channels: int, 
    num_heads: int = 4, 
    num_rotors: int = 8, 
    dropout: float = 0.1, 
    use_entropy_gating: bool = False, 
    eta: float = 1.5, 
    H_base: float = 0.5
):
    """Initializes the Geometric Transformer Block.

    Args:
        algebra: Clifford algebra instance.
        channels: Total multivector channels.
        num_heads: Number of attention heads.
        num_rotors: Number of rotors in the FFN.
        dropout: Dropout rate.
        use_entropy_gating: If True, uses EntropyGatedAttention.
        eta: Gating multiplier for entropy attention.
        H_base: Base entropy threshold.
    """
    super().__init__(algebra)
    self.use_entropy_gating = use_entropy_gating
    self.norm1 = CliffordLayerNorm(algebra, channels)

    if use_entropy_gating:
        self.attn = EntropyGatedAttention(algebra, channels, num_heads, eta=eta, H_base=H_base)
    else:
        self.attn = GeometricProductAttention(algebra, channels, num_heads, causal=False, dropout=dropout)

    self.norm2 = CliffordLayerNorm(algebra, channels)

    # Check MultiRotorFFN class name in multi_rotor_ffn.py
    from .multi_rotor_ffn import MultiRotorFFN
    self.ffn = MultiRotorFFN(algebra, channels, num_rotors=num_rotors)

forward(x, key_padding_mask=None, return_state=False)

Forward pass through the transformer block.

Parameters:

Name Type Description Default
x Tensor

Input multivectors [B, L, C, D].

required
key_padding_mask Tensor

Optional [B, L] bool mask where True = padded.

None
return_state bool

If True, returns intermediate entropy/gating states.

False

Returns:

Type Description
Tensor

Processed multivectors [B, L, C, D] (and optionally intermediate states).

Source code in layers/blocks/transformer.py
def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None,
            return_state: bool = False) -> torch.Tensor:
    """Forward pass through the transformer block.

    Args:
        x: Input multivectors [B, L, C, D].
        key_padding_mask: Optional [B, L] bool mask where True = padded.
        return_state: If True, returns intermediate entropy/gating states.

    Returns:
        Processed multivectors [B, L, C, D] (and optionally intermediate states).
    """
    B, L, C, D = x.shape

    # 1. Attention path
    res = x
    x_n = self.norm1(x.reshape(B*L, C, D)).reshape(B, L, C, D)

    if self.use_entropy_gating and return_state:
        attn_out, H, lambda_dyn = self.attn(x_n, key_padding_mask=key_padding_mask, return_gating=True)
    else:
        attn_out = self.attn(x_n, key_padding_mask=key_padding_mask)
        H, lambda_dyn = None, None

    x = res + attn_out

    # 2. FFN path
    res = x
    x_n = self.norm2(x.reshape(B*L, C, D)).reshape(B, L, C, D)
    f_out = self.ffn(x_n.reshape(B*L, C, D)).reshape(B, L, C, D)
    x = res + f_out

    if return_state:
        return x, H, lambda_dyn
    return x

Adapters

MultivectorEmbedding

Bases: CliffordModule

Token embedding as multivectors.

Each token maps to a [channels, dim] multivector. Initializes content in grade-1 (vector) subspace only - semantic content starts as directed quantities before rotors act on them.

Attributes:

Name Type Description
vocab_size int

Number of tokens.

channels int

Number of multivector channels.

embedding Embedding

Underlying embedding table.

Source code in layers/adapters/embedding.py
class MultivectorEmbedding(CliffordModule):
    """Token embedding as multivectors.

    Each token maps to a [channels, dim] multivector. Initializes
    content in grade-1 (vector) subspace only - semantic content
    starts as directed quantities before rotors act on them.

    Attributes:
        vocab_size (int): Number of tokens.
        channels (int): Number of multivector channels.
        embedding (nn.Embedding): Underlying embedding table.
    """

    def __init__(self, algebra: CliffordAlgebra, vocab_size: int, channels: int):
        """Sets up the multivector embedding.

        Args:
            algebra: Clifford algebra instance.
            vocab_size: Vocabulary size.
            channels: Number of multivector channels per token.
        """
        super().__init__(algebra)
        self.vocab_size = vocab_size
        self.channels = channels

        # Single flat embedding: vocab_size -> channels * dim
        self.embedding = nn.Embedding(vocab_size, channels * algebra.dim)
        self._init_grade1()

    def _init_grade1(self):
        """Initializes only grade-1 components; zeros out all others."""
        with torch.no_grad():
            dim = self.algebra.dim
            channels = self.channels

            # Build grade-1 mask (indices with exactly 1 bit set)
            grade1_flat = []
            for i in range(dim):
                if bin(i).count('1') == 1:
                    grade1_flat.append(i)

            # Zero everything
            self.embedding.weight.zero_()

            # Fill grade-1 slots with small normal values
            for ch in range(channels):
                for idx in grade1_flat:
                    flat_idx = ch * dim + idx
                    self.embedding.weight[:, flat_idx].normal_(std=0.02)

    def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
        """Maps token ids to multivector embeddings.

        Args:
            token_ids: Token indices [B, L].

        Returns:
            Multivector embeddings [B, L, channels, dim].
        """
        B, L = token_ids.shape
        flat = self.embedding(token_ids)  # [B, L, channels * dim]
        return flat.reshape(B, L, self.channels, self.algebra.dim)

__init__(algebra, vocab_size, channels)

Sets up the multivector embedding.

Parameters:

Name Type Description Default
algebra CliffordAlgebra

Clifford algebra instance.

required
vocab_size int

Vocabulary size.

required
channels int

Number of multivector channels per token.

required
Source code in layers/adapters/embedding.py
def __init__(self, algebra: CliffordAlgebra, vocab_size: int, channels: int):
    """Sets up the multivector embedding.

    Args:
        algebra: Clifford algebra instance.
        vocab_size: Vocabulary size.
        channels: Number of multivector channels per token.
    """
    super().__init__(algebra)
    self.vocab_size = vocab_size
    self.channels = channels

    # Single flat embedding: vocab_size -> channels * dim
    self.embedding = nn.Embedding(vocab_size, channels * algebra.dim)
    self._init_grade1()

forward(token_ids)

Maps token ids to multivector embeddings.

Parameters:

Name Type Description Default
token_ids Tensor

Token indices [B, L].

required

Returns:

Type Description
Tensor

Multivector embeddings [B, L, channels, dim].

Source code in layers/adapters/embedding.py
def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
    """Maps token ids to multivector embeddings.

    Args:
        token_ids: Token indices [B, L].

    Returns:
        Multivector embeddings [B, L, channels, dim].
    """
    B, L = token_ids.shape
    flat = self.embedding(token_ids)  # [B, L, channels * dim]
    return flat.reshape(B, L, self.channels, self.algebra.dim)

MotherEmbedding

Bases: CliffordModule

Embeds local feature groups into a canonical Mother Algebra with Procrustes Alignment.

Uses fixed rotors (R_fixed) to rotate individual channel vectors into a shared reference frame, effectively aligning disparate geometric manifolds.

Source code in layers/adapters/mother.py
class MotherEmbedding(CliffordModule):
    """Embeds local feature groups into a canonical Mother Algebra with Procrustes Alignment.

    Uses fixed rotors (R_fixed) to rotate individual channel vectors into a shared
    reference frame, effectively aligning disparate geometric manifolds.
    """
    def __init__(self, algebra: CliffordAlgebra, input_dim: int, channels: int, U: float = 0.0, V: torch.Tensor = None):
        """Initializes the Mother Embedding.

        Args:
            algebra: Clifford algebra instance.
            input_dim: Dimension of the input features.
            channels: Number of multivector channels.
            U: Geometric uncertainty index for manifold suppression.
            V: Fixed rotor proxy for Procrustes alignment (input_dim x input_dim).
        """
        super().__init__(algebra)
        self.channels = channels

        # Procrustes Alignment Matrix (Fixed Rotor Proxy)
        if V is None:
            V = torch.eye(input_dim)
        self.register_buffer('R_fixed', V)

        # Up-cast to Mother Algebra multivector channels
        self.linear = nn.Linear(input_dim, channels * algebra.dim)
        self.norm = CliffordLayerNorm(algebra, channels)

        # Pre-condition LayerNorm scale with Uncertainty Index
        with torch.no_grad():
            if hasattr(self.norm, 'weight'):
                # Suppress highly uncertain (twisted) manifolds initially
                scale = 1.0 / (1.0 + U)
                self.norm.weight.data.fill_(scale)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Projects input into the aligned mother manifold.

        Args:
            x: Input features [B, input_dim].

        Returns:
            Aligned multivectors [B, channels, dim].
        """
        # 1. Apply Geometric Procrustes Alignment
        if self.R_fixed is not None:
            x = x @ self.R_fixed.T

        # 2. Mother Projection
        c = self.linear(x).view(-1, self.channels, self.algebra.dim)
        return self.norm(c)

__init__(algebra, input_dim, channels, U=0.0, V=None)

Initializes the Mother Embedding.

Parameters:

Name Type Description Default
algebra CliffordAlgebra

Clifford algebra instance.

required
input_dim int

Dimension of the input features.

required
channels int

Number of multivector channels.

required
U float

Geometric uncertainty index for manifold suppression.

0.0
V Tensor

Fixed rotor proxy for Procrustes alignment (input_dim x input_dim).

None
Source code in layers/adapters/mother.py
def __init__(self, algebra: CliffordAlgebra, input_dim: int, channels: int, U: float = 0.0, V: torch.Tensor = None):
    """Initializes the Mother Embedding.

    Args:
        algebra: Clifford algebra instance.
        input_dim: Dimension of the input features.
        channels: Number of multivector channels.
        U: Geometric uncertainty index for manifold suppression.
        V: Fixed rotor proxy for Procrustes alignment (input_dim x input_dim).
    """
    super().__init__(algebra)
    self.channels = channels

    # Procrustes Alignment Matrix (Fixed Rotor Proxy)
    if V is None:
        V = torch.eye(input_dim)
    self.register_buffer('R_fixed', V)

    # Up-cast to Mother Algebra multivector channels
    self.linear = nn.Linear(input_dim, channels * algebra.dim)
    self.norm = CliffordLayerNorm(algebra, channels)

    # Pre-condition LayerNorm scale with Uncertainty Index
    with torch.no_grad():
        if hasattr(self.norm, 'weight'):
            # Suppress highly uncertain (twisted) manifolds initially
            scale = 1.0 / (1.0 + U)
            self.norm.weight.data.fill_(scale)

forward(x)

Projects input into the aligned mother manifold.

Parameters:

Name Type Description Default
x Tensor

Input features [B, input_dim].

required

Returns:

Type Description
Tensor

Aligned multivectors [B, channels, dim].

Source code in layers/adapters/mother.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Projects input into the aligned mother manifold.

    Args:
        x: Input features [B, input_dim].

    Returns:
        Aligned multivectors [B, channels, dim].
    """
    # 1. Apply Geometric Procrustes Alignment
    if self.R_fixed is not None:
        x = x @ self.R_fixed.T

    # 2. Mother Projection
    c = self.linear(x).view(-1, self.channels, self.algebra.dim)
    return self.norm(c)

EntropyGatedAttention

Bases: CliffordModule

Dynamic geometric attention governed by bivector information entropy.

Segments with high bivector entropy (disordered phase states) are "stiffened" or suppressed, allowing only coherent, synchronized states to propagate.

Source code in layers/adapters/mother.py
class EntropyGatedAttention(CliffordModule):
    """Dynamic geometric attention governed by bivector information entropy.

    Segments with high bivector entropy (disordered phase states) are "stiffened" 
    or suppressed, allowing only coherent, synchronized states to propagate.
    """
    def __init__(self, algebra: CliffordAlgebra, channels: int, num_heads: int, eta: float = 1.0, H_base: float = 0.5):
        """Initializes Entropy-Gated Attention.

        Args:
            algebra: Clifford algebra instance.
            channels: Total multivector channels.
            num_heads: Number of attention heads.
            eta: Gating multiplier.
            H_base: Base entropy threshold.
        """
        super().__init__(algebra)
        self.channels = channels
        self.eta = eta
        self.H_base = H_base
        self.base_attention = GeometricProductAttention(algebra, channels, num_heads, causal=False)

        # Cache bivector indices and float mask for compile-friendly gating
        mask = self.algebra.grade_masks[2]
        self.register_buffer('g2_idx', mask.nonzero(as_tuple=True)[0])
        self.register_buffer('_g2_float_mask', mask.float())

    def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None,
                return_gating: bool = False) -> torch.Tensor:
        """Applies entropy-gated geometric attention.

        Args:
            x: Input multivectors [B, L, C, D].
            key_padding_mask: Optional [B, L] bool mask where True = padded.
            return_gating: If True, returns entropy and gating values.

        Returns:
            Attended multivectors [B, L, C, D].
        """
        # 1. Calculate Information Entropy of Bivector Energy
        # Summing across multivector components (g2_idx) and across channels (dim 2)
        # x: [B, L, C, D]
        g2_energy = (x[..., self.g2_idx]**2).sum(dim=(-1, -2)) # [B, L]

        # Mask padded positions before entropy calc
        if key_padding_mask is not None:
            g2_energy = g2_energy.masked_fill(key_padding_mask, 0.0)

        # Normalize to probability distribution over sequence
        p = g2_energy / (g2_energy.sum(dim=1, keepdim=True) + 1e-8)

        # Shannon Entropy H per batch [B]
        H = -(p * torch.log(p + 1e-8)).sum(dim=1)

        # 2. Base-Adjusted Gating Function
        lambda_dyn = self.eta * torch.sigmoid(H - self.H_base) # [B]

        # 3. Apply dynamic geometric stiffness
        # Scale the rotational components (bivectors)
        lambda_view = lambda_dyn.view(-1, 1, 1, 1)

        g2_mask = self._g2_float_mask.to(dtype=x.dtype)
        scale = 1.0 + (lambda_view - 1.0) * g2_mask  # [B, 1, 1, D]
        x_gated = x * scale

        out = self.base_attention(x_gated, key_padding_mask=key_padding_mask)

        if return_gating:
            return out, H, lambda_dyn
        return out

__init__(algebra, channels, num_heads, eta=1.0, H_base=0.5)

Initializes Entropy-Gated Attention.

Parameters:

Name Type Description Default
algebra CliffordAlgebra

Clifford algebra instance.

required
channels int

Total multivector channels.

required
num_heads int

Number of attention heads.

required
eta float

Gating multiplier.

1.0
H_base float

Base entropy threshold.

0.5
Source code in layers/adapters/mother.py
def __init__(self, algebra: CliffordAlgebra, channels: int, num_heads: int, eta: float = 1.0, H_base: float = 0.5):
    """Initializes Entropy-Gated Attention.

    Args:
        algebra: Clifford algebra instance.
        channels: Total multivector channels.
        num_heads: Number of attention heads.
        eta: Gating multiplier.
        H_base: Base entropy threshold.
    """
    super().__init__(algebra)
    self.channels = channels
    self.eta = eta
    self.H_base = H_base
    self.base_attention = GeometricProductAttention(algebra, channels, num_heads, causal=False)

    # Cache bivector indices and float mask for compile-friendly gating
    mask = self.algebra.grade_masks[2]
    self.register_buffer('g2_idx', mask.nonzero(as_tuple=True)[0])
    self.register_buffer('_g2_float_mask', mask.float())

forward(x, key_padding_mask=None, return_gating=False)

Applies entropy-gated geometric attention.

Parameters:

Name Type Description Default
x Tensor

Input multivectors [B, L, C, D].

required
key_padding_mask Tensor

Optional [B, L] bool mask where True = padded.

None
return_gating bool

If True, returns entropy and gating values.

False

Returns:

Type Description
Tensor

Attended multivectors [B, L, C, D].

Source code in layers/adapters/mother.py
def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None,
            return_gating: bool = False) -> torch.Tensor:
    """Applies entropy-gated geometric attention.

    Args:
        x: Input multivectors [B, L, C, D].
        key_padding_mask: Optional [B, L] bool mask where True = padded.
        return_gating: If True, returns entropy and gating values.

    Returns:
        Attended multivectors [B, L, C, D].
    """
    # 1. Calculate Information Entropy of Bivector Energy
    # Summing across multivector components (g2_idx) and across channels (dim 2)
    # x: [B, L, C, D]
    g2_energy = (x[..., self.g2_idx]**2).sum(dim=(-1, -2)) # [B, L]

    # Mask padded positions before entropy calc
    if key_padding_mask is not None:
        g2_energy = g2_energy.masked_fill(key_padding_mask, 0.0)

    # Normalize to probability distribution over sequence
    p = g2_energy / (g2_energy.sum(dim=1, keepdim=True) + 1e-8)

    # Shannon Entropy H per batch [B]
    H = -(p * torch.log(p + 1e-8)).sum(dim=1)

    # 2. Base-Adjusted Gating Function
    lambda_dyn = self.eta * torch.sigmoid(H - self.H_base) # [B]

    # 3. Apply dynamic geometric stiffness
    # Scale the rotational components (bivectors)
    lambda_view = lambda_dyn.view(-1, 1, 1, 1)

    g2_mask = self._g2_float_mask.to(dtype=x.dtype)
    scale = 1.0 + (lambda_view - 1.0) * g2_mask  # [B, 1, 1, D]
    x_gated = x * scale

    out = self.base_attention(x_gated, key_padding_mask=key_padding_mask)

    if return_gating:
        return out, H, lambda_dyn
    return out

Optional dependency

CliffordGraphConv requires torch-geometric. Install with uv sync --extra md17.

CliffordGraphConv

Bases: CliffordModule

Geometric Graph Conv. Performs message passing using multivector features.

Aggregates features based on graph topology. H' = Aggregate(H) * W + Bias.

Attributes:

Name Type Description
linear CliffordLinear

The transformation.

Source code in layers/adapters/gnn.py
class CliffordGraphConv(CliffordModule):
    """Geometric Graph Conv. Performs message passing using multivector features.

    Aggregates features based on graph topology.
    H' = Aggregate(H) * W + Bias.

    Attributes:
        linear (CliffordLinear): The transformation.
    """

    def __init__(self, algebra: CliffordAlgebra, in_channels: int, out_channels: int):
        """Sets up the GNN layer.

        Args:
            algebra (CliffordAlgebra): The algebra instance.
            in_channels (int): Input features.
            out_channels (int): Output features.
        """
        super().__init__(algebra)
        self.linear = CliffordLinear(algebra, in_channels, out_channels)

    def forward(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
        """Aggregates and transforms node features using geometric operations.

        Args:
            x (torch.Tensor): Node features.
            adj (torch.Tensor): Adjacency matrix.

        Returns:
            torch.Tensor: Updated features.
        """
        # 1. Aggregate
        N, C, D = x.shape
        x_flat = x.view(N, -1)

        # Sparse aggregation
        x_agg_flat = torch.mm(adj, x_flat)
        x_agg = x_agg_flat.view(N, C, D)

        # 2. Transform
        out = self.linear(x_agg)

        return out

__init__(algebra, in_channels, out_channels)

Sets up the GNN layer.

Parameters:

Name Type Description Default
algebra CliffordAlgebra

The algebra instance.

required
in_channels int

Input features.

required
out_channels int

Output features.

required
Source code in layers/adapters/gnn.py
def __init__(self, algebra: CliffordAlgebra, in_channels: int, out_channels: int):
    """Sets up the GNN layer.

    Args:
        algebra (CliffordAlgebra): The algebra instance.
        in_channels (int): Input features.
        out_channels (int): Output features.
    """
    super().__init__(algebra)
    self.linear = CliffordLinear(algebra, in_channels, out_channels)

forward(x, adj)

Aggregates and transforms node features using geometric operations.

Parameters:

Name Type Description Default
x Tensor

Node features.

required
adj Tensor

Adjacency matrix.

required

Returns:

Type Description
Tensor

torch.Tensor: Updated features.

Source code in layers/adapters/gnn.py
def forward(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
    """Aggregates and transforms node features using geometric operations.

    Args:
        x (torch.Tensor): Node features.
        adj (torch.Tensor): Adjacency matrix.

    Returns:
        torch.Tensor: Updated features.
    """
    # 1. Aggregate
    N, C, D = x.shape
    x_flat = x.view(N, -1)

    # Sparse aggregation
    x_agg_flat = torch.mm(adj, x_flat)
    x_agg = x_agg_flat.view(N, C, D)

    # 2. Transform
    out = self.linear(x_agg)

    return out