Skip to content

Layers

Base

CliffordModule

Bases: Module

Base module for Clifford algebra layers.

Manages the algebra configuration.

Source code in layers/primitives/base.py
class CliffordModule(nn.Module):
    """Base module for Clifford algebra layers.

    Manages the algebra configuration.
    """

    def __init__(self, algebra: CliffordAlgebra):
        """Sets up the module.

        Args:
            algebra (CliffordAlgebra): The algebra instance.
        """
        super().__init__()
        # Store minimal config to reconstruct algebra if needed
        self.p = algebra.p
        self.q = algebra.q
        self.r = algebra.r
        self._algebra = algebra # transient reference

    @property
    def algebra(self) -> CliffordAlgebra:
        """Return the algebra instance, reconstructing if necessary."""
        if self._algebra is None:
            # Detect device from module parameters/buffers
            try:
                device = next(self.parameters()).device
            except StopIteration:
                try:
                    device = next(self.buffers()).device
                except StopIteration:
                    device = 'cpu'
            self._algebra = CliffordAlgebra(self.p, self.q, self.r, device=str(device))
        return self._algebra

    def forward(self, x):
        """Performs the forward pass computation."""
        raise NotImplementedError

algebra property

Return the algebra instance, reconstructing if necessary.

__init__(algebra)

Sets up the module.

Parameters:

Name Type Description Default
algebra CliffordAlgebra

The algebra instance.

required
Source code in layers/primitives/base.py
def __init__(self, algebra: CliffordAlgebra):
    """Sets up the module.

    Args:
        algebra (CliffordAlgebra): The algebra instance.
    """
    super().__init__()
    # Store minimal config to reconstruct algebra if needed
    self.p = algebra.p
    self.q = algebra.q
    self.r = algebra.r
    self._algebra = algebra # transient reference

forward(x)

Performs the forward pass computation.

Source code in layers/primitives/base.py
def forward(self, x):
    """Performs the forward pass computation."""
    raise NotImplementedError

Primitives

RotorLayer

Bases: CliffordModule

Learnable rotor layer for sandwich-product transformation.

Learns R = exp(-B/2) and applies the isometry x' = RxR~. Preserves origin, lengths, and angles.

Attributes:

Name Type Description
channels int

Number of rotors.

bivector_weights Parameter

Learnable B coefficients.

use_decomposition bool

If True, use power iteration decomposition.

decomp_k int

Number of simple components for decomposition.

Source code in layers/primitives/rotor.py
class RotorLayer(CliffordModule):
    """Learnable rotor layer for sandwich-product transformation.

    Learns R = exp(-B/2) and applies the isometry x' = RxR~.
    Preserves origin, lengths, and angles.

    Attributes:
        channels (int): Number of rotors.
        bivector_weights (nn.Parameter): Learnable B coefficients.
        use_decomposition (bool): If True, use power iteration decomposition.
        decomp_k (int, optional): Number of simple components for decomposition.
    """

    def __init__(
        self,
        algebra: CliffordAlgebra,
        channels: int,
        use_decomposition: bool = False,
        decomp_k: int = None
    ):
        """Initialize the rotor layer.

        Args:
            algebra (CliffordAlgebra): The algebra instance.
            channels (int): Number of features.
            use_decomposition (bool): If True, use bivector decomposition.
                Reference: Pence et al. (2025), arXiv:2507.11688v1
            decomp_k (int, optional): Number of simple components for decomposition.
        """
        super().__init__(algebra)
        self.channels = channels
        self.use_decomposition = use_decomposition
        self.decomp_k = decomp_k

        # Use algebra's precomputed grade masks for bivector indices
        bv_mask = algebra.grade_masks[2]
        self.register_buffer('bivector_indices', bv_mask.nonzero(as_tuple=False).squeeze(-1))
        self.num_bivectors = len(self.bivector_indices)

        self.bivector_weights = nn.Parameter(torch.Tensor(channels, self.num_bivectors))

        # Rotor cache for eval mode
        self._cached_R = None
        self._cached_R_rev = None

        self.reset_parameters()

    def reset_parameters(self):
        """Initialize with near-identity rotations."""
        nn.init.normal_(self.bivector_weights, std=0.01)

    def _compute_rotors(self, device, dtype):
        """Compute R and R~ from bivector weights."""
        B = torch.zeros(self.channels, self.algebra.dim, device=device, dtype=dtype)
        indices = self.bivector_indices.unsqueeze(0).expand(self.channels, -1)
        B.scatter_(1, indices, self.bivector_weights)

        if self.use_decomposition:
            R = self.algebra.exp_decomposed(
                -0.5 * B, use_decomposition=True, k=self.decomp_k
            )
        else:
            R = self.algebra.exp(-0.5 * B)

        R_rev = self.algebra.reverse(R)
        return R, R_rev

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply the sandwich product x' = RxR~.

        Caches rotors during eval mode for faster inference.

        Args:
            x (torch.Tensor): Input [Batch, Channels, Dim].

        Returns:
            torch.Tensor: Rotated input.
        """
        from core.validation import check_multivector, check_channels
        check_multivector(x, self.algebra, "RotorLayer input")
        check_channels(x, self.channels, "RotorLayer input")

        self.algebra.ensure_device(x.device)

        if not self.training and self._cached_R is not None:
            R, R_rev = self._cached_R, self._cached_R_rev
        else:
            R, R_rev = self._compute_rotors(x.device, x.dtype)
            if not self.training:
                self._cached_R = R
                self._cached_R_rev = R_rev

        R_expanded = R.unsqueeze(0)
        R_rev_expanded = R_rev.unsqueeze(0)

        Rx = self.algebra.geometric_product(R_expanded, x)
        res = self.algebra.geometric_product(Rx, R_rev_expanded)

        return res

    def train(self, mode: bool = True):
        """Override to invalidate rotor cache when switching to train mode."""
        if mode:
            self._cached_R = None
            self._cached_R_rev = None
        return super().train(mode)

    def prune_bivectors(self, threshold: float = 1e-4) -> int:
        """Zero out bivector weights below the threshold.

        Args:
            threshold (float): Cutoff magnitude.

        Returns:
            int: Number of pruned parameters.
        """
        with torch.no_grad():
            mask = torch.abs(self.bivector_weights) >= threshold
            num_pruned = (~mask).sum().item()
            self.bivector_weights.data.mul_(mask.float())
        return num_pruned

    def sparsity_loss(self) -> torch.Tensor:
        """Compute L1 sparsity regularization on bivector weights."""
        return torch.norm(self.bivector_weights, p=1)

__init__(algebra, channels, use_decomposition=False, decomp_k=None)

Initialize the rotor layer.

Parameters:

Name Type Description Default
algebra CliffordAlgebra

The algebra instance.

required
channels int

Number of features.

required
use_decomposition bool

If True, use bivector decomposition. Reference: Pence et al. (2025), arXiv:2507.11688v1

False
decomp_k int

Number of simple components for decomposition.

None
Source code in layers/primitives/rotor.py
def __init__(
    self,
    algebra: CliffordAlgebra,
    channels: int,
    use_decomposition: bool = False,
    decomp_k: int = None
):
    """Initialize the rotor layer.

    Args:
        algebra (CliffordAlgebra): The algebra instance.
        channels (int): Number of features.
        use_decomposition (bool): If True, use bivector decomposition.
            Reference: Pence et al. (2025), arXiv:2507.11688v1
        decomp_k (int, optional): Number of simple components for decomposition.
    """
    super().__init__(algebra)
    self.channels = channels
    self.use_decomposition = use_decomposition
    self.decomp_k = decomp_k

    # Use algebra's precomputed grade masks for bivector indices
    bv_mask = algebra.grade_masks[2]
    self.register_buffer('bivector_indices', bv_mask.nonzero(as_tuple=False).squeeze(-1))
    self.num_bivectors = len(self.bivector_indices)

    self.bivector_weights = nn.Parameter(torch.Tensor(channels, self.num_bivectors))

    # Rotor cache for eval mode
    self._cached_R = None
    self._cached_R_rev = None

    self.reset_parameters()

reset_parameters()

Initialize with near-identity rotations.

Source code in layers/primitives/rotor.py
def reset_parameters(self):
    """Initialize with near-identity rotations."""
    nn.init.normal_(self.bivector_weights, std=0.01)

forward(x)

Apply the sandwich product x' = RxR~.

Caches rotors during eval mode for faster inference.

Parameters:

Name Type Description Default
x Tensor

Input [Batch, Channels, Dim].

required

Returns:

Type Description
Tensor

torch.Tensor: Rotated input.

Source code in layers/primitives/rotor.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Apply the sandwich product x' = RxR~.

    Caches rotors during eval mode for faster inference.

    Args:
        x (torch.Tensor): Input [Batch, Channels, Dim].

    Returns:
        torch.Tensor: Rotated input.
    """
    from core.validation import check_multivector, check_channels
    check_multivector(x, self.algebra, "RotorLayer input")
    check_channels(x, self.channels, "RotorLayer input")

    self.algebra.ensure_device(x.device)

    if not self.training and self._cached_R is not None:
        R, R_rev = self._cached_R, self._cached_R_rev
    else:
        R, R_rev = self._compute_rotors(x.device, x.dtype)
        if not self.training:
            self._cached_R = R
            self._cached_R_rev = R_rev

    R_expanded = R.unsqueeze(0)
    R_rev_expanded = R_rev.unsqueeze(0)

    Rx = self.algebra.geometric_product(R_expanded, x)
    res = self.algebra.geometric_product(Rx, R_rev_expanded)

    return res

train(mode=True)

Override to invalidate rotor cache when switching to train mode.

Source code in layers/primitives/rotor.py
def train(self, mode: bool = True):
    """Override to invalidate rotor cache when switching to train mode."""
    if mode:
        self._cached_R = None
        self._cached_R_rev = None
    return super().train(mode)

prune_bivectors(threshold=0.0001)

Zero out bivector weights below the threshold.

Parameters:

Name Type Description Default
threshold float

Cutoff magnitude.

0.0001

Returns:

Name Type Description
int int

Number of pruned parameters.

Source code in layers/primitives/rotor.py
def prune_bivectors(self, threshold: float = 1e-4) -> int:
    """Zero out bivector weights below the threshold.

    Args:
        threshold (float): Cutoff magnitude.

    Returns:
        int: Number of pruned parameters.
    """
    with torch.no_grad():
        mask = torch.abs(self.bivector_weights) >= threshold
        num_pruned = (~mask).sum().item()
        self.bivector_weights.data.mul_(mask.float())
    return num_pruned

sparsity_loss()

Compute L1 sparsity regularization on bivector weights.

Source code in layers/primitives/rotor.py
def sparsity_loss(self) -> torch.Tensor:
    """Compute L1 sparsity regularization on bivector weights."""
    return torch.norm(self.bivector_weights, p=1)

MultiRotorLayer

Bases: CliffordModule

Multi-rotor layer with weighted superposition: x' = sum_k w_k R_k x R~_k.

Replaces rigid single-rotor rotations with a flexible superposition.

Attributes:

Name Type Description
channels int

Input features.

num_rotors int

Number of overlapping rotors.

use_decomposition bool

If True, use power iteration decomposition.

decomp_k int | None

Number of simple components for decomposition.

rotor_bivectors Parameter

Bivector coefficients [num_rotors, num_bv]

weights Parameter

Mixing weights [channels, num_rotors]

Source code in layers/primitives/multi_rotor.py
class MultiRotorLayer(CliffordModule):
    """Multi-rotor layer with weighted superposition: x' = sum_k w_k R_k x R~_k.

    Replaces rigid single-rotor rotations with a flexible superposition.

    Attributes:
        channels (int): Input features.
        num_rotors (int): Number of overlapping rotors.
        use_decomposition (bool): If True, use power iteration decomposition.
        decomp_k (int | None): Number of simple components for decomposition.
        rotor_bivectors (torch.nn.Parameter): Bivector coefficients [num_rotors, num_bv]
        weights (torch.nn.Parameter): Mixing weights [channels, num_rotors]
    """

    def __init__(
        self,
        algebra: CliffordAlgebra,
        channels: int,
        num_rotors: int = 8,
        use_decomposition: bool = False,
        decomp_k: int = None
    ):
        """Initialize Multi-Rotor Layer.

        Args:
            algebra (CliffordAlgebra): The algebra instance.
            channels (int): Input features.
            num_rotors (int): Parallel heads.
            use_decomposition (bool): If True, use bivector decomposition.
                Reference: Pence et al. (2025), arXiv:2507.11688v1
            decomp_k (int, optional): Number of simple components for decomposition.
        """
        super().__init__(algebra)
        self.channels = channels
        self.num_rotors = num_rotors
        self.use_decomposition = use_decomposition
        self.decomp_k = decomp_k

        # Use algebra's precomputed grade masks for bivector indices
        bv_mask = algebra.grade_masks[2]
        self.register_buffer('bivector_indices', bv_mask.nonzero(as_tuple=False).squeeze(-1))
        self.num_bivectors = len(self.bivector_indices)

        # Overlapping rotors
        self.rotor_bivectors = nn.Parameter(torch.Tensor(num_rotors, self.num_bivectors))

        # Mixing weights
        self.weights = nn.Parameter(torch.Tensor(channels, num_rotors))

        # Rotor cache for eval mode
        self._cached_R = None
        self._cached_R_rev = None

        self.reset_parameters()

    def reset_parameters(self):
        """Initialize with small rotations and uniform weights."""
        nn.init.normal_(self.rotor_bivectors, std=0.01)
        nn.init.xavier_uniform_(self.weights)

    def _compute_rotors(self, device, dtype):
        """Compute R and R~ from bivector weights.

        Args:
            device (torch.device): Target device
            dtype (torch.dtype): Target data type

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: Rotor and reversed rotor multivectors
        """
        B = torch.zeros(self.num_rotors, self.algebra.dim, device=device, dtype=dtype)
        indices = self.bivector_indices.unsqueeze(0).expand(self.num_rotors, -1)
        B.scatter_(1, indices, self.rotor_bivectors)

        if self.use_decomposition:
            R = self.algebra.exp_decomposed(
                -0.5 * B, use_decomposition=True, k=self.decomp_k
            )
        else:
            R = self.algebra.exp(-0.5 * B)  # [K, D]
        R_rev = self.algebra.reverse(R)
        return R, R_rev

    def forward(self, x: torch.Tensor, return_invariants: bool = False) -> torch.Tensor:
        """Apply weighted multi-rotor transformation.

        Caches rotors during eval mode for faster inference.

        Args:
            x (torch.Tensor): Input [Batch, Channels, Dim].
            return_invariants (bool): If True, returns grade norms.

        Returns:
            torch.Tensor: Transformed output [Batch, Channels, Dim].
        """
        from core.validation import check_multivector, check_channels
        check_multivector(x, self.algebra, "MultiRotorLayer input")
        check_channels(x, self.channels, "MultiRotorLayer input")

        self.algebra.ensure_device(x.device)

        if not self.training and self._cached_R is not None:
            R, R_rev = self._cached_R, self._cached_R_rev
        else:
            R, R_rev = self._compute_rotors(x.device, x.dtype)
            if not self.training:
                self._cached_R = R
                self._cached_R_rev = R_rev

        # Sandwich Product
        x_expanded = x.unsqueeze(2)
        R_expanded = R.view(1, 1, self.num_rotors, -1)
        R_rev_expanded = R_rev.view(1, 1, self.num_rotors, -1)

        Rx = self.algebra.geometric_product(R_expanded, x_expanded)
        rotated_x = self.algebra.geometric_product(Rx, R_rev_expanded)

        # Superposition
        out = torch.einsum('ck,bckd->bcd', self.weights, rotated_x)

        if return_invariants:
            return self.algebra.get_grade_norms(out)

        return out

    def train(self, mode: bool = True):
        """Override to invalidate rotor cache when switching to train mode.

        Args:
            mode (bool): Whether to set to training mode.
        """
        if mode:
            self._cached_R = None
            self._cached_R_rev = None
        return super().train(mode)

    def sparsity_loss(self) -> torch.Tensor:
        """Computes the L1 sparsity loss for rotor bivectors and weights.

        Returns:
            torch.Tensor: Scalar sparsity loss.
        """
        return torch.norm(self.rotor_bivectors, p=1) + torch.norm(self.weights, p=1)

__init__(algebra, channels, num_rotors=8, use_decomposition=False, decomp_k=None)

Initialize Multi-Rotor Layer.

Parameters:

Name Type Description Default
algebra CliffordAlgebra

The algebra instance.

required
channels int

Input features.

required
num_rotors int

Parallel heads.

8
use_decomposition bool

If True, use bivector decomposition. Reference: Pence et al. (2025), arXiv:2507.11688v1

False
decomp_k int

Number of simple components for decomposition.

None
Source code in layers/primitives/multi_rotor.py
def __init__(
    self,
    algebra: CliffordAlgebra,
    channels: int,
    num_rotors: int = 8,
    use_decomposition: bool = False,
    decomp_k: int = None
):
    """Initialize Multi-Rotor Layer.

    Args:
        algebra (CliffordAlgebra): The algebra instance.
        channels (int): Input features.
        num_rotors (int): Parallel heads.
        use_decomposition (bool): If True, use bivector decomposition.
            Reference: Pence et al. (2025), arXiv:2507.11688v1
        decomp_k (int, optional): Number of simple components for decomposition.
    """
    super().__init__(algebra)
    self.channels = channels
    self.num_rotors = num_rotors
    self.use_decomposition = use_decomposition
    self.decomp_k = decomp_k

    # Use algebra's precomputed grade masks for bivector indices
    bv_mask = algebra.grade_masks[2]
    self.register_buffer('bivector_indices', bv_mask.nonzero(as_tuple=False).squeeze(-1))
    self.num_bivectors = len(self.bivector_indices)

    # Overlapping rotors
    self.rotor_bivectors = nn.Parameter(torch.Tensor(num_rotors, self.num_bivectors))

    # Mixing weights
    self.weights = nn.Parameter(torch.Tensor(channels, num_rotors))

    # Rotor cache for eval mode
    self._cached_R = None
    self._cached_R_rev = None

    self.reset_parameters()

reset_parameters()

Initialize with small rotations and uniform weights.

Source code in layers/primitives/multi_rotor.py
def reset_parameters(self):
    """Initialize with small rotations and uniform weights."""
    nn.init.normal_(self.rotor_bivectors, std=0.01)
    nn.init.xavier_uniform_(self.weights)

forward(x, return_invariants=False)

Apply weighted multi-rotor transformation.

Caches rotors during eval mode for faster inference.

Parameters:

Name Type Description Default
x Tensor

Input [Batch, Channels, Dim].

required
return_invariants bool

If True, returns grade norms.

False

Returns:

Type Description
Tensor

torch.Tensor: Transformed output [Batch, Channels, Dim].

Source code in layers/primitives/multi_rotor.py
def forward(self, x: torch.Tensor, return_invariants: bool = False) -> torch.Tensor:
    """Apply weighted multi-rotor transformation.

    Caches rotors during eval mode for faster inference.

    Args:
        x (torch.Tensor): Input [Batch, Channels, Dim].
        return_invariants (bool): If True, returns grade norms.

    Returns:
        torch.Tensor: Transformed output [Batch, Channels, Dim].
    """
    from core.validation import check_multivector, check_channels
    check_multivector(x, self.algebra, "MultiRotorLayer input")
    check_channels(x, self.channels, "MultiRotorLayer input")

    self.algebra.ensure_device(x.device)

    if not self.training and self._cached_R is not None:
        R, R_rev = self._cached_R, self._cached_R_rev
    else:
        R, R_rev = self._compute_rotors(x.device, x.dtype)
        if not self.training:
            self._cached_R = R
            self._cached_R_rev = R_rev

    # Sandwich Product
    x_expanded = x.unsqueeze(2)
    R_expanded = R.view(1, 1, self.num_rotors, -1)
    R_rev_expanded = R_rev.view(1, 1, self.num_rotors, -1)

    Rx = self.algebra.geometric_product(R_expanded, x_expanded)
    rotated_x = self.algebra.geometric_product(Rx, R_rev_expanded)

    # Superposition
    out = torch.einsum('ck,bckd->bcd', self.weights, rotated_x)

    if return_invariants:
        return self.algebra.get_grade_norms(out)

    return out

train(mode=True)

Override to invalidate rotor cache when switching to train mode.

Parameters:

Name Type Description Default
mode bool

Whether to set to training mode.

True
Source code in layers/primitives/multi_rotor.py
def train(self, mode: bool = True):
    """Override to invalidate rotor cache when switching to train mode.

    Args:
        mode (bool): Whether to set to training mode.
    """
    if mode:
        self._cached_R = None
        self._cached_R_rev = None
    return super().train(mode)

sparsity_loss()

Computes the L1 sparsity loss for rotor bivectors and weights.

Returns:

Type Description
Tensor

torch.Tensor: Scalar sparsity loss.

Source code in layers/primitives/multi_rotor.py
def sparsity_loss(self) -> torch.Tensor:
    """Computes the L1 sparsity loss for rotor bivectors and weights.

    Returns:
        torch.Tensor: Scalar sparsity loss.
    """
    return torch.norm(self.rotor_bivectors, p=1) + torch.norm(self.weights, p=1)

CliffordLinear

Bases: CliffordModule

Fully connected layer with optional rotor-based backend.

Can use either: - Traditional scalar weight matrix (default, backward compatible) - Rotor-based transformation (new, parameter efficient via RotorGadget)

The traditional backend uses O(in_channels x out_channels) parameters, while the rotor backend uses O(num_rotor_pairs x n(n-1)/2) parameters where n is the number of basis vectors.

Attributes:

Name Type Description
in_channels int

Input features.

out_channels int

Output features.

backend str

'traditional' or 'rotor'

weight Parameter | None

Weights [Out, In] (traditional backend only).

bias Parameter | None

Bias multivector [Out, Dim] (traditional backend only).

gadget Module | None

Rotor transformation (rotor backend only).

Source code in layers/primitives/linear.py
class CliffordLinear(CliffordModule):
    """Fully connected layer with optional rotor-based backend.

    Can use either:
    - Traditional scalar weight matrix (default, backward compatible)
    - Rotor-based transformation (new, parameter efficient via RotorGadget)

    The traditional backend uses O(in_channels x out_channels) parameters,
    while the rotor backend uses O(num_rotor_pairs x n(n-1)/2) parameters
    where n is the number of basis vectors.

    Attributes:
        in_channels (int): Input features.
        out_channels (int): Output features.
        backend (str): 'traditional' or 'rotor'
        weight (torch.nn.Parameter | None): Weights [Out, In] (traditional backend only).
        bias (torch.nn.Parameter | None): Bias multivector [Out, Dim] (traditional backend only).
        gadget (nn.Module | None): Rotor transformation (rotor backend only).
    """

    def __init__(
        self,
        algebra: CliffordAlgebra,
        in_channels: int,
        out_channels: int,
        backend: Literal['traditional', 'rotor'] = 'traditional',
        num_rotor_pairs: int = 4,
        use_decomposition: bool = False,
        decomp_k: int = 10,
        aggregation: Literal['mean', 'sum', 'learned'] = 'mean',
        shuffle: Literal['none', 'fixed', 'random'] = 'none',
    ):
        """Initialize Clifford Linear.

        Args:
            algebra (CliffordAlgebra): The algebra instance.
            in_channels (int): Input size.
            out_channels (int): Output size.
            backend (str): 'traditional' for standard linear layer,
                          'rotor' for rotor-based transformation
            num_rotor_pairs (int): Number of rotor pairs (rotor backend only)
            use_decomposition (bool): Use bivector decomposition (rotor backend only)
            decomp_k (int): Decomposition iterations (rotor backend only)
            aggregation (str): Aggregation method (rotor backend only)
            shuffle (str): Input channel shuffle strategy (rotor backend only):
                - 'none': No shuffle (default)
                - 'fixed': Fixed random permutation
                - 'random': Random permutation each forward pass
        """
        super().__init__(algebra)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.backend = backend

        if backend == 'traditional':
            self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels))
            self.bias = nn.Parameter(torch.Tensor(out_channels, algebra.dim))
            self.reset_parameters()
            self.gadget = None

        elif backend == 'rotor':
            from .rotor_gadget import RotorGadget
            self.gadget = RotorGadget(
                algebra=algebra,
                in_channels=in_channels,
                out_channels=out_channels,
                num_rotor_pairs=num_rotor_pairs,
                use_decomposition=use_decomposition,
                decomp_k=decomp_k,
                aggregation=aggregation,
                shuffle=shuffle,
                bias=True,  # Include bias in rotor gadget
            )
            self.weight = None
            self.bias = None

        else:
            raise ValueError(
                f"Unknown backend: {backend}. Must be 'traditional' or 'rotor'."
            )

    def reset_parameters(self):
        """Initialize weights with Xavier uniform and zero bias."""
        if self.backend == 'traditional':
            nn.init.xavier_uniform_(self.weight)
            nn.init.zeros_(self.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply channel-mixing linear transformation.

        Args:
            x (torch.Tensor): Input [Batch, In, Dim].

        Returns:
            torch.Tensor: Output [Batch, Out, Dim].
        """
        from core.validation import check_multivector, check_channels
        check_multivector(x, self.algebra, "CliffordLinear input")
        check_channels(x, self.in_channels, "CliffordLinear input")

        if self.backend == 'traditional':
            # Traditional linear transformation
            # x: [Batch, In, Dim]
            # weight: [Out, In]
            # out: [Batch, Out, Dim]
            out = torch.einsum('oi,bid->bod', self.weight, x)
            out = out + self.bias.unsqueeze(0)
            return out
        else:
            # Rotor-based transformation
            return self.gadget(x)

    def extra_repr(self) -> str:
        """String representation for debugging.

        Returns:
            str: Layer parameters description
        """
        if self.backend == 'traditional':
            return f"in_channels={self.in_channels}, out_channels={self.out_channels}, backend=traditional"
        else:
            return f"in_channels={self.in_channels}, out_channels={self.out_channels}, backend=rotor"

__init__(algebra, in_channels, out_channels, backend='traditional', num_rotor_pairs=4, use_decomposition=False, decomp_k=10, aggregation='mean', shuffle='none')

Initialize Clifford Linear.

Parameters:

Name Type Description Default
algebra CliffordAlgebra

The algebra instance.

required
in_channels int

Input size.

required
out_channels int

Output size.

required
backend str

'traditional' for standard linear layer, 'rotor' for rotor-based transformation

'traditional'
num_rotor_pairs int

Number of rotor pairs (rotor backend only)

4
use_decomposition bool

Use bivector decomposition (rotor backend only)

False
decomp_k int

Decomposition iterations (rotor backend only)

10
aggregation str

Aggregation method (rotor backend only)

'mean'
shuffle str

Input channel shuffle strategy (rotor backend only): - 'none': No shuffle (default) - 'fixed': Fixed random permutation - 'random': Random permutation each forward pass

'none'
Source code in layers/primitives/linear.py
def __init__(
    self,
    algebra: CliffordAlgebra,
    in_channels: int,
    out_channels: int,
    backend: Literal['traditional', 'rotor'] = 'traditional',
    num_rotor_pairs: int = 4,
    use_decomposition: bool = False,
    decomp_k: int = 10,
    aggregation: Literal['mean', 'sum', 'learned'] = 'mean',
    shuffle: Literal['none', 'fixed', 'random'] = 'none',
):
    """Initialize Clifford Linear.

    Args:
        algebra (CliffordAlgebra): The algebra instance.
        in_channels (int): Input size.
        out_channels (int): Output size.
        backend (str): 'traditional' for standard linear layer,
                      'rotor' for rotor-based transformation
        num_rotor_pairs (int): Number of rotor pairs (rotor backend only)
        use_decomposition (bool): Use bivector decomposition (rotor backend only)
        decomp_k (int): Decomposition iterations (rotor backend only)
        aggregation (str): Aggregation method (rotor backend only)
        shuffle (str): Input channel shuffle strategy (rotor backend only):
            - 'none': No shuffle (default)
            - 'fixed': Fixed random permutation
            - 'random': Random permutation each forward pass
    """
    super().__init__(algebra)
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.backend = backend

    if backend == 'traditional':
        self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels))
        self.bias = nn.Parameter(torch.Tensor(out_channels, algebra.dim))
        self.reset_parameters()
        self.gadget = None

    elif backend == 'rotor':
        from .rotor_gadget import RotorGadget
        self.gadget = RotorGadget(
            algebra=algebra,
            in_channels=in_channels,
            out_channels=out_channels,
            num_rotor_pairs=num_rotor_pairs,
            use_decomposition=use_decomposition,
            decomp_k=decomp_k,
            aggregation=aggregation,
            shuffle=shuffle,
            bias=True,  # Include bias in rotor gadget
        )
        self.weight = None
        self.bias = None

    else:
        raise ValueError(
            f"Unknown backend: {backend}. Must be 'traditional' or 'rotor'."
        )

reset_parameters()

Initialize weights with Xavier uniform and zero bias.

Source code in layers/primitives/linear.py
def reset_parameters(self):
    """Initialize weights with Xavier uniform and zero bias."""
    if self.backend == 'traditional':
        nn.init.xavier_uniform_(self.weight)
        nn.init.zeros_(self.bias)

forward(x)

Apply channel-mixing linear transformation.

Parameters:

Name Type Description Default
x Tensor

Input [Batch, In, Dim].

required

Returns:

Type Description
Tensor

torch.Tensor: Output [Batch, Out, Dim].

Source code in layers/primitives/linear.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Apply channel-mixing linear transformation.

    Args:
        x (torch.Tensor): Input [Batch, In, Dim].

    Returns:
        torch.Tensor: Output [Batch, Out, Dim].
    """
    from core.validation import check_multivector, check_channels
    check_multivector(x, self.algebra, "CliffordLinear input")
    check_channels(x, self.in_channels, "CliffordLinear input")

    if self.backend == 'traditional':
        # Traditional linear transformation
        # x: [Batch, In, Dim]
        # weight: [Out, In]
        # out: [Batch, Out, Dim]
        out = torch.einsum('oi,bid->bod', self.weight, x)
        out = out + self.bias.unsqueeze(0)
        return out
    else:
        # Rotor-based transformation
        return self.gadget(x)

extra_repr()

String representation for debugging.

Returns:

Name Type Description
str str

Layer parameters description

Source code in layers/primitives/linear.py
def extra_repr(self) -> str:
    """String representation for debugging.

    Returns:
        str: Layer parameters description
    """
    if self.backend == 'traditional':
        return f"in_channels={self.in_channels}, out_channels={self.out_channels}, backend=traditional"
    else:
        return f"in_channels={self.in_channels}, out_channels={self.out_channels}, backend=rotor"

RotorGadget

Bases: CliffordModule

Rotor-based linear transformation (Generalized Rotor Gadget).

Replaces standard linear layers with parameter-efficient rotor-sandwich transformations. Instead of using O(in_channels x out_channels) parameters, this uses O(num_rotor_pairs x n(n-1)/2) parameters where n is the number of basis vectors in the Clifford algebra.

Architecture
  1. Partition input channels into blocks
  2. For each rotor pair (i, j):
  3. Apply rotor sandwich: r_ij . x_i . s_ij.H
  4. Pool/aggregate results to output channels

The transformation is: psi(x) = r.x.s.H where r, s are rotors (bivector exponentials).

Attributes:

Name Type Description
algebra CliffordAlgebra

CliffordAlgebra instance

in_channels

Number of input channels

out_channels

Number of output channels

num_rotor_pairs

Number of rotor pairs to use

use_decomposition

Whether to use bivector decomposition

aggregation

Aggregation method ('mean', 'sum', or 'learned')

Source code in layers/primitives/rotor_gadget.py
class RotorGadget(CliffordModule):
    """Rotor-based linear transformation (Generalized Rotor Gadget).

    Replaces standard linear layers with parameter-efficient rotor-sandwich
    transformations. Instead of using O(in_channels x out_channels) parameters,
    this uses O(num_rotor_pairs x n(n-1)/2) parameters where n is the number
    of basis vectors in the Clifford algebra.

    Architecture:
        1. Partition input channels into blocks
        2. For each rotor pair (i, j):
           - Apply rotor sandwich: r_ij . x_i . s_ij.H
        3. Pool/aggregate results to output channels

    The transformation is: psi(x) = r.x.s.H where r, s are rotors (bivector exponentials).

    Attributes:
        algebra: CliffordAlgebra instance
        in_channels: Number of input channels
        out_channels: Number of output channels
        num_rotor_pairs: Number of rotor pairs to use
        use_decomposition: Whether to use bivector decomposition
        aggregation: Aggregation method ('mean', 'sum', or 'learned')
    """

    def __init__(
        self,
        algebra: CliffordAlgebra,
        in_channels: int,
        out_channels: int,
        num_rotor_pairs: int = 4,
        use_decomposition: bool = False,
        decomp_k: int = 10,
        aggregation: Literal['mean', 'sum', 'learned'] = 'mean',
        shuffle: Literal['none', 'fixed', 'random'] = 'none',
        bias: bool = False,
    ):
        """Initialize rotor gadget layer.

        Args:
            algebra: CliffordAlgebra instance
            in_channels: Number of input channels
            out_channels: Number of output channels
            num_rotor_pairs: Number of rotor pairs (higher = more expressive)
            use_decomposition: Use bivector decomposition for efficiency
            decomp_k: Number of iterations for decomposition (if enabled)
            aggregation: How to pool rotor outputs ('mean', 'sum', 'learned')
            shuffle: Input channel shuffle strategy:
                - 'none': No shuffle, sequential block assignment (default)
                - 'fixed': Random permutation at initialization (fixed during training)
                - 'random': Random permutation each forward pass (regularization)
            bias: Whether to include bias term (applied after transformation)
        """
        super().__init__(algebra)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_rotor_pairs = num_rotor_pairs
        self.use_decomposition = use_decomposition
        self.decomp_k = decomp_k
        self.aggregation = aggregation
        self.shuffle = shuffle

        # Use algebra's precomputed grade masks for bivector indices
        if algebra.num_grades > 2:
            bv_mask = algebra.grade_masks[2]
            self.register_buffer('bivector_indices', bv_mask.nonzero(as_tuple=False).squeeze(-1))
        else:
            self.register_buffer('bivector_indices', torch.tensor([], dtype=torch.long, device=algebra.device))
        self.num_bivectors = len(self.bivector_indices)

        if self.num_bivectors == 0:
            raise ValueError(
                f"Algebra has no bivectors. RotorGadget requires "
                "at least one bivector for rotation."
            )

        # Rotor parameters: bivector coefficients for exponential map
        # Left rotors: [num_rotor_pairs, num_bivectors]
        self.bivector_left = nn.Parameter(
            torch.randn(num_rotor_pairs, self.num_bivectors) * 0.1
        )
        # Right rotors: [num_rotor_pairs, num_bivectors]
        self.bivector_right = nn.Parameter(
            torch.randn(num_rotor_pairs, self.num_bivectors) * 0.1
        )

        # Channel routing: block diagonal partitioning (paper style)
        # Each rotor pair processes a subset of input channels
        self._setup_channel_routing()

        # Aggregation weights (if learned)
        if aggregation == 'learned':
            # Learned weights for combining rotor outputs
            self.agg_weights = nn.Parameter(
                torch.ones(num_rotor_pairs, out_channels) / num_rotor_pairs
            )
        else:
            self.register_buffer('agg_weights', None)

        # Optional bias
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channels, algebra.dim))
        else:
            self.register_buffer('bias', None)

        # Rotor cache for eval mode
        self._cached_rotors = None

    def _setup_channel_routing(self):
        """Set up block diagonal channel routing with optional shuffle.

        Partitions input and output channels into blocks, where each rotor
        pair operates on a specific block. Optionally shuffles input channels
        before routing for regularization.
        """
        # Compute block sizes
        in_block_size = max(1, self.in_channels // self.num_rotor_pairs)
        out_block_size = max(1, self.out_channels // self.num_rotor_pairs)

        # Create routing indices
        in_indices = []
        out_indices = []

        for i in range(self.num_rotor_pairs):
            # Input block for this rotor pair
            in_start = i * in_block_size
            in_end = min((i + 1) * in_block_size, self.in_channels)
            in_indices.append((in_start, in_end))

            # Output block for this rotor pair
            out_start = i * out_block_size
            out_end = min((i + 1) * out_block_size, self.out_channels)
            out_indices.append((out_start, out_end))

        self.in_indices = in_indices
        self.out_indices = out_indices

        # Set up channel shuffle permutation
        if self.shuffle == 'fixed':
            # Create fixed random permutation at initialization
            perm = torch.randperm(self.in_channels)
            self.register_buffer('channel_permutation', perm)
        elif self.shuffle == 'random':
            # Random shuffle each forward pass - no fixed permutation
            self.register_buffer('channel_permutation', None)
        else:  # 'none'
            # No shuffle - identity permutation
            self.register_buffer('channel_permutation', None)

    def _bivector_to_multivector(self, bivector_coeffs: torch.Tensor) -> torch.Tensor:
        """Convert bivector coefficients to full multivector via vectorized scatter.

        Args:
            bivector_coeffs: Tensor of shape [..., num_bivectors]

        Returns:
            Multivector tensor of shape [..., algebra.dim]
        """
        batch_shape = bivector_coeffs.shape[:-1]
        mv = torch.zeros(*batch_shape, self.algebra.dim,
                          device=bivector_coeffs.device, dtype=bivector_coeffs.dtype)
        # Expand indices to match batch shape for scatter_
        idx = self.bivector_indices.expand(*batch_shape, -1)
        mv.scatter_(-1, idx, bivector_coeffs)
        return mv

    def _compute_rotors(self):
        """Compute rotor multivectors from bivector parameters.

        Returns:
            Tuple of (left_rotors, right_rotors_reversed) where each is
            a tensor of shape [num_rotor_pairs, algebra.dim]
        """
        # Convert bivector parameters to multivectors
        B_left = self._bivector_to_multivector(self.bivector_left)  # [pairs, dim]
        B_right = self._bivector_to_multivector(self.bivector_right)  # [pairs, dim]

        # Compute rotors via exponential map: R = exp(-0.5 * B)
        if self.use_decomposition:
            # Use decomposed exponential (more efficient)
            R_left = self.algebra.exp_decomposed(
                -0.5 * B_left,
                use_decomposition=True,
                k=self.decomp_k
            )  # [pairs, dim]
            R_right = self.algebra.exp_decomposed(
                -0.5 * B_right,
                use_decomposition=True,
                k=self.decomp_k
            )  # [pairs, dim]
        else:
            # Standard exponential
            R_left = self.algebra.exp(-0.5 * B_left)  # [pairs, dim]
            R_right = self.algebra.exp(-0.5 * B_right)  # [pairs, dim]

        # Compute reverse of right rotors for sandwich product
        R_right_rev = self.algebra.reverse(R_right)  # [pairs, dim]

        return R_left, R_right_rev

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply rotor-based transformation.

        Uses batched geometric products - all rotor pairs are applied in
        parallel via a single pair of GP calls.

        Args:
            x: Input tensor of shape [Batch, In_Channels, Dim]

        Returns:
            Output tensor of shape [Batch, Out_Channels, Dim]
        """
        from core.validation import check_multivector, check_channels
        check_multivector(x, self.algebra, "RotorGadget input")
        check_channels(x, self.in_channels, "RotorGadget input")

        self.algebra.ensure_device(x.device)

        # Apply input channel shuffle if enabled
        if self.shuffle == 'fixed':
            x = x[:, self.channel_permutation, :]
        elif self.shuffle == 'random':
            perm = torch.randperm(self.in_channels, device=x.device)
            x = x[:, perm, :]

        # Compute rotors (cached in eval mode)
        if not self.training and self._cached_rotors is not None:
            R_left, R_right_rev = self._cached_rotors
        else:
            R_left, R_right_rev = self._compute_rotors()
            if not self.training:
                self._cached_rotors = (R_left, R_right_rev)

        # Vectorized sandwich: apply each rotor pair to its channel block
        # Build expanded rotor tensors [1, in_channels, dim] where each channel
        # gets the rotor for its assigned pair
        D = self.algebra.dim
        R_left_expanded = torch.zeros(1, self.in_channels, D,
                                       device=x.device, dtype=x.dtype)
        R_right_expanded = torch.zeros(1, self.in_channels, D,
                                        device=x.device, dtype=x.dtype)

        for i in range(self.num_rotor_pairs):
            in_start, in_end = self.in_indices[i]
            if in_end > in_start:
                R_left_expanded[0, in_start:in_end] = R_left[i]
                R_right_expanded[0, in_start:in_end] = R_right_rev[i]

        # Two batched GPs instead of 2*K sequential GPs
        temp = self.algebra.geometric_product(R_left_expanded, x)
        concat_out = self.algebra.geometric_product(temp, R_right_expanded)

        # Map to output channels
        out = self._aggregate_to_output_channels(concat_out)

        if self.bias is not None:
            out = out + self.bias.unsqueeze(0)

        return out

    def _aggregate_to_output_channels(self, x: torch.Tensor) -> torch.Tensor:
        """Aggregate rotor pair outputs to match output channel count.

        Args:
            x: Concatenated outputs from rotor pairs [B, total_channels, dim]

        Returns:
            Aggregated output [B, out_channels, dim]
        """
        batch_size = x.shape[0]

        if self.aggregation == 'learned':
            # Weighted aggregation with learned weights
            # agg_weights: [num_pairs, out_channels]
            # Need to apply per-block
            outputs = []
            for i in range(self.num_rotor_pairs):
                in_start, in_end = self.in_indices[i]
                block_size = in_end - in_start
                if block_size == 0:
                    continue

                x_i = x[:, in_start:in_end, :]  # [B, block, dim]
                # Average over block channels and weight
                x_i_mean = x_i.mean(dim=1, keepdim=True)  # [B, 1, dim]
                # Expand to output channels with weights
                weighted = x_i_mean * self.agg_weights[i:i+1, :, None]  # [B, out_ch, dim]
                outputs.append(weighted)

            out = torch.stack(outputs, dim=0).sum(dim=0)  # [B, out_ch, dim]

        elif self.aggregation == 'sum':
            # Simple channel-wise sum with reshaping
            if x.shape[1] == self.out_channels:
                out = x
            elif x.shape[1] > self.out_channels:
                # Pool down by summing
                fold = x.shape[1] // self.out_channels
                out = x[:, :fold*self.out_channels, :].reshape(
                    batch_size, self.out_channels, fold, self.algebra.dim
                ).sum(dim=2)
            else:
                # Expand by tiling
                repeats = (self.out_channels + x.shape[1] - 1) // x.shape[1]
                out = x.repeat(1, repeats, 1)[:, :self.out_channels, :]

        else:  # 'mean'
            # Mean pooling
            if x.shape[1] == self.out_channels:
                out = x
            elif x.shape[1] > self.out_channels:
                # Pool down by averaging
                fold = x.shape[1] // self.out_channels
                out = x[:, :fold*self.out_channels, :].reshape(
                    batch_size, self.out_channels, fold, self.algebra.dim
                ).mean(dim=2)
            else:
                # Expand by tiling
                repeats = (self.out_channels + x.shape[1] - 1) // x.shape[1]
                out = x.repeat(1, repeats, 1)[:, :self.out_channels, :]

        return out

    def train(self, mode: bool = True):
        """Override to invalidate rotor cache when switching to train mode."""
        if mode:
            self._cached_rotors = None
        return super().train(mode)

    def extra_repr(self) -> str:
        """String representation for debugging."""
        return (
            f"in_channels={self.in_channels}, "
            f"out_channels={self.out_channels}, "
            f"num_rotor_pairs={self.num_rotor_pairs}, "
            f"aggregation={self.aggregation}, "
            f"shuffle={self.shuffle}, "
            f"use_decomposition={self.use_decomposition}, "
            f"bias={self.bias is not None}"
        )

__init__(algebra, in_channels, out_channels, num_rotor_pairs=4, use_decomposition=False, decomp_k=10, aggregation='mean', shuffle='none', bias=False)

Initialize rotor gadget layer.

Parameters:

Name Type Description Default
algebra CliffordAlgebra

CliffordAlgebra instance

required
in_channels int

Number of input channels

required
out_channels int

Number of output channels

required
num_rotor_pairs int

Number of rotor pairs (higher = more expressive)

4
use_decomposition bool

Use bivector decomposition for efficiency

False
decomp_k int

Number of iterations for decomposition (if enabled)

10
aggregation Literal['mean', 'sum', 'learned']

How to pool rotor outputs ('mean', 'sum', 'learned')

'mean'
shuffle Literal['none', 'fixed', 'random']

Input channel shuffle strategy: - 'none': No shuffle, sequential block assignment (default) - 'fixed': Random permutation at initialization (fixed during training) - 'random': Random permutation each forward pass (regularization)

'none'
bias bool

Whether to include bias term (applied after transformation)

False
Source code in layers/primitives/rotor_gadget.py
def __init__(
    self,
    algebra: CliffordAlgebra,
    in_channels: int,
    out_channels: int,
    num_rotor_pairs: int = 4,
    use_decomposition: bool = False,
    decomp_k: int = 10,
    aggregation: Literal['mean', 'sum', 'learned'] = 'mean',
    shuffle: Literal['none', 'fixed', 'random'] = 'none',
    bias: bool = False,
):
    """Initialize rotor gadget layer.

    Args:
        algebra: CliffordAlgebra instance
        in_channels: Number of input channels
        out_channels: Number of output channels
        num_rotor_pairs: Number of rotor pairs (higher = more expressive)
        use_decomposition: Use bivector decomposition for efficiency
        decomp_k: Number of iterations for decomposition (if enabled)
        aggregation: How to pool rotor outputs ('mean', 'sum', 'learned')
        shuffle: Input channel shuffle strategy:
            - 'none': No shuffle, sequential block assignment (default)
            - 'fixed': Random permutation at initialization (fixed during training)
            - 'random': Random permutation each forward pass (regularization)
        bias: Whether to include bias term (applied after transformation)
    """
    super().__init__(algebra)

    self.in_channels = in_channels
    self.out_channels = out_channels
    self.num_rotor_pairs = num_rotor_pairs
    self.use_decomposition = use_decomposition
    self.decomp_k = decomp_k
    self.aggregation = aggregation
    self.shuffle = shuffle

    # Use algebra's precomputed grade masks for bivector indices
    if algebra.num_grades > 2:
        bv_mask = algebra.grade_masks[2]
        self.register_buffer('bivector_indices', bv_mask.nonzero(as_tuple=False).squeeze(-1))
    else:
        self.register_buffer('bivector_indices', torch.tensor([], dtype=torch.long, device=algebra.device))
    self.num_bivectors = len(self.bivector_indices)

    if self.num_bivectors == 0:
        raise ValueError(
            f"Algebra has no bivectors. RotorGadget requires "
            "at least one bivector for rotation."
        )

    # Rotor parameters: bivector coefficients for exponential map
    # Left rotors: [num_rotor_pairs, num_bivectors]
    self.bivector_left = nn.Parameter(
        torch.randn(num_rotor_pairs, self.num_bivectors) * 0.1
    )
    # Right rotors: [num_rotor_pairs, num_bivectors]
    self.bivector_right = nn.Parameter(
        torch.randn(num_rotor_pairs, self.num_bivectors) * 0.1
    )

    # Channel routing: block diagonal partitioning (paper style)
    # Each rotor pair processes a subset of input channels
    self._setup_channel_routing()

    # Aggregation weights (if learned)
    if aggregation == 'learned':
        # Learned weights for combining rotor outputs
        self.agg_weights = nn.Parameter(
            torch.ones(num_rotor_pairs, out_channels) / num_rotor_pairs
        )
    else:
        self.register_buffer('agg_weights', None)

    # Optional bias
    if bias:
        self.bias = nn.Parameter(torch.zeros(out_channels, algebra.dim))
    else:
        self.register_buffer('bias', None)

    # Rotor cache for eval mode
    self._cached_rotors = None

forward(x)

Apply rotor-based transformation.

Uses batched geometric products - all rotor pairs are applied in parallel via a single pair of GP calls.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape [Batch, In_Channels, Dim]

required

Returns:

Type Description
Tensor

Output tensor of shape [Batch, Out_Channels, Dim]

Source code in layers/primitives/rotor_gadget.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Apply rotor-based transformation.

    Uses batched geometric products - all rotor pairs are applied in
    parallel via a single pair of GP calls.

    Args:
        x: Input tensor of shape [Batch, In_Channels, Dim]

    Returns:
        Output tensor of shape [Batch, Out_Channels, Dim]
    """
    from core.validation import check_multivector, check_channels
    check_multivector(x, self.algebra, "RotorGadget input")
    check_channels(x, self.in_channels, "RotorGadget input")

    self.algebra.ensure_device(x.device)

    # Apply input channel shuffle if enabled
    if self.shuffle == 'fixed':
        x = x[:, self.channel_permutation, :]
    elif self.shuffle == 'random':
        perm = torch.randperm(self.in_channels, device=x.device)
        x = x[:, perm, :]

    # Compute rotors (cached in eval mode)
    if not self.training and self._cached_rotors is not None:
        R_left, R_right_rev = self._cached_rotors
    else:
        R_left, R_right_rev = self._compute_rotors()
        if not self.training:
            self._cached_rotors = (R_left, R_right_rev)

    # Vectorized sandwich: apply each rotor pair to its channel block
    # Build expanded rotor tensors [1, in_channels, dim] where each channel
    # gets the rotor for its assigned pair
    D = self.algebra.dim
    R_left_expanded = torch.zeros(1, self.in_channels, D,
                                   device=x.device, dtype=x.dtype)
    R_right_expanded = torch.zeros(1, self.in_channels, D,
                                    device=x.device, dtype=x.dtype)

    for i in range(self.num_rotor_pairs):
        in_start, in_end = self.in_indices[i]
        if in_end > in_start:
            R_left_expanded[0, in_start:in_end] = R_left[i]
            R_right_expanded[0, in_start:in_end] = R_right_rev[i]

    # Two batched GPs instead of 2*K sequential GPs
    temp = self.algebra.geometric_product(R_left_expanded, x)
    concat_out = self.algebra.geometric_product(temp, R_right_expanded)

    # Map to output channels
    out = self._aggregate_to_output_channels(concat_out)

    if self.bias is not None:
        out = out + self.bias.unsqueeze(0)

    return out

train(mode=True)

Override to invalidate rotor cache when switching to train mode.

Source code in layers/primitives/rotor_gadget.py
def train(self, mode: bool = True):
    """Override to invalidate rotor cache when switching to train mode."""
    if mode:
        self._cached_rotors = None
    return super().train(mode)

extra_repr()

String representation for debugging.

Source code in layers/primitives/rotor_gadget.py
def extra_repr(self) -> str:
    """String representation for debugging."""
    return (
        f"in_channels={self.in_channels}, "
        f"out_channels={self.out_channels}, "
        f"num_rotor_pairs={self.num_rotor_pairs}, "
        f"aggregation={self.aggregation}, "
        f"shuffle={self.shuffle}, "
        f"use_decomposition={self.use_decomposition}, "
        f"bias={self.bias is not None}"
    )

CliffordLayerNorm

Bases: CliffordModule

Geometric LayerNorm that preserves direction and recovers scale.

Normalizes the multivector to unit norm (preserving geometric direction), then injects the original log-magnitude into the scalar (grade-0) part via a learnable gate.

Attributes:

Name Type Description
weight Parameter

Per-channel direction scale [C].

bias Parameter

Per-channel scalar bias [C].

norm_scale Parameter

Per-channel gate for log-magnitude injection into grade-0. Initialized to zero so the layer starts identical to the old (scale-discarding) behaviour.

Source code in layers/primitives/normalization.py
class CliffordLayerNorm(CliffordModule):
    """Geometric LayerNorm that preserves direction and recovers scale.

    Normalizes the multivector to unit norm (preserving geometric direction),
    then injects the original log-magnitude into the scalar (grade-0) part
    via a learnable gate.

    Attributes:
        weight (nn.Parameter): Per-channel direction scale [C].
        bias (nn.Parameter): Per-channel scalar bias [C].
        norm_scale (nn.Parameter): Per-channel gate for log-magnitude
            injection into grade-0.  Initialized to zero so the layer
            starts identical to the old (scale-discarding) behaviour.
    """

    def __init__(self, algebra: CliffordAlgebra, channels: int, eps: float = 1e-6, recover: bool = True):
        """Sets up normalization.

        Args:
            algebra (CliffordAlgebra): The algebra instance.
            channels (int): Features.
            eps (float): Stability term.
            recover (bool): Whether to inject original scale into the scalar part.
        """
        super().__init__(algebra)
        self.eps = eps
        self.recover = recover

        self.weight = nn.Parameter(torch.ones(channels))
        self.bias = nn.Parameter(torch.zeros(channels))
        # Learnable gate: how much of the original log-magnitude to push
        # into the scalar part.  Zero-init -> backward compatible at start.
        if recover:
            self.norm_scale = nn.Parameter(torch.zeros(channels))
        else:
            self.register_buffer('norm_scale', None)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Normalizes energy, preserves direction, optionally recovers scale in grade-0.

        Args:
            x (torch.Tensor): Input [Batch, Channels, Dim].

        Returns:
            torch.Tensor: Normalized input.
        """
        # Per-channel magnitude
        norm = x.norm(dim=-1, keepdim=True)  # [B, C, 1]

        # Normalize direction
        x_normalized = x / (norm + self.eps)

        # Affine transform on direction
        out = x_normalized * self.weight.view(1, -1, 1)

        out = out.clone()
        out[..., 0] = out[..., 0] + self.bias.view(1, -1)

        if self.recover:
            # Push original magnitude into scalar (grade-0) part.
            # log1p keeps the value bounded and well-behaved for gradients.
            log_norm = torch.log1p(norm.squeeze(-1))  # [B, C]
            out[..., 0] = out[..., 0] + self.norm_scale.view(1, -1) * log_norm

        return out

__init__(algebra, channels, eps=1e-06, recover=True)

Sets up normalization.

Parameters:

Name Type Description Default
algebra CliffordAlgebra

The algebra instance.

required
channels int

Features.

required
eps float

Stability term.

1e-06
recover bool

Whether to inject original scale into the scalar part.

True
Source code in layers/primitives/normalization.py
def __init__(self, algebra: CliffordAlgebra, channels: int, eps: float = 1e-6, recover: bool = True):
    """Sets up normalization.

    Args:
        algebra (CliffordAlgebra): The algebra instance.
        channels (int): Features.
        eps (float): Stability term.
        recover (bool): Whether to inject original scale into the scalar part.
    """
    super().__init__(algebra)
    self.eps = eps
    self.recover = recover

    self.weight = nn.Parameter(torch.ones(channels))
    self.bias = nn.Parameter(torch.zeros(channels))
    # Learnable gate: how much of the original log-magnitude to push
    # into the scalar part.  Zero-init -> backward compatible at start.
    if recover:
        self.norm_scale = nn.Parameter(torch.zeros(channels))
    else:
        self.register_buffer('norm_scale', None)

forward(x)

Normalizes energy, preserves direction, optionally recovers scale in grade-0.

Parameters:

Name Type Description Default
x Tensor

Input [Batch, Channels, Dim].

required

Returns:

Type Description
Tensor

torch.Tensor: Normalized input.

Source code in layers/primitives/normalization.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Normalizes energy, preserves direction, optionally recovers scale in grade-0.

    Args:
        x (torch.Tensor): Input [Batch, Channels, Dim].

    Returns:
        torch.Tensor: Normalized input.
    """
    # Per-channel magnitude
    norm = x.norm(dim=-1, keepdim=True)  # [B, C, 1]

    # Normalize direction
    x_normalized = x / (norm + self.eps)

    # Affine transform on direction
    out = x_normalized * self.weight.view(1, -1, 1)

    out = out.clone()
    out[..., 0] = out[..., 0] + self.bias.view(1, -1)

    if self.recover:
        # Push original magnitude into scalar (grade-0) part.
        # log1p keeps the value bounded and well-behaved for gradients.
        log_norm = torch.log1p(norm.squeeze(-1))  # [B, C]
        out[..., 0] = out[..., 0] + self.norm_scale.view(1, -1) * log_norm

    return out

BladeSelector

Bases: CliffordModule

Blade Selector. Filters insignificant components.

Learns to weigh geometric grades, suppressing less relevant ones.

Attributes:

Name Type Description
weights Parameter

Soft gates [Channels, Dim].

Source code in layers/primitives/projection.py
class BladeSelector(CliffordModule):
    """Blade Selector. Filters insignificant components.

    Learns to weigh geometric grades, suppressing less relevant ones.

    Attributes:
        weights (nn.Parameter): Soft gates [Channels, Dim].
    """

    def __init__(self, algebra: CliffordAlgebra, channels: int):
        """Sets up the selector.

        Args:
            algebra (CliffordAlgebra): The algebra instance.
            channels (int): Input features.
        """
        super().__init__(algebra)

        self.weights = nn.Parameter(torch.Tensor(channels, algebra.dim))

        self.reset_parameters()

    def reset_parameters(self):
        """Initializes weights to one (pass-through)."""
        nn.init.ones_(self.weights)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Gates the grades.

        Args:
            x (torch.Tensor): Input [Batch, Channels, Dim].

        Returns:
            torch.Tensor: Filtered input.
        """
        # Sigmoid gate
        w = torch.sigmoid(self.weights).unsqueeze(0)
        return x * w

__init__(algebra, channels)

Sets up the selector.

Parameters:

Name Type Description Default
algebra CliffordAlgebra

The algebra instance.

required
channels int

Input features.

required
Source code in layers/primitives/projection.py
def __init__(self, algebra: CliffordAlgebra, channels: int):
    """Sets up the selector.

    Args:
        algebra (CliffordAlgebra): The algebra instance.
        channels (int): Input features.
    """
    super().__init__(algebra)

    self.weights = nn.Parameter(torch.Tensor(channels, algebra.dim))

    self.reset_parameters()

reset_parameters()

Initializes weights to one (pass-through).

Source code in layers/primitives/projection.py
def reset_parameters(self):
    """Initializes weights to one (pass-through)."""
    nn.init.ones_(self.weights)

forward(x)

Gates the grades.

Parameters:

Name Type Description Default
x Tensor

Input [Batch, Channels, Dim].

required

Returns:

Type Description
Tensor

torch.Tensor: Filtered input.

Source code in layers/primitives/projection.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Gates the grades.

    Args:
        x (torch.Tensor): Input [Batch, Channels, Dim].

    Returns:
        torch.Tensor: Filtered input.
    """
    # Sigmoid gate
    w = torch.sigmoid(self.weights).unsqueeze(0)
    return x * w

Blocks

GeometricProductAttention

Bases: CliffordModule

Multi-head attention using geometric product scoring.

Standard attention: score(Q, K) = / sqrt(d) (scalar only)

GA attention

product = Q_c * reverse(K_c) (geometric product per head-channel) score = (0 + lambda * ||_2||_F) / sqrt(H_c * dim)

The grade-0 (scalar) part measures alignment (like dot product). The grade-2 (bivector) part measures relative orientation - novel.

Memory: naive [B, H, L, L, H_c, D] is too large. We chunk over L_q in blocks of BLOCK_SIZE to bound peak VRAM.

Attributes:

Name Type Description
num_heads int

Number of attention heads.

head_channels int

Channels per head.

causal bool

If True, apply autoregressive causal mask.

bivector_weight float

lambda_ - weight of bivector score component.

Source code in layers/blocks/attention.py
class GeometricProductAttention(CliffordModule):
    """Multi-head attention using geometric product scoring.

    Standard attention: score(Q, K) = <Q, K> / sqrt(d)  (scalar only)

    GA attention:
        product = Q_c * reverse(K_c)    (geometric product per head-channel)
        score   = (<product>_0 + lambda_ * ||<product>_2||_F) / sqrt(H_c * dim)

    The grade-0 (scalar) part measures alignment (like dot product).
    The grade-2 (bivector) part measures relative orientation - novel.

    Memory: naive [B, H, L, L, H_c, D] is too large. We chunk over L_q
    in blocks of BLOCK_SIZE to bound peak VRAM.

    Attributes:
        num_heads (int): Number of attention heads.
        head_channels (int): Channels per head.
        causal (bool): If True, apply autoregressive causal mask.
        bivector_weight (float): lambda_ - weight of bivector score component.
    """

    def __init__(
        self,
        algebra: CliffordAlgebra,
        channels: int,
        num_heads: int,
        causal: bool = True,
        bivector_weight: float = 0.5,
        dropout: float = 0.0,
    ):
        """Sets up geometric product attention.

        Args:
            algebra: Clifford algebra instance.
            channels: Total number of multivector channels.
            num_heads: Number of attention heads.
            causal: Apply causal mask for autoregressive generation.
            bivector_weight: lambda_ weight on bivector score component.
            dropout: Dropout rate on attention weights.
        """
        super().__init__(algebra)
        assert channels % num_heads == 0, \
            f"channels ({channels}) must be divisible by num_heads ({num_heads})"

        self.channels = channels
        self.num_heads = num_heads
        self.head_channels = channels // num_heads
        self.causal = causal
        self.bivector_weight = bivector_weight

        # Q, K, V projections operate on [B*L, channels, dim]
        self.q_proj = CliffordLinear(algebra, channels, channels)
        self.k_proj = CliffordLinear(algebra, channels, channels)
        self.v_proj = CliffordLinear(algebra, channels, channels)
        self.out_proj = CliffordLinear(algebra, channels, channels)

        self.attn_dropout = nn.Dropout(dropout) if dropout > 0.0 else None

        # Precompute bilinear score tables (replaces pairwise geometric product)
        self._precompute_score_tables()

    def _precompute_score_tables(self):
        """Precomputes lookup tables for efficient attention scoring.

        Replaces the O(L**2) full pairwise geometric product with direct bilinear
        forms for grade-0 and grade-2 components of Q * reverse(K):

        Grade-0:  <Q * rev(K)>_0 = Sum_a Q[a] * K[a] * metric_rev[a]
                  -> simple weighted dot product, no pairwise expansion needed.

        Grade-2:  <Q * rev(K)>_r = Sum_a Q[a] * K[a^r] * g2_sign[r, a]
                  -> precompute K_g2 once, then batched matmul.

        Memory: ~4 MB peak vs ~256 MB for the naive B_gathered approach.
        """
        alg = self.algebra
        D = alg.dim

        # Grade-0 metric: metric_rev[a] = gp_signs[a, 0] * rev_signs[a]
        # gp_signs[a, 0] is the sign when A[a] * B[a] contributes to output blade 0
        metric_rev = alg.gp_signs[:, 0].float() * alg.rev_signs.float()
        self.register_buffer('_metric_rev', metric_rev)  # [D]

        # Grade-2 tables: for each grade-2 blade r, for each A-blade a:
        #   B-blade  = a XOR r
        #   sign     = rev_sign[a^r] * gp_signs[a, r]
        g2_blades = [i for i in range(D) if bin(i).count('1') == 2]
        n_g2 = len(g2_blades)
        self.n_g2 = n_g2

        if n_g2 > 0:
            a_idx = torch.arange(D, device=alg.device)
            r_vals = torch.tensor(g2_blades, dtype=torch.long, device=alg.device)  # [n_g2]

            # b_idx[r, a] = a XOR r_vals[r]
            b_idx = a_idx.unsqueeze(0) ^ r_vals.unsqueeze(1)  # [n_g2, D]

            # rev_sign at the B-blade position
            rev_b = alg.rev_signs.float()[b_idx]  # [n_g2, D]

            # gp_signs[a, r_val]: sign when A[a] pairs with B[a^r] to give output r
            # alg.gp_signs[:, r_vals] -> [D, n_g2]; transpose -> [n_g2, D]
            gp_ar = alg.gp_signs[:, r_vals].float().T  # [n_g2, D]

            g2_sign = rev_b * gp_ar  # [n_g2, D]
        else:
            b_idx = torch.zeros(0, D, dtype=torch.long, device=alg.device)
            g2_sign = torch.zeros(0, D, device=alg.device)

        self.register_buffer('_g2_b_idx', b_idx)   # [n_g2, D] long
        self.register_buffer('_g2_sign', g2_sign)  # [n_g2, D] float

    def _compute_score(
        self,
        q_head: torch.Tensor,
        k_head: torch.Tensor,
        k_g2: torch.Tensor,
    ) -> torch.Tensor:
        """Computes GA attention score using precomputed bilinear form tables.

        Avoids the O(B.H.Lq.Lk.Hc.D.BLOCK) memory of the full pairwise
        geometric product. Instead:

          Grade-0: score_g0 = Q_weighted @ K^T  (weighted dot product, peak ~1 MB)
          Grade-2: batched matmul via precomputed k_g2            (peak ~4 MB)

        Args:
            q_head: Query block [B, H, Lq, Hc, D]
            k_head: Keys        [B, H, Lk, Hc, D]
            k_g2:   Precomputed [B, H, Lk, Hc, n_g2, D]
                    k_g2[b,h,j,c,r,d] = K[b,h,j,c, d^r] * g2_sign[r, d]

        Returns:
            scores: [B, H, Lq, Lk]
        """
        B, H, Lq, Hc, D = q_head.shape
        Lk = k_head.shape[2]
        n_g2 = self.n_g2

        # == Grade-0 score ====================================================
        # <Q * rev(K)>_0 = Sum_c Sum_d  Q[c,d] * K[c,d] * metric_rev[d]
        # Implemented as a batched matrix multiply: [B,H,Lq,Hc*D] @ [B,H,Hc*D,Lk]
        q_weighted = q_head * self._metric_rev          # [B, H, Lq, Hc, D]
        q_flat = q_weighted.reshape(B, H, Lq, Hc * D)  # [B, H, Lq, Hc*D]
        k_flat = k_head.reshape(B, H, Lk, Hc * D)      # [B, H, Lk, Hc*D]
        score_g0 = torch.matmul(q_flat, k_flat.transpose(-2, -1))  # [B, H, Lq, Lk]

        # == Grade-2 score ====================================================
        # ||<Q * rev(K)>_2||_F = sqrt(Sum_c Sum_r (Sum_d Q[c,d]*k_g2[j,c,r,d])^2)
        # Batched matmul merging (B, H, Hc) into one batch dimension:
        #   q_2d:     [B*H*Hc, Lq, D]
        #   k_g2_2d:  [B*H*Hc, Lk*n_g2, D]   (Lk and n_g2 merged, n_g2 varies fast)
        #   comp:     [B*H*Hc, Lq, Lk*n_g2]
        # Peak ~4 MB vs ~256 MB for the naive B_gathered approach.
        if n_g2 > 0:
            q_2d = q_head.permute(0, 1, 3, 2, 4).reshape(B * H * Hc, Lq, D)
            # k_g2: [B, H, Lk, Hc, n_g2, D] -> permute to [B, H, Hc, Lk, n_g2, D]
            k_g2_t = k_g2.permute(0, 1, 3, 2, 4, 5)
            k_g2_2d = k_g2_t.reshape(B * H * Hc, Lk * n_g2, D)
            # [B*H*Hc, Lq, D] @ [B*H*Hc, D, Lk*n_g2] -> [B*H*Hc, Lq, Lk*n_g2]
            comp = torch.bmm(q_2d, k_g2_2d.transpose(-2, -1))
            # Sum squared components over n_g2, then sum over Hc -> [B, H, Lq, Lk]
            comp_sq = comp.reshape(B * H * Hc, Lq, Lk, n_g2).pow(2).sum(-1)  # [B*H*Hc, Lq, Lk]
            score_g2_sq = comp_sq.reshape(B, H, Hc, Lq, Lk).sum(2)           # [B, H, Lq, Lk]
            score_g2 = score_g2_sq.sqrt()
        else:
            score_g2 = torch.zeros_like(score_g0)

        # Combined score
        scale = math.sqrt(self.head_channels * self.algebra.dim)
        return (score_g0 + self.bivector_weight * score_g2) / scale

    def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None) -> torch.Tensor:
        """Computes geometric product attention.

        Args:
            x: Input multivectors [B, L, C, D].
            key_padding_mask: Optional [B, L] bool mask where True = padded (ignored).

        Returns:
            Output multivectors [B, L, C, D].
        """
        B, L, C, D = x.shape

        # Project Q, K, V (CliffordLinear expects [B, C, D])
        x_flat = x.reshape(B * L, C, D)
        Q = self.q_proj(x_flat).reshape(B, L, C, D)
        K = self.k_proj(x_flat).reshape(B, L, C, D)
        V = self.v_proj(x_flat).reshape(B, L, C, D)

        H = self.num_heads
        Hc = self.head_channels

        # Reshape to [B, H, L, Hc, D]
        Q = Q.reshape(B, L, H, Hc, D).permute(0, 2, 1, 3, 4)  # [B, H, L, Hc, D]
        K = K.reshape(B, L, H, Hc, D).permute(0, 2, 1, 3, 4)
        V = V.reshape(B, L, H, Hc, D).permute(0, 2, 1, 3, 4)

        # Build causal mask once [L, L]
        if self.causal:
            causal_mask = torch.triu(
                torch.ones(L, L, device=x.device, dtype=torch.bool), diagonal=1
            )  # True = masked (future)
        else:
            causal_mask = None

        # Precompute K_g2 once for all query blocks - much cheaper than recomputing
        # k_g2[b,h,j,c,r,d] = K[b,h,j,c, d^r_val] * g2_sign[r, d]
        # Shape: [B, H, L, Hc, n_g2, D]  ~= 768 KB for the small MPS config
        K_g2 = K[..., self._g2_b_idx] * self._g2_sign  # [B, H, L, Hc, n_g2, D]

        # Chunked attention over query positions to bound memory
        output_chunks = []
        for q_start in range(0, L, _BLOCK_SIZE):
            q_end = min(q_start + _BLOCK_SIZE, L)

            Q_block = Q[:, :, q_start:q_end]  # [B, H, Lq, Hc, D]

            # Compute scores: [B, H, Lq, L]
            scores = self._compute_score(Q_block, K, K_g2)

            # Apply causal mask
            if causal_mask is not None:
                mask_block = causal_mask[q_start:q_end, :]  # [Lq, L]
                scores = scores.masked_fill(
                    mask_block.unsqueeze(0).unsqueeze(0), float('-inf')
                )

            # Apply key padding mask: True = padded -> -inf
            if key_padding_mask is not None:
                # key_padding_mask: [B, L] -> [B, 1, 1, L]
                scores = scores.masked_fill(
                    key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf')
                )

            # Softmax + dropout
            attn_weights = F.softmax(scores, dim=-1)  # [B, H, Lq, L]
            if self.attn_dropout is not None:
                attn_weights = self.attn_dropout(attn_weights)

            # Aggregate values: sum_k attn[b,h,i,k] * V[b,h,k,Hc,D]
            # attn_weights: [B, H, Lq, L]
            # V:            [B, H, L,  Hc, D]
            # out:          [B, H, Lq, Hc, D]
            out_block = torch.einsum('bhij,bhjcd->bhicd', attn_weights, V)
            output_chunks.append(out_block)

        # Reassemble: [B, H, L, Hc, D]
        output = torch.cat(output_chunks, dim=2)

        # Merge heads back: [B, L, C, D]
        output = output.permute(0, 2, 1, 3, 4).reshape(B, L, C, D)

        # Output projection
        output = self.out_proj(output.reshape(B * L, C, D)).reshape(B, L, C, D)

        return output

__init__(algebra, channels, num_heads, causal=True, bivector_weight=0.5, dropout=0.0)

Sets up geometric product attention.

Parameters:

Name Type Description Default
algebra CliffordAlgebra

Clifford algebra instance.

required
channels int

Total number of multivector channels.

required
num_heads int

Number of attention heads.

required
causal bool

Apply causal mask for autoregressive generation.

True
bivector_weight float

lambda_ weight on bivector score component.

0.5
dropout float

Dropout rate on attention weights.

0.0
Source code in layers/blocks/attention.py
def __init__(
    self,
    algebra: CliffordAlgebra,
    channels: int,
    num_heads: int,
    causal: bool = True,
    bivector_weight: float = 0.5,
    dropout: float = 0.0,
):
    """Sets up geometric product attention.

    Args:
        algebra: Clifford algebra instance.
        channels: Total number of multivector channels.
        num_heads: Number of attention heads.
        causal: Apply causal mask for autoregressive generation.
        bivector_weight: lambda_ weight on bivector score component.
        dropout: Dropout rate on attention weights.
    """
    super().__init__(algebra)
    assert channels % num_heads == 0, \
        f"channels ({channels}) must be divisible by num_heads ({num_heads})"

    self.channels = channels
    self.num_heads = num_heads
    self.head_channels = channels // num_heads
    self.causal = causal
    self.bivector_weight = bivector_weight

    # Q, K, V projections operate on [B*L, channels, dim]
    self.q_proj = CliffordLinear(algebra, channels, channels)
    self.k_proj = CliffordLinear(algebra, channels, channels)
    self.v_proj = CliffordLinear(algebra, channels, channels)
    self.out_proj = CliffordLinear(algebra, channels, channels)

    self.attn_dropout = nn.Dropout(dropout) if dropout > 0.0 else None

    # Precompute bilinear score tables (replaces pairwise geometric product)
    self._precompute_score_tables()

forward(x, key_padding_mask=None)

Computes geometric product attention.

Parameters:

Name Type Description Default
x Tensor

Input multivectors [B, L, C, D].

required
key_padding_mask Tensor

Optional [B, L] bool mask where True = padded (ignored).

None

Returns:

Type Description
Tensor

Output multivectors [B, L, C, D].

Source code in layers/blocks/attention.py
def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None) -> torch.Tensor:
    """Computes geometric product attention.

    Args:
        x: Input multivectors [B, L, C, D].
        key_padding_mask: Optional [B, L] bool mask where True = padded (ignored).

    Returns:
        Output multivectors [B, L, C, D].
    """
    B, L, C, D = x.shape

    # Project Q, K, V (CliffordLinear expects [B, C, D])
    x_flat = x.reshape(B * L, C, D)
    Q = self.q_proj(x_flat).reshape(B, L, C, D)
    K = self.k_proj(x_flat).reshape(B, L, C, D)
    V = self.v_proj(x_flat).reshape(B, L, C, D)

    H = self.num_heads
    Hc = self.head_channels

    # Reshape to [B, H, L, Hc, D]
    Q = Q.reshape(B, L, H, Hc, D).permute(0, 2, 1, 3, 4)  # [B, H, L, Hc, D]
    K = K.reshape(B, L, H, Hc, D).permute(0, 2, 1, 3, 4)
    V = V.reshape(B, L, H, Hc, D).permute(0, 2, 1, 3, 4)

    # Build causal mask once [L, L]
    if self.causal:
        causal_mask = torch.triu(
            torch.ones(L, L, device=x.device, dtype=torch.bool), diagonal=1
        )  # True = masked (future)
    else:
        causal_mask = None

    # Precompute K_g2 once for all query blocks - much cheaper than recomputing
    # k_g2[b,h,j,c,r,d] = K[b,h,j,c, d^r_val] * g2_sign[r, d]
    # Shape: [B, H, L, Hc, n_g2, D]  ~= 768 KB for the small MPS config
    K_g2 = K[..., self._g2_b_idx] * self._g2_sign  # [B, H, L, Hc, n_g2, D]

    # Chunked attention over query positions to bound memory
    output_chunks = []
    for q_start in range(0, L, _BLOCK_SIZE):
        q_end = min(q_start + _BLOCK_SIZE, L)

        Q_block = Q[:, :, q_start:q_end]  # [B, H, Lq, Hc, D]

        # Compute scores: [B, H, Lq, L]
        scores = self._compute_score(Q_block, K, K_g2)

        # Apply causal mask
        if causal_mask is not None:
            mask_block = causal_mask[q_start:q_end, :]  # [Lq, L]
            scores = scores.masked_fill(
                mask_block.unsqueeze(0).unsqueeze(0), float('-inf')
            )

        # Apply key padding mask: True = padded -> -inf
        if key_padding_mask is not None:
            # key_padding_mask: [B, L] -> [B, 1, 1, L]
            scores = scores.masked_fill(
                key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf')
            )

        # Softmax + dropout
        attn_weights = F.softmax(scores, dim=-1)  # [B, H, Lq, L]
        if self.attn_dropout is not None:
            attn_weights = self.attn_dropout(attn_weights)

        # Aggregate values: sum_k attn[b,h,i,k] * V[b,h,k,Hc,D]
        # attn_weights: [B, H, Lq, L]
        # V:            [B, H, L,  Hc, D]
        # out:          [B, H, Lq, Hc, D]
        out_block = torch.einsum('bhij,bhjcd->bhicd', attn_weights, V)
        output_chunks.append(out_block)

    # Reassemble: [B, H, L, Hc, D]
    output = torch.cat(output_chunks, dim=2)

    # Merge heads back: [B, L, C, D]
    output = output.permute(0, 2, 1, 3, 4).reshape(B, L, C, D)

    # Output projection
    output = self.out_proj(output.reshape(B * L, C, D)).reshape(B, L, C, D)

    return output

MultiRotorFFN

Bases: CliffordModule

Embedded Geometric Toolbox - Feed-Forward Network via rotor superposition.

Standard transformers use: Linear -> GELU -> Linear. This replaces that with:

CliffordLinear(expand) -> CliffordLayerNorm
    -> MultiRotorLayer(K rotors) -> GeometricGELU
    -> CliffordLinear(contract) -> BladeSelector

The expand step lifts x into a ffn_mult x channels toolbox subspace. MultiRotorLayer applies K parallel rotors, each exploring a different rotation plane - this IS the nonlinearity, not just a scalar gate. The contract step projects back to the original channel count.

Designed as a standalone module so it can be reused in other tasks (md17, pdbbind, etc.) beyond the language model.

Parameters:

Name Type Description Default
algebra CliffordAlgebra

The algebra instance.

required
channels int

Input/output channel count.

required
ffn_mult int

Expansion factor (ffn_channels = channels * ffn_mult).

4
num_rotors int

Number of parallel rotors K in the toolbox.

8
use_decomposition bool

Use power-iteration bivector decomposition.

False
decomp_k int

Number of simple components for decomposition.

None
use_rotor_backend bool

Use RotorGadget backend for CliffordLinear.

False

Input/Output shape: [B, C, D] where D = algebra.dim.

Source code in layers/blocks/multi_rotor_ffn.py
class MultiRotorFFN(CliffordModule):
    """Embedded Geometric Toolbox - Feed-Forward Network via rotor superposition.

    Standard transformers use: Linear -> GELU -> Linear.
    This replaces that with:

        CliffordLinear(expand) -> CliffordLayerNorm
            -> MultiRotorLayer(K rotors) -> GeometricGELU
            -> CliffordLinear(contract) -> BladeSelector

    The expand step lifts x into a ``ffn_mult x channels`` toolbox subspace.
    ``MultiRotorLayer`` applies K parallel rotors, each exploring a different
    rotation plane - this IS the nonlinearity, not just a scalar gate.
    The contract step projects back to the original channel count.

    Designed as a standalone module so it can be reused in other tasks
    (md17, pdbbind, etc.) beyond the language model.

    Args:
        algebra (CliffordAlgebra): The algebra instance.
        channels (int): Input/output channel count.
        ffn_mult (int): Expansion factor (ffn_channels = channels * ffn_mult).
        num_rotors (int): Number of parallel rotors K in the toolbox.
        use_decomposition (bool): Use power-iteration bivector decomposition.
        decomp_k (int, optional): Number of simple components for decomposition.
        use_rotor_backend (bool): Use RotorGadget backend for CliffordLinear.

    Input/Output shape: ``[B, C, D]`` where D = algebra.dim.
    """

    def __init__(
        self,
        algebra: CliffordAlgebra,
        channels: int,
        ffn_mult: int = 4,
        num_rotors: int = 8,
        use_decomposition: bool = False,
        decomp_k: int = None,
        use_rotor_backend: bool = False,
    ):
        super().__init__(algebra)
        self.channels = channels
        ffn_channels = channels * ffn_mult
        backend = 'rotor' if use_rotor_backend else 'traditional'

        self.expand = CliffordLinear(algebra, channels, ffn_channels, backend=backend)
        self.norm = CliffordLayerNorm(algebra, ffn_channels)
        self.toolbox = MultiRotorLayer(
            algebra, ffn_channels, num_rotors,
            use_decomposition=use_decomposition,
            decomp_k=decomp_k,
        )
        self.act = GeometricGELU(algebra, channels=ffn_channels)
        self.contract = CliffordLinear(algebra, ffn_channels, channels, backend=backend)
        self.gate = BladeSelector(algebra, channels)

    def forward(self, x) -> torch.Tensor:
        """Applies the geometric toolbox FFN.

        Args:
            x (torch.Tensor): Input ``[B, C, D]``.

        Returns:
            torch.Tensor: Output ``[B, C, D]``.
        """
        h = self.expand(x)    # [B, ffn_channels, D]
        h = self.norm(h)      # [B, ffn_channels, D]
        h = self.toolbox(h)   # [B, ffn_channels, D]  - K-rotor superposition
        h = self.act(h)       # [B, ffn_channels, D]
        h = self.contract(h)  # [B, channels, D]
        h = self.gate(h)      # [B, channels, D]      - per-blade gating
        return h

forward(x)

Applies the geometric toolbox FFN.

Parameters:

Name Type Description Default
x Tensor

Input [B, C, D].

required

Returns:

Type Description
Tensor

torch.Tensor: Output [B, C, D].

Source code in layers/blocks/multi_rotor_ffn.py
def forward(self, x) -> torch.Tensor:
    """Applies the geometric toolbox FFN.

    Args:
        x (torch.Tensor): Input ``[B, C, D]``.

    Returns:
        torch.Tensor: Output ``[B, C, D]``.
    """
    h = self.expand(x)    # [B, ffn_channels, D]
    h = self.norm(h)      # [B, ffn_channels, D]
    h = self.toolbox(h)   # [B, ffn_channels, D]  - K-rotor superposition
    h = self.act(h)       # [B, ffn_channels, D]
    h = self.contract(h)  # [B, channels, D]
    h = self.gate(h)      # [B, channels, D]      - per-blade gating
    return h

GeometricTransformerBlock

Bases: CliffordModule

Modular Geometric Transformer block.

Architecture: 1. Pre-norm 2. Geometric Attention (Standard or Entropy-Gated) 3. Residual connection 4. Pre-norm 5. Multi-Rotor FFN 6. Residual connection

Source code in layers/blocks/transformer.py
class GeometricTransformerBlock(CliffordModule):
    """Modular Geometric Transformer block.

    Architecture:
    1. Pre-norm
    2. Geometric Attention (Standard or Entropy-Gated)
    3. Residual connection
    4. Pre-norm
    5. Multi-Rotor FFN
    6. Residual connection
    """
    def __init__(
        self, 
        algebra: CliffordAlgebra, 
        channels: int, 
        num_heads: int = 4, 
        num_rotors: int = 8, 
        dropout: float = 0.1, 
        use_entropy_gating: bool = False, 
        eta: float = 1.5, 
        H_base: float = 0.5
    ):
        """Initializes the Geometric Transformer Block.

        Args:
            algebra: Clifford algebra instance.
            channels: Total multivector channels.
            num_heads: Number of attention heads.
            num_rotors: Number of rotors in the FFN.
            dropout: Dropout rate.
            use_entropy_gating: If True, uses EntropyGatedAttention.
            eta: Gating multiplier for entropy attention.
            H_base: Base entropy threshold.
        """
        super().__init__(algebra)
        self.use_entropy_gating = use_entropy_gating
        self.norm1 = CliffordLayerNorm(algebra, channels)

        if use_entropy_gating:
            self.attn = EntropyGatedAttention(algebra, channels, num_heads, eta=eta, H_base=H_base)
        else:
            self.attn = GeometricProductAttention(algebra, channels, num_heads, causal=False, dropout=dropout)

        self.norm2 = CliffordLayerNorm(algebra, channels)

        # Check MultiRotorFFN class name in multi_rotor_ffn.py
        from .multi_rotor_ffn import MultiRotorFFN
        self.ffn = MultiRotorFFN(algebra, channels, num_rotors=num_rotors)

    def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None,
                return_state: bool = False) -> torch.Tensor:
        """Forward pass through the transformer block.

        Args:
            x: Input multivectors [B, L, C, D].
            key_padding_mask: Optional [B, L] bool mask where True = padded.
            return_state: If True, returns intermediate entropy/gating states.

        Returns:
            Processed multivectors [B, L, C, D] (and optionally intermediate states).
        """
        B, L, C, D = x.shape

        # 1. Attention path
        res = x
        x_n = self.norm1(x.reshape(B*L, C, D)).reshape(B, L, C, D)

        if self.use_entropy_gating and return_state:
            attn_out, H, lambda_dyn = self.attn(x_n, key_padding_mask=key_padding_mask, return_gating=True)
        else:
            attn_out = self.attn(x_n, key_padding_mask=key_padding_mask)
            H, lambda_dyn = None, None

        x = res + attn_out

        # 2. FFN path
        res = x
        x_n = self.norm2(x.reshape(B*L, C, D)).reshape(B, L, C, D)
        f_out = self.ffn(x_n.reshape(B*L, C, D)).reshape(B, L, C, D)
        x = res + f_out

        if return_state:
            return x, H, lambda_dyn
        return x

__init__(algebra, channels, num_heads=4, num_rotors=8, dropout=0.1, use_entropy_gating=False, eta=1.5, H_base=0.5)

Initializes the Geometric Transformer Block.

Parameters:

Name Type Description Default
algebra CliffordAlgebra

Clifford algebra instance.

required
channels int

Total multivector channels.

required
num_heads int

Number of attention heads.

4
num_rotors int

Number of rotors in the FFN.

8
dropout float

Dropout rate.

0.1
use_entropy_gating bool

If True, uses EntropyGatedAttention.

False
eta float

Gating multiplier for entropy attention.

1.5
H_base float

Base entropy threshold.

0.5
Source code in layers/blocks/transformer.py
def __init__(
    self, 
    algebra: CliffordAlgebra, 
    channels: int, 
    num_heads: int = 4, 
    num_rotors: int = 8, 
    dropout: float = 0.1, 
    use_entropy_gating: bool = False, 
    eta: float = 1.5, 
    H_base: float = 0.5
):
    """Initializes the Geometric Transformer Block.

    Args:
        algebra: Clifford algebra instance.
        channels: Total multivector channels.
        num_heads: Number of attention heads.
        num_rotors: Number of rotors in the FFN.
        dropout: Dropout rate.
        use_entropy_gating: If True, uses EntropyGatedAttention.
        eta: Gating multiplier for entropy attention.
        H_base: Base entropy threshold.
    """
    super().__init__(algebra)
    self.use_entropy_gating = use_entropy_gating
    self.norm1 = CliffordLayerNorm(algebra, channels)

    if use_entropy_gating:
        self.attn = EntropyGatedAttention(algebra, channels, num_heads, eta=eta, H_base=H_base)
    else:
        self.attn = GeometricProductAttention(algebra, channels, num_heads, causal=False, dropout=dropout)

    self.norm2 = CliffordLayerNorm(algebra, channels)

    # Check MultiRotorFFN class name in multi_rotor_ffn.py
    from .multi_rotor_ffn import MultiRotorFFN
    self.ffn = MultiRotorFFN(algebra, channels, num_rotors=num_rotors)

forward(x, key_padding_mask=None, return_state=False)

Forward pass through the transformer block.

Parameters:

Name Type Description Default
x Tensor

Input multivectors [B, L, C, D].

required
key_padding_mask Tensor

Optional [B, L] bool mask where True = padded.

None
return_state bool

If True, returns intermediate entropy/gating states.

False

Returns:

Type Description
Tensor

Processed multivectors [B, L, C, D] (and optionally intermediate states).

Source code in layers/blocks/transformer.py
def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None,
            return_state: bool = False) -> torch.Tensor:
    """Forward pass through the transformer block.

    Args:
        x: Input multivectors [B, L, C, D].
        key_padding_mask: Optional [B, L] bool mask where True = padded.
        return_state: If True, returns intermediate entropy/gating states.

    Returns:
        Processed multivectors [B, L, C, D] (and optionally intermediate states).
    """
    B, L, C, D = x.shape

    # 1. Attention path
    res = x
    x_n = self.norm1(x.reshape(B*L, C, D)).reshape(B, L, C, D)

    if self.use_entropy_gating and return_state:
        attn_out, H, lambda_dyn = self.attn(x_n, key_padding_mask=key_padding_mask, return_gating=True)
    else:
        attn_out = self.attn(x_n, key_padding_mask=key_padding_mask)
        H, lambda_dyn = None, None

    x = res + attn_out

    # 2. FFN path
    res = x
    x_n = self.norm2(x.reshape(B*L, C, D)).reshape(B, L, C, D)
    f_out = self.ffn(x_n.reshape(B*L, C, D)).reshape(B, L, C, D)
    x = res + f_out

    if return_state:
        return x, H, lambda_dyn
    return x

Adapters

MultivectorEmbedding

Bases: CliffordModule

Token embedding as multivectors.

Each token maps to a [channels, dim] multivector. Initializes content in grade-1 (vector) subspace only - semantic content starts as directed quantities before rotors act on them.

Attributes:

Name Type Description
vocab_size int

Number of tokens.

channels int

Number of multivector channels.

embedding Embedding

Underlying embedding table.

Source code in layers/adapters/embedding.py
class MultivectorEmbedding(CliffordModule):
    """Token embedding as multivectors.

    Each token maps to a [channels, dim] multivector. Initializes
    content in grade-1 (vector) subspace only - semantic content
    starts as directed quantities before rotors act on them.

    Attributes:
        vocab_size (int): Number of tokens.
        channels (int): Number of multivector channels.
        embedding (nn.Embedding): Underlying embedding table.
    """

    def __init__(self, algebra: CliffordAlgebra, vocab_size: int, channels: int):
        """Sets up the multivector embedding.

        Args:
            algebra: Clifford algebra instance.
            vocab_size: Vocabulary size.
            channels: Number of multivector channels per token.
        """
        super().__init__(algebra)
        self.vocab_size = vocab_size
        self.channels = channels

        # Single flat embedding: vocab_size -> channels * dim
        self.embedding = nn.Embedding(vocab_size, channels * algebra.dim)
        self._init_grade1()

    def _init_grade1(self):
        """Initializes only grade-1 components; zeros out all others."""
        with torch.no_grad():
            dim = self.algebra.dim
            channels = self.channels

            # Build grade-1 mask (indices with exactly 1 bit set)
            grade1_flat = []
            for i in range(dim):
                if bin(i).count('1') == 1:
                    grade1_flat.append(i)

            # Zero everything
            self.embedding.weight.zero_()

            # Fill grade-1 slots with small normal values
            for ch in range(channels):
                for idx in grade1_flat:
                    flat_idx = ch * dim + idx
                    self.embedding.weight[:, flat_idx].normal_(std=0.02)

    def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
        """Maps token ids to multivector embeddings.

        Args:
            token_ids: Token indices [B, L].

        Returns:
            Multivector embeddings [B, L, channels, dim].
        """
        B, L = token_ids.shape
        flat = self.embedding(token_ids)  # [B, L, channels * dim]
        return flat.reshape(B, L, self.channels, self.algebra.dim)

__init__(algebra, vocab_size, channels)

Sets up the multivector embedding.

Parameters:

Name Type Description Default
algebra CliffordAlgebra

Clifford algebra instance.

required
vocab_size int

Vocabulary size.

required
channels int

Number of multivector channels per token.

required
Source code in layers/adapters/embedding.py
def __init__(self, algebra: CliffordAlgebra, vocab_size: int, channels: int):
    """Sets up the multivector embedding.

    Args:
        algebra: Clifford algebra instance.
        vocab_size: Vocabulary size.
        channels: Number of multivector channels per token.
    """
    super().__init__(algebra)
    self.vocab_size = vocab_size
    self.channels = channels

    # Single flat embedding: vocab_size -> channels * dim
    self.embedding = nn.Embedding(vocab_size, channels * algebra.dim)
    self._init_grade1()

forward(token_ids)

Maps token ids to multivector embeddings.

Parameters:

Name Type Description Default
token_ids Tensor

Token indices [B, L].

required

Returns:

Type Description
Tensor

Multivector embeddings [B, L, channels, dim].

Source code in layers/adapters/embedding.py
def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
    """Maps token ids to multivector embeddings.

    Args:
        token_ids: Token indices [B, L].

    Returns:
        Multivector embeddings [B, L, channels, dim].
    """
    B, L = token_ids.shape
    flat = self.embedding(token_ids)  # [B, L, channels * dim]
    return flat.reshape(B, L, self.channels, self.algebra.dim)

MotherEmbedding

Bases: CliffordModule

Embeds local feature groups into a canonical Mother Algebra with Procrustes Alignment.

Uses fixed rotors (R_fixed) to rotate individual channel vectors into a shared reference frame, effectively aligning disparate geometric manifolds.

Source code in layers/adapters/mother.py
class MotherEmbedding(CliffordModule):
    """Embeds local feature groups into a canonical Mother Algebra with Procrustes Alignment.

    Uses fixed rotors (R_fixed) to rotate individual channel vectors into a shared
    reference frame, effectively aligning disparate geometric manifolds.
    """
    def __init__(self, algebra: CliffordAlgebra, input_dim: int, channels: int, U: float = 0.0, V: torch.Tensor = None):
        """Initializes the Mother Embedding.

        Args:
            algebra: Clifford algebra instance.
            input_dim: Dimension of the input features.
            channels: Number of multivector channels.
            U: Geometric uncertainty index for manifold suppression.
            V: Fixed rotor proxy for Procrustes alignment (input_dim x input_dim).
        """
        super().__init__(algebra)
        self.channels = channels

        # Procrustes Alignment Matrix (Fixed Rotor Proxy)
        if V is None:
            V = torch.eye(input_dim)
        self.register_buffer('R_fixed', V)

        # Up-cast to Mother Algebra multivector channels
        self.linear = nn.Linear(input_dim, channels * algebra.dim)
        self.norm = CliffordLayerNorm(algebra, channels)

        # Pre-condition LayerNorm scale with Uncertainty Index
        with torch.no_grad():
            if hasattr(self.norm, 'weight'):
                # Suppress highly uncertain (twisted) manifolds initially
                scale = 1.0 / (1.0 + U)
                self.norm.weight.data.fill_(scale)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Projects input into the aligned mother manifold.

        Args:
            x: Input features [B, input_dim].

        Returns:
            Aligned multivectors [B, channels, dim].
        """
        # 1. Apply Geometric Procrustes Alignment
        if self.R_fixed is not None:
            x = x @ self.R_fixed.T

        # 2. Mother Projection
        c = self.linear(x).view(-1, self.channels, self.algebra.dim)
        return self.norm(c)

__init__(algebra, input_dim, channels, U=0.0, V=None)

Initializes the Mother Embedding.

Parameters:

Name Type Description Default
algebra CliffordAlgebra

Clifford algebra instance.

required
input_dim int

Dimension of the input features.

required
channels int

Number of multivector channels.

required
U float

Geometric uncertainty index for manifold suppression.

0.0
V Tensor

Fixed rotor proxy for Procrustes alignment (input_dim x input_dim).

None
Source code in layers/adapters/mother.py
def __init__(self, algebra: CliffordAlgebra, input_dim: int, channels: int, U: float = 0.0, V: torch.Tensor = None):
    """Initializes the Mother Embedding.

    Args:
        algebra: Clifford algebra instance.
        input_dim: Dimension of the input features.
        channels: Number of multivector channels.
        U: Geometric uncertainty index for manifold suppression.
        V: Fixed rotor proxy for Procrustes alignment (input_dim x input_dim).
    """
    super().__init__(algebra)
    self.channels = channels

    # Procrustes Alignment Matrix (Fixed Rotor Proxy)
    if V is None:
        V = torch.eye(input_dim)
    self.register_buffer('R_fixed', V)

    # Up-cast to Mother Algebra multivector channels
    self.linear = nn.Linear(input_dim, channels * algebra.dim)
    self.norm = CliffordLayerNorm(algebra, channels)

    # Pre-condition LayerNorm scale with Uncertainty Index
    with torch.no_grad():
        if hasattr(self.norm, 'weight'):
            # Suppress highly uncertain (twisted) manifolds initially
            scale = 1.0 / (1.0 + U)
            self.norm.weight.data.fill_(scale)

forward(x)

Projects input into the aligned mother manifold.

Parameters:

Name Type Description Default
x Tensor

Input features [B, input_dim].

required

Returns:

Type Description
Tensor

Aligned multivectors [B, channels, dim].

Source code in layers/adapters/mother.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Projects input into the aligned mother manifold.

    Args:
        x: Input features [B, input_dim].

    Returns:
        Aligned multivectors [B, channels, dim].
    """
    # 1. Apply Geometric Procrustes Alignment
    if self.R_fixed is not None:
        x = x @ self.R_fixed.T

    # 2. Mother Projection
    c = self.linear(x).view(-1, self.channels, self.algebra.dim)
    return self.norm(c)

EntropyGatedAttention

Bases: CliffordModule

Dynamic geometric attention governed by bivector information entropy.

Segments with high bivector entropy (disordered phase states) are "stiffened" or suppressed, allowing only coherent, synchronized states to propagate.

Source code in layers/adapters/mother.py
class EntropyGatedAttention(CliffordModule):
    """Dynamic geometric attention governed by bivector information entropy.

    Segments with high bivector entropy (disordered phase states) are "stiffened" 
    or suppressed, allowing only coherent, synchronized states to propagate.
    """
    def __init__(self, algebra: CliffordAlgebra, channels: int, num_heads: int, eta: float = 1.0, H_base: float = 0.5):
        """Initializes Entropy-Gated Attention.

        Args:
            algebra: Clifford algebra instance.
            channels: Total multivector channels.
            num_heads: Number of attention heads.
            eta: Gating multiplier.
            H_base: Base entropy threshold.
        """
        super().__init__(algebra)
        self.channels = channels
        self.eta = eta
        self.H_base = H_base
        self.base_attention = GeometricProductAttention(algebra, channels, num_heads, causal=False)

        # Cache bivector indices to avoid recomputing
        mask = self.algebra.grade_masks[2]
        self.register_buffer('g2_idx', mask.nonzero(as_tuple=True)[0])

    def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None,
                return_gating: bool = False) -> torch.Tensor:
        """Applies entropy-gated geometric attention.

        Args:
            x: Input multivectors [B, L, C, D].
            key_padding_mask: Optional [B, L] bool mask where True = padded.
            return_gating: If True, returns entropy and gating values.

        Returns:
            Attended multivectors [B, L, C, D].
        """
        # 1. Calculate Information Entropy of Bivector Energy
        # Summing across multivector components (g2_idx) and across channels (dim 2)
        # x: [B, L, C, D]
        g2_energy = (x[..., self.g2_idx]**2).sum(dim=(-1, -2)) # [B, L]

        # Mask padded positions before entropy calc
        if key_padding_mask is not None:
            g2_energy = g2_energy.masked_fill(key_padding_mask, 0.0)

        # Normalize to probability distribution over sequence
        p = g2_energy / (g2_energy.sum(dim=1, keepdim=True) + 1e-8)

        # Shannon Entropy H per batch [B]
        H = -(p * torch.log(p + 1e-8)).sum(dim=1)

        # 2. Base-Adjusted Gating Function
        lambda_dyn = self.eta * torch.sigmoid(H - self.H_base) # [B]

        # 3. Apply dynamic geometric stiffness
        # Scale the rotational components (bivectors)
        lambda_view = lambda_dyn.view(-1, 1, 1, 1)

        x_gated = x.clone()
        x_gated[..., self.g2_idx] *= lambda_view

        out = self.base_attention(x_gated, key_padding_mask=key_padding_mask)

        if return_gating:
            return out, H, lambda_dyn
        return out

__init__(algebra, channels, num_heads, eta=1.0, H_base=0.5)

Initializes Entropy-Gated Attention.

Parameters:

Name Type Description Default
algebra CliffordAlgebra

Clifford algebra instance.

required
channels int

Total multivector channels.

required
num_heads int

Number of attention heads.

required
eta float

Gating multiplier.

1.0
H_base float

Base entropy threshold.

0.5
Source code in layers/adapters/mother.py
def __init__(self, algebra: CliffordAlgebra, channels: int, num_heads: int, eta: float = 1.0, H_base: float = 0.5):
    """Initializes Entropy-Gated Attention.

    Args:
        algebra: Clifford algebra instance.
        channels: Total multivector channels.
        num_heads: Number of attention heads.
        eta: Gating multiplier.
        H_base: Base entropy threshold.
    """
    super().__init__(algebra)
    self.channels = channels
    self.eta = eta
    self.H_base = H_base
    self.base_attention = GeometricProductAttention(algebra, channels, num_heads, causal=False)

    # Cache bivector indices to avoid recomputing
    mask = self.algebra.grade_masks[2]
    self.register_buffer('g2_idx', mask.nonzero(as_tuple=True)[0])

forward(x, key_padding_mask=None, return_gating=False)

Applies entropy-gated geometric attention.

Parameters:

Name Type Description Default
x Tensor

Input multivectors [B, L, C, D].

required
key_padding_mask Tensor

Optional [B, L] bool mask where True = padded.

None
return_gating bool

If True, returns entropy and gating values.

False

Returns:

Type Description
Tensor

Attended multivectors [B, L, C, D].

Source code in layers/adapters/mother.py
def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None,
            return_gating: bool = False) -> torch.Tensor:
    """Applies entropy-gated geometric attention.

    Args:
        x: Input multivectors [B, L, C, D].
        key_padding_mask: Optional [B, L] bool mask where True = padded.
        return_gating: If True, returns entropy and gating values.

    Returns:
        Attended multivectors [B, L, C, D].
    """
    # 1. Calculate Information Entropy of Bivector Energy
    # Summing across multivector components (g2_idx) and across channels (dim 2)
    # x: [B, L, C, D]
    g2_energy = (x[..., self.g2_idx]**2).sum(dim=(-1, -2)) # [B, L]

    # Mask padded positions before entropy calc
    if key_padding_mask is not None:
        g2_energy = g2_energy.masked_fill(key_padding_mask, 0.0)

    # Normalize to probability distribution over sequence
    p = g2_energy / (g2_energy.sum(dim=1, keepdim=True) + 1e-8)

    # Shannon Entropy H per batch [B]
    H = -(p * torch.log(p + 1e-8)).sum(dim=1)

    # 2. Base-Adjusted Gating Function
    lambda_dyn = self.eta * torch.sigmoid(H - self.H_base) # [B]

    # 3. Apply dynamic geometric stiffness
    # Scale the rotational components (bivectors)
    lambda_view = lambda_dyn.view(-1, 1, 1, 1)

    x_gated = x.clone()
    x_gated[..., self.g2_idx] *= lambda_view

    out = self.base_attention(x_gated, key_padding_mask=key_padding_mask)

    if return_gating:
        return out, H, lambda_dyn
    return out

Optional dependency

CliffordGraphConv requires torch-geometric. Install with uv sync --extra md17.

CliffordGraphConv

Bases: CliffordModule

Geometric Graph Conv. Performs message passing using multivector features.

Aggregates features based on graph topology. H' = Aggregate(H) * W + Bias.

Attributes:

Name Type Description
linear CliffordLinear

The transformation.

Source code in layers/adapters/gnn.py
class CliffordGraphConv(CliffordModule):
    """Geometric Graph Conv. Performs message passing using multivector features.

    Aggregates features based on graph topology.
    H' = Aggregate(H) * W + Bias.

    Attributes:
        linear (CliffordLinear): The transformation.
    """

    def __init__(self, algebra: CliffordAlgebra, in_channels: int, out_channels: int):
        """Sets up the GNN layer.

        Args:
            algebra (CliffordAlgebra): The algebra instance.
            in_channels (int): Input features.
            out_channels (int): Output features.
        """
        super().__init__(algebra)
        self.linear = CliffordLinear(algebra, in_channels, out_channels)

    def forward(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
        """Aggregates and transforms node features using geometric operations.

        Args:
            x (torch.Tensor): Node features.
            adj (torch.Tensor): Adjacency matrix.

        Returns:
            torch.Tensor: Updated features.
        """
        # 1. Aggregate
        N, C, D = x.shape
        x_flat = x.view(N, -1)

        # Sparse aggregation
        x_agg_flat = torch.mm(adj, x_flat)
        x_agg = x_agg_flat.view(N, C, D)

        # 2. Transform
        out = self.linear(x_agg)

        return out

__init__(algebra, in_channels, out_channels)

Sets up the GNN layer.

Parameters:

Name Type Description Default
algebra CliffordAlgebra

The algebra instance.

required
in_channels int

Input features.

required
out_channels int

Output features.

required
Source code in layers/adapters/gnn.py
def __init__(self, algebra: CliffordAlgebra, in_channels: int, out_channels: int):
    """Sets up the GNN layer.

    Args:
        algebra (CliffordAlgebra): The algebra instance.
        in_channels (int): Input features.
        out_channels (int): Output features.
    """
    super().__init__(algebra)
    self.linear = CliffordLinear(algebra, in_channels, out_channels)

forward(x, adj)

Aggregates and transforms node features using geometric operations.

Parameters:

Name Type Description Default
x Tensor

Node features.

required
adj Tensor

Adjacency matrix.

required

Returns:

Type Description
Tensor

torch.Tensor: Updated features.

Source code in layers/adapters/gnn.py
def forward(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
    """Aggregates and transforms node features using geometric operations.

    Args:
        x (torch.Tensor): Node features.
        adj (torch.Tensor): Adjacency matrix.

    Returns:
        torch.Tensor: Updated features.
    """
    # 1. Aggregate
    N, C, D = x.shape
    x_flat = x.view(N, -1)

    # Sparse aggregation
    x_agg_flat = torch.mm(adj, x_flat)
    x_agg = x_agg_flat.view(N, C, D)

    # 2. Transform
    out = self.linear(x_agg)

    return out