mHC: Manifold-Constrained Hyper-Connections
Overview
mHC (Manifold-Constrained Hyper-Connections) stabilizes deep network training by constraining residual mixing matrices to be doubly stochastic. It provides:
- Stable Training: Lower gradient norm variance via doubly stochastic constraints
- Multiple Streams: Hyper-Connections with learnable mixing across residual streams
- Sinkhorn Projection: Log-space Sinkhorn-Knopp algorithm for doubly stochastic projection
- GPT Integration: Pattern for wrapping attention and MLP layers
Two components:
- HyperConnections Module: Core PyTorch module with H_res, H_pre, H_post matrices
- Sinkhorn-Knopp: Log-space projection to doubly stochastic manifold
Quick Reference
| Topic | Reference | |-------|-----------| | Core Concepts & Math | Core Concepts | | Sinkhorn Algorithm | Sinkhorn-Knopp | | HyperConnections Module | Module Implementation | | GPT Integration | GPT Integration | | Common Pitfalls | Pitfalls |
Installation
# Required packages
pip install torch einops numpy
Minimal Example
import torch
import torch.nn as nn
from einops import rearrange, einsum
def sinkhorn_knopp(logits, num_iters=20, tau=0.05):
log_alpha = logits / tau
for _ in range(num_iters):
log_alpha = log_alpha - torch.logsumexp(log_alpha, dim=-1, keepdim=True)
log_alpha = log_alpha - torch.logsumexp(log_alpha, dim=-2, keepdim=True)
return torch.exp(log_alpha)
class HyperConnections(nn.Module):
def __init__(self, num_streams, dim, branch=None, layer_idx=0):
super().__init__()
self.num_streams = num_streams
self.branch = branch
# Initialize H_res near identity (use small negative for gradient flow)
init_h_res = torch.full((num_streams, num_streams), -0.1)
init_h_res.fill_diagonal_(0.0)
self.H_res_logits = nn.Parameter(init_h_res)
# H_pre/H_post for depth connections
init_h_pre = torch.full((1, num_streams), -0.1)
init_h_pre[0, layer_idx % num_streams] = 0.0
self.H_pre_logits = nn.Parameter(init_h_pre)
self.H_post_logits = nn.Parameter(torch.zeros(1, num_streams))
def forward(self, x):
s = self.num_streams
x = rearrange(x, "(b s) t d -> b t s d", s=s)
h_res = sinkhorn_knopp(self.H_res_logits)
x_mixed = einsum(h_res, x, "s t, b n s d -> b n t d")
h_pre = self.H_pre_logits.softmax(dim=-1)
branch_in = einsum(h_pre, x, "v s, b n s d -> b n v d").squeeze(-2)
branch_out = self.branch(branch_in) if self.branch else branch_in
h_post = self.H_post_logits.softmax(dim=-1)
depth_out = einsum(branch_out, h_post, "b t d, v s -> b t s d")
output = x_mixed + depth_out
return rearrange(output, "b t s d -> (b s) t d")
Common Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, einsum, repeat, reduce
When to Use What
| Scenario | Approach | |----------|----------| | Standard residual connection | No mHC needed | | Deep networks (>12 layers) with stability issues | Use mHC with num_streams=4 | | GPT/Transformer training | Wrap both attention and MLP with HyperConnections | | Custom Sinkhorn iterations | Adjust num_iters (20 default) and tau (0.05 default) | | Memory-constrained training | Reduce num_streams or batch size |
External Resources
- mHC Paper: https://arxiv.org/abs/2512.24880
- Hyper-Connections: https://arxiv.org/abs/2409.19606
- Sinkhorn's Theorem: https://en.wikipedia.org/wiki/Sinkhorn%27s_theorem