The Problem
Standard self-attention computes pairwise interactions between every token and every other token. That is O(n²) in sequence length. At 1 million tokens, full attention is simply infeasible. The usual responses — sparse attention, linear attention, sliding windows, state-space models — all approximate or restrict which tokens can interact.
I want to ask a different question. Instead of "which token pairs should interact?", ask: "what is the geometric summary of the path that tokens trace through embedding space?" That summary — the log-signature — is a fixed-size object, independent of sequence length. Attention over summaries instead of raw tokens breaks the O(n²) barrier.
What Is a Log-Signature?
Treat the sequence of token embeddings X₀, X₁, …, Xₙ ∈ ℝd as a piecewise-linear path in ℝd. The signature of this path is an infinite series of iterated integrals that encodes its shape completely — including the order in which dimensions are traversed. The log-signature is its logarithm in the tensor algebra: it lives in the free Lie algebra over ℝd and is strictly more compact.
A Concrete Example (d = 2, k = 2)
Four token embeddings in ℝ²: X₀ = (0,0), X₁ = (1,2), X₂ = (3,1), X₃ = (4,4).
The single depth-2 number, L¹² = 3 > 0, tells us the path tended to move in dimension 1 before dimension 2. Reverse the token order and you get −3. The log-signature is order-sensitive in a compact, algebraically principled way — which is exactly what a positional encoding should be.
Two Ways to Use This in a Transformer
Part 1 — Context Compression
Divide the sequence into chunks of size w (say 64–128 tokens). Compute the log-signature of each chunk. Then combine adjacent chunks hierarchically using the Baker–Campbell–Hausdorff (BCH) formula, which is exactly the rule for composing log-signatures of concatenated paths. The result is a binary tree of summaries:
Attention runs over the m = n/w leaf summaries (or the full hierarchy of summaries at multiple scales), not over the n raw tokens. Cost:
Part 2 — Positional Encoding
For each position i, compute the log-signature of the path from position 0 to i incrementally using Chen's identity:
This costs O(poly(d,k)) per position and O(n·poly(d,k)) total. The resulting vector is injected into the embedding layer as the positional encoding — replacing or augmenting sinusoidal or learned encodings. Because it derives from the actual token embedding path rather than just the index, it is content-aware, not purely positional.
Why TPUs Love This
The BCH combination rule at depth k=2 is:
The outer product A₁ ⊗ A₂ is a rank-1 matrix multiplication — exactly what the TPU's MXU (Matrix Multiply Unit) is designed for. And because every tensor in the tree has fixed shape NLie(d,k), XLA can compile the entire computation without dynamic shapes.
| Step | Operations | Parallelism |
|---|---|---|
| Increments | O(nd) | fully parallel |
| Leaf init | O(nd²) outer products | fully parallel |
| Parallel scan | O(n·d²) total | O(n/2) per step |
| Antisymmetrize | O(nd²) | fully parallel |
| Total | O(nd²) | O(log n) depth |
Naive sequential scan: same O(nd²) work but O(n) depth. Using
jax.lax.associative_scan reduces depth to O(log n) — for n = 1M that
is 20 parallel steps instead of 1 million.
Working JAX Code
Here is a minimal, runnable depth-2 log-signature via associative scan in JAX:
import jax
import jax.numpy as jnp
import functools
from jax import vmap
def combine(seg1, seg2):
"""BCH combination at depth 2.
A: depth-1 (displacement), B: depth-2 (area / Lévy bracket).
"""
A1, B1 = seg1
A2, B2 = seg2
A = A1 + A2
B = B1 + B2 + jnp.outer(A1, A2) # outer product → MXU
return A, B
@functools.partial(jax.jit)
def log_signature_k2(tokens):
"""
tokens: (n, d) — sequence of token embeddings
returns: (A, L) where
A : (n, d) depth-1 prefix log-sigs
L : (n, d, d) depth-2 antisymmetric part (Lie bracket)
"""
deltas = jnp.diff(tokens, axis=0) # (n-1, d)
A_init = deltas
B_init = vmap(jnp.outer)(deltas, deltas) # (n-1, d, d)
# O(log n) parallel depth via associative scan
A_scan, B_scan = jax.lax.associative_scan(combine, (A_init, B_init))
# Log-signature depth-2: antisymmetric part only
L2 = B_scan - jnp.transpose(B_scan, (0, 2, 1))
return A_scan, L2
# Quick test
key = jax.random.PRNGKey(0)
tokens = jax.random.normal(key, (1024, 64)) # 1024 tokens, 64-dim
A, L = log_signature_k2(tokens)
print(A.shape, L.shape) # (1023, 64) (1023, 64, 64)
The combine function is the entire BCH rule at depth 2. XLA compiles this
into fused MXU operations. A custom Pallas kernel can further fuse this scan with the
subsequent attention to avoid HBM round-trips entirely.
Relation to Existing Work
Rough Transformers (Moreno-Pino et al., NeurIPS 2024) use signature patching and multi-view attention for time series — an important precedent — but do not target TPU memory hierarchy or hierarchical BCH combination. Deep Signature Transforms (Kidger et al., NeurIPS 2019) integrate signatures as learnable layers but without attention or long-context goals. Flash Attention, Mamba, RWKV all reduce the O(n²) cost via sparsity, linearization, or recurrence — orthogonal to the geometric compression approach here.
The specific combination — hierarchical BCH as context compression, hardware-aware Pallas kernel, and log-signature-derived positional encoding — is, to my knowledge, new. The key open question is empirical: does depth-k truncation retain enough information for typical text tasks? For time series (Rough Transformers) the answer seems yes; for language the evidence is thinner and worth measuring carefully.
Honest Open Questions
Truncation depth. Signatures truncated at depth k lose higher-order interactions. Sufficient depth for text is not well-established; better understood in financial time series. This needs ablation.
Path regularity. Rough path theory assumes some regularity (Hölder continuity) of the path. Token embeddings in high dimensions may not satisfy this in a useful sense — the theoretical guarantees may weaken. Empirical validation on long-context benchmarks (PG19, Long-Range Arena) is necessary.
Lie algebra dimension growth. NLie(d, k) grows with both d and k. For d = 512, k = 3: roughly 22M basis elements — impractical without aggressive projection. Depth-2 is the realistic starting point; sparse Lie algebra representations are a possible extension.
References
Moreno-Pino et al., Rough Transformers, NeurIPS 2024.
Kidger, Bonnier, Perez Arribas, Salvi, Lyons, Deep Signature Transforms, NeurIPS 2019.
Chevyrev & Kormilitzin, A primer on the signature method in machine learning, 2016.
Lyons, Rough paths, signatures and the modelling of functions on streams, ICM 2014.
Casas & Murua, An efficient algorithm for computing the BCH series, 2009.
Google JAX / Pallas documentation.
I am a mathematician turned ML researcher. My background is in harmonic analysis on non-commutative spaces (quantum groups, von Neumann algebras) — which is how I ended up thinking about paths and algebras in the context of language models. More at ra312.github.io.