Skip to content

Functional

Stateless operations. No learnable parameters.

Activations

GeometricGELU

Bases: Module

Geometric GELU activation: x' = x * GELU(||x|| + b) / ||x||.

Scales magnitude while preserving direction.

Attributes:

Name Type Description
algebra CliffordAlgebra

The algebra instance.

bias Parameter

Learnable bias added to norm.

Source code in functional/activation.py
class GeometricGELU(nn.Module):
    """Geometric GELU activation: x' = x * GELU(||x|| + b) / ||x||.

    Scales magnitude while preserving direction.

    Attributes:
        algebra (CliffordAlgebra): The algebra instance.
        bias (torch.nn.Parameter): Learnable bias added to norm.
    """

    def __init__(self, algebra, channels: int = 1):
        """Initialize Geometric GELU.

        Args:
            algebra (CliffordAlgebra): The algebra instance.
            channels (int): Number of channels.
        """
        super().__init__()
        self.algebra = algebra
        self.bias = nn.Parameter(torch.zeros(channels))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply geometric GELU activation.

        Args:
            x (torch.Tensor): Input multivector [..., Dim].

        Returns:
            torch.Tensor: Activated multivector.
        """
        norm = x.norm(dim=-1, keepdim=True)

        eps = 1e-6
        scale = F.gelu(norm + self.bias.view(1, -1, 1)) / (norm + eps)

        return x * scale

__init__(algebra, channels=1)

Initialize Geometric GELU.

Parameters:

Name Type Description Default
algebra CliffordAlgebra

The algebra instance.

required
channels int

Number of channels.

1
Source code in functional/activation.py
def __init__(self, algebra, channels: int = 1):
    """Initialize Geometric GELU.

    Args:
        algebra (CliffordAlgebra): The algebra instance.
        channels (int): Number of channels.
    """
    super().__init__()
    self.algebra = algebra
    self.bias = nn.Parameter(torch.zeros(channels))

forward(x)

Apply geometric GELU activation.

Parameters:

Name Type Description Default
x Tensor

Input multivector [..., Dim].

required

Returns:

Type Description
Tensor

torch.Tensor: Activated multivector.

Source code in functional/activation.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Apply geometric GELU activation.

    Args:
        x (torch.Tensor): Input multivector [..., Dim].

    Returns:
        torch.Tensor: Activated multivector.
    """
    norm = x.norm(dim=-1, keepdim=True)

    eps = 1e-6
    scale = F.gelu(norm + self.bias.view(1, -1, 1)) / (norm + eps)

    return x * scale

GradeSwish

Bases: Module

Per-grade gated activation.

Each grade receives an independent sigmoid gate based on its norm.

Attributes:

Name Type Description
algebra CliffordAlgebra

The algebra instance.

n_grades int

Number of grades.

grade_weights Parameter

Weights for each grade gate.

grade_biases Parameter

Biases for each grade gate.

Source code in functional/activation.py
class GradeSwish(nn.Module):
    """Per-grade gated activation.

    Each grade receives an independent sigmoid gate based on its norm.

    Attributes:
        algebra (CliffordAlgebra): The algebra instance.
        n_grades (int): Number of grades.
        grade_weights (torch.nn.Parameter): Weights for each grade gate.
        grade_biases (torch.nn.Parameter): Biases for each grade gate.
    """

    def __init__(self, algebra, channels: int = 1):
        """Initialize Grade Swish.

        Args:
            algebra (CliffordAlgebra): The algebra instance.
            channels (int): Number of channels.
        """
        super().__init__()
        self.algebra = algebra
        self.n_grades = algebra.n + 1

        self.grade_weights = nn.Parameter(torch.ones(self.n_grades))
        self.grade_biases = nn.Parameter(torch.zeros(self.n_grades))

        self.register_buffer('grade_masks', self._build_masks())

    def _build_masks(self) -> torch.Tensor:
        """Precompute grade masks.

        Returns:
            torch.Tensor: Boolean masks for each grade [n_grades, dim].
        """
        masks = torch.zeros(self.n_grades, self.algebra.dim, dtype=torch.bool)
        for i in range(self.algebra.dim):
            grade = bin(i).count('1')
            masks[grade, i] = True
        return masks

    def _build_grade_map(self) -> torch.Tensor:
        """Precompute per-component grade index for vectorized forward.

        Returns:
            torch.Tensor: Long tensor of grade indices [dim].
        """
        grade_map = torch.zeros(self.algebra.dim, dtype=torch.long)
        for i in range(self.algebra.dim):
            grade_map[i] = bin(i).count('1')
        return grade_map

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply per-grade gating.

        Args:
            x (torch.Tensor): Input multivector [..., Dim].

        Returns:
            torch.Tensor: Activated multivector.
        """
        # Build grade map buffer on first call or device change
        if not hasattr(self, '_grade_map') or self._grade_map is None:
            self.register_buffer('_grade_map', self._build_grade_map())
        grade_map = self._grade_map
        if grade_map.device != x.device:
            grade_map = grade_map.to(x.device)
            self._grade_map = grade_map

        # Compute per-grade norms via scatter
        # x: [..., D], grade_map: [D] -> group components by grade
        D = self.algebra.dim
        G = self.n_grades

        # Square, scatter-add by grade, sqrt -> per-grade norms
        x_sq = x * x  # [..., D]
        # Expand grade_map to match x shape for scatter
        batch_shape = x.shape[:-1]
        grade_idx = grade_map.expand(*batch_shape, D)  # [..., D]

        norm_sq = torch.zeros(*batch_shape, G, device=x.device, dtype=x.dtype)
        norm_sq.scatter_add_(-1, grade_idx, x_sq)  # [..., G]
        norms = torch.sqrt(norm_sq.clamp(min=1e-12))  # [..., G]

        # Compute gates: sigmoid(w * norm + b) for each grade
        gates = torch.sigmoid(
            self.grade_weights * norms + self.grade_biases
        )  # [..., G]

        # Broadcast gate per component: lookup gate[grade_map[d]] for each d
        per_component_gate = gates.gather(-1, grade_idx)  # [..., D]

        return x * per_component_gate

__init__(algebra, channels=1)

Initialize Grade Swish.

Parameters:

Name Type Description Default
algebra CliffordAlgebra

The algebra instance.

required
channels int

Number of channels.

1
Source code in functional/activation.py
def __init__(self, algebra, channels: int = 1):
    """Initialize Grade Swish.

    Args:
        algebra (CliffordAlgebra): The algebra instance.
        channels (int): Number of channels.
    """
    super().__init__()
    self.algebra = algebra
    self.n_grades = algebra.n + 1

    self.grade_weights = nn.Parameter(torch.ones(self.n_grades))
    self.grade_biases = nn.Parameter(torch.zeros(self.n_grades))

    self.register_buffer('grade_masks', self._build_masks())

forward(x)

Apply per-grade gating.

Parameters:

Name Type Description Default
x Tensor

Input multivector [..., Dim].

required

Returns:

Type Description
Tensor

torch.Tensor: Activated multivector.

Source code in functional/activation.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Apply per-grade gating.

    Args:
        x (torch.Tensor): Input multivector [..., Dim].

    Returns:
        torch.Tensor: Activated multivector.
    """
    # Build grade map buffer on first call or device change
    if not hasattr(self, '_grade_map') or self._grade_map is None:
        self.register_buffer('_grade_map', self._build_grade_map())
    grade_map = self._grade_map
    if grade_map.device != x.device:
        grade_map = grade_map.to(x.device)
        self._grade_map = grade_map

    # Compute per-grade norms via scatter
    # x: [..., D], grade_map: [D] -> group components by grade
    D = self.algebra.dim
    G = self.n_grades

    # Square, scatter-add by grade, sqrt -> per-grade norms
    x_sq = x * x  # [..., D]
    # Expand grade_map to match x shape for scatter
    batch_shape = x.shape[:-1]
    grade_idx = grade_map.expand(*batch_shape, D)  # [..., D]

    norm_sq = torch.zeros(*batch_shape, G, device=x.device, dtype=x.dtype)
    norm_sq.scatter_add_(-1, grade_idx, x_sq)  # [..., G]
    norms = torch.sqrt(norm_sq.clamp(min=1e-12))  # [..., G]

    # Compute gates: sigmoid(w * norm + b) for each grade
    gates = torch.sigmoid(
        self.grade_weights * norms + self.grade_biases
    )  # [..., G]

    # Broadcast gate per component: lookup gate[grade_map[d]] for each d
    per_component_gate = gates.gather(-1, grade_idx)  # [..., D]

    return x * per_component_gate

Losses

ChamferDistance

Bases: Module

Symmetric Chamfer distance between two point clouds.

CD(P, Q) = (1/|P|) sum_p min_q ||p-q||^2 + (1/|Q|) sum_q min_p ||q-p||^2

Standard metric for 3D point cloud reconstruction and generation.

Source code in functional/loss.py
class ChamferDistance(nn.Module):
    """Symmetric Chamfer distance between two point clouds.

    CD(P, Q) = (1/|P|) sum_p min_q ||p-q||^2 + (1/|Q|) sum_q min_p ||q-p||^2

    Standard metric for 3D point cloud reconstruction and generation.
    """

    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """Compute Chamfer distance.

        Args:
            pred: Predicted point cloud [B, M, 3].
            target: Target point cloud [B, N, 3].

        Returns:
            Chamfer distance (scalar).
        """
        diff = pred.unsqueeze(2) - target.unsqueeze(1)
        dist_sq = (diff ** 2).sum(dim=-1)
        min_dist_pred = dist_sq.min(dim=2)[0].mean(dim=1)
        min_dist_target = dist_sq.min(dim=1)[0].mean(dim=1)
        return (min_dist_pred + min_dist_target).mean()

forward(pred, target)

Compute Chamfer distance.

Parameters:

Name Type Description Default
pred Tensor

Predicted point cloud [B, M, 3].

required
target Tensor

Target point cloud [B, N, 3].

required

Returns:

Type Description
Tensor

Chamfer distance (scalar).

Source code in functional/loss.py
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """Compute Chamfer distance.

    Args:
        pred: Predicted point cloud [B, M, 3].
        target: Target point cloud [B, N, 3].

    Returns:
        Chamfer distance (scalar).
    """
    diff = pred.unsqueeze(2) - target.unsqueeze(1)
    dist_sq = (diff ** 2).sum(dim=-1)
    min_dist_pred = dist_sq.min(dim=2)[0].mean(dim=1)
    min_dist_target = dist_sq.min(dim=1)[0].mean(dim=1)
    return (min_dist_pred + min_dist_target).mean()

ConservativeLoss

Bases: Module

Enforces F = -grad(E) conservative force constraint.

Physics: forces should be the negative gradient of energy with respect to atomic positions. Used in molecular dynamics tasks.

Source code in functional/loss.py
class ConservativeLoss(nn.Module):
    """Enforces F = -grad(E) conservative force constraint.

    Physics: forces should be the negative gradient of energy
    with respect to atomic positions. Used in molecular dynamics tasks.
    """

    def forward(self, energy: torch.Tensor, force_pred: torch.Tensor,
                pos: torch.Tensor) -> torch.Tensor:
        """Compute conservative force loss.

        Args:
            energy: Predicted energy (scalar, requires grad graph).
            force_pred: Predicted forces [N, 3].
            pos: Atom positions [N, 3] (must have requires_grad=True).

        Returns:
            MSE between predicted forces and -grad(E).
        """
        force_from_energy = -torch.autograd.grad(
            energy.sum(), pos,
            create_graph=True, retain_graph=True
        )[0]
        return F.mse_loss(force_pred, force_from_energy)

forward(energy, force_pred, pos)

Compute conservative force loss.

Parameters:

Name Type Description Default
energy Tensor

Predicted energy (scalar, requires grad graph).

required
force_pred Tensor

Predicted forces [N, 3].

required
pos Tensor

Atom positions [N, 3] (must have requires_grad=True).

required

Returns:

Type Description
Tensor

MSE between predicted forces and -grad(E).

Source code in functional/loss.py
def forward(self, energy: torch.Tensor, force_pred: torch.Tensor,
            pos: torch.Tensor) -> torch.Tensor:
    """Compute conservative force loss.

    Args:
        energy: Predicted energy (scalar, requires grad graph).
        force_pred: Predicted forces [N, 3].
        pos: Atom positions [N, 3] (must have requires_grad=True).

    Returns:
        MSE between predicted forces and -grad(E).
    """
    force_from_energy = -torch.autograd.grad(
        energy.sum(), pos,
        create_graph=True, retain_graph=True
    )[0]
    return F.mse_loss(force_pred, force_from_energy)

PhysicsInformedLoss

Bases: Module

Physics-informed loss combining MSE with conservation penalty.

Enforces that global weighted mean of each variable is approximately conserved between forecast and target. Used in weather forecasting.

Source code in functional/loss.py
class PhysicsInformedLoss(nn.Module):
    """Physics-informed loss combining MSE with conservation penalty.

    Enforces that global weighted mean of each variable is approximately
    conserved between forecast and target. Used in weather forecasting.
    """

    def __init__(self, physics_weight: float = 0.1):
        super().__init__()
        self.physics_weight = physics_weight

    def forward(self, forecast: torch.Tensor, target: torch.Tensor,
                lat_weights: torch.Tensor = None) -> torch.Tensor:
        """Compute physics-informed loss.

        Args:
            forecast: Predicted state [B, H, W, C].
            target: Target state [B, H, W, C].
            lat_weights: Latitude area weights [H].

        Returns:
            Combined MSE + conservation penalty.
        """
        mse_loss = F.mse_loss(forecast, target)

        if lat_weights is not None and forecast.dim() == 4:
            w = lat_weights.view(1, -1, 1, 1).to(forecast.device)
            forecast_mean = (forecast * w).sum(dim=[1, 2]) / w.sum()
            target_mean = (target * w).sum(dim=[1, 2]) / w.sum()
        else:
            forecast_mean = forecast.mean(dim=list(range(1, forecast.dim() - 1)))
            target_mean = target.mean(dim=list(range(1, target.dim() - 1)))

        conservation_loss = F.mse_loss(forecast_mean, target_mean)
        return mse_loss + self.physics_weight * conservation_loss

forward(forecast, target, lat_weights=None)

Compute physics-informed loss.

Parameters:

Name Type Description Default
forecast Tensor

Predicted state [B, H, W, C].

required
target Tensor

Target state [B, H, W, C].

required
lat_weights Tensor

Latitude area weights [H].

None

Returns:

Type Description
Tensor

Combined MSE + conservation penalty.

Source code in functional/loss.py
def forward(self, forecast: torch.Tensor, target: torch.Tensor,
            lat_weights: torch.Tensor = None) -> torch.Tensor:
    """Compute physics-informed loss.

    Args:
        forecast: Predicted state [B, H, W, C].
        target: Target state [B, H, W, C].
        lat_weights: Latitude area weights [H].

    Returns:
        Combined MSE + conservation penalty.
    """
    mse_loss = F.mse_loss(forecast, target)

    if lat_weights is not None and forecast.dim() == 4:
        w = lat_weights.view(1, -1, 1, 1).to(forecast.device)
        forecast_mean = (forecast * w).sum(dim=[1, 2]) / w.sum()
        target_mean = (target * w).sum(dim=[1, 2]) / w.sum()
    else:
        forecast_mean = forecast.mean(dim=list(range(1, forecast.dim() - 1)))
        target_mean = target.mean(dim=list(range(1, target.dim() - 1)))

    conservation_loss = F.mse_loss(forecast_mean, target_mean)
    return mse_loss + self.physics_weight * conservation_loss

GeometricMSELoss

Bases: Module

Geometric MSE. Euclidean distance in embedding space.

Standard MSE on coefficients.

Source code in functional/loss.py
class GeometricMSELoss(nn.Module):
    """Geometric MSE. Euclidean distance in embedding space.

    Standard MSE on coefficients.
    """

    def __init__(self, algebra=None):
        """Initialize the geometric MSE loss."""
        super().__init__()
        self.algebra = algebra

    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """MSE."""
        return F.mse_loss(pred, target, reduction='mean')

__init__(algebra=None)

Initialize the geometric MSE loss.

Source code in functional/loss.py
def __init__(self, algebra=None):
    """Initialize the geometric MSE loss."""
    super().__init__()
    self.algebra = algebra

forward(pred, target)

MSE.

Source code in functional/loss.py
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """MSE."""
    return F.mse_loss(pred, target, reduction='mean')

SubspaceLoss

Bases: Module

Subspace Loss. Enforces grade constraints.

Penalizes energy in forbidden grades.

Source code in functional/loss.py
class SubspaceLoss(nn.Module):
    """Subspace Loss. Enforces grade constraints.

    Penalizes energy in forbidden grades.
    """

    def __init__(self, algebra, target_indices: list = None, exclude_indices: list = None):
        """Initialize grade constraint penalties."""
        super().__init__()
        self.algebra = algebra

        if target_indices is not None:
            mask = torch.ones(algebra.dim, dtype=torch.bool)
            mask[target_indices] = False
        elif exclude_indices is not None:
            mask = torch.zeros(algebra.dim, dtype=torch.bool)
            mask[exclude_indices] = True
        else:
            raise ValueError("Must provide target_indices or exclude_indices")

        self.register_buffer('penalty_mask', mask)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Penalizes deviations."""
        penalty_components = x[..., self.penalty_mask.to(x.device)]
        loss = (penalty_components ** 2).sum(dim=-1).mean()
        return loss

__init__(algebra, target_indices=None, exclude_indices=None)

Initialize grade constraint penalties.

Source code in functional/loss.py
def __init__(self, algebra, target_indices: list = None, exclude_indices: list = None):
    """Initialize grade constraint penalties."""
    super().__init__()
    self.algebra = algebra

    if target_indices is not None:
        mask = torch.ones(algebra.dim, dtype=torch.bool)
        mask[target_indices] = False
    elif exclude_indices is not None:
        mask = torch.zeros(algebra.dim, dtype=torch.bool)
        mask[exclude_indices] = True
    else:
        raise ValueError("Must provide target_indices or exclude_indices")

    self.register_buffer('penalty_mask', mask)

forward(x)

Penalizes deviations.

Source code in functional/loss.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Penalizes deviations."""
    penalty_components = x[..., self.penalty_mask.to(x.device)]
    loss = (penalty_components ** 2).sum(dim=-1).mean()
    return loss

IsometryLoss

Bases: Module

Isometry loss enforcing metric norm preservation.

Ensures transformations preserve the metric norm.

Source code in functional/loss.py
class IsometryLoss(nn.Module):
    """Isometry loss enforcing metric norm preservation.

    Ensures transformations preserve the metric norm.
    """

    def __init__(self, algebra):
        """Initialize isometry loss with metric diagonal."""
        super().__init__()
        self.algebra = algebra
        self.metric_diag = self._compute_metric_diagonal()

    def _compute_metric_diagonal(self):
        """Finds the signature."""
        basis = torch.eye(self.algebra.dim, device=self.algebra.device)
        sq = self.algebra.geometric_product(basis, basis)
        diag = sq[:, 0]
        return diag

    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """Compares norms."""
        metric_diag = self.metric_diag.to(pred.device)
        pred_sq = (pred ** 2) * metric_diag
        target_sq = (target ** 2) * metric_diag

        pred_norm = pred_sq.sum(dim=-1)
        target_norm = target_sq.sum(dim=-1)

        return F.mse_loss(pred_norm, target_norm)

__init__(algebra)

Initialize isometry loss with metric diagonal.

Source code in functional/loss.py
def __init__(self, algebra):
    """Initialize isometry loss with metric diagonal."""
    super().__init__()
    self.algebra = algebra
    self.metric_diag = self._compute_metric_diagonal()

forward(pred, target)

Compares norms.

Source code in functional/loss.py
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """Compares norms."""
    metric_diag = self.metric_diag.to(pred.device)
    pred_sq = (pred ** 2) * metric_diag
    target_sq = (target ** 2) * metric_diag

    pred_norm = pred_sq.sum(dim=-1)
    target_norm = target_sq.sum(dim=-1)

    return F.mse_loss(pred_norm, target_norm)

BivectorRegularization

Bases: Module

Bivector regularization enforcing grade-2 purity.

Penalizes energy outside the target grade (default: grade 2).

Source code in functional/loss.py
class BivectorRegularization(nn.Module):
    """Bivector regularization enforcing grade-2 purity.

    Penalizes energy outside the target grade (default: grade 2).
    """

    def __init__(self, algebra, grade=2):
        """Initialize bivector regularization."""
        super().__init__()
        self.algebra = algebra
        self.grade = grade

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Penalizes non-bivector parts."""
        target_part = self.algebra.grade_projection(x, self.grade)
        residual = x - target_part
        return (residual ** 2).sum(dim=-1).mean()

__init__(algebra, grade=2)

Initialize bivector regularization.

Source code in functional/loss.py
def __init__(self, algebra, grade=2):
    """Initialize bivector regularization."""
    super().__init__()
    self.algebra = algebra
    self.grade = grade

forward(x)

Penalizes non-bivector parts.

Source code in functional/loss.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Penalizes non-bivector parts."""
    target_part = self.algebra.grade_projection(x, self.grade)
    residual = x - target_part
    return (residual ** 2).sum(dim=-1).mean()