← ra312.github.io

Compressing Long Contexts
with Log-Signatures

Can a century-old idea from stochastic analysis — the path signature — solve the O(n²) attention bottleneck? A concrete proposal with working JAX code.

The Problem

Standard self-attention computes pairwise interactions between every token and every other token. That is O(n²) in sequence length. At 1 million tokens, full attention is simply infeasible. The usual responses — sparse attention, linear attention, sliding windows, state-space models — all approximate or restrict which tokens can interact.

I want to ask a different question. Instead of "which token pairs should interact?", ask: "what is the geometric summary of the path that tokens trace through embedding space?" That summary — the log-signature — is a fixed-size object, independent of sequence length. Attention over summaries instead of raw tokens breaks the O(n²) barrier.

What Is a Log-Signature?

Treat the sequence of token embeddings X₀, X₁, …, Xₙ ∈ ℝd as a piecewise-linear path in ℝd. The signature of this path is an infinite series of iterated integrals that encodes its shape completely — including the order in which dimensions are traversed. The log-signature is its logarithm in the tensor algebra: it lives in the free Lie algebra over ℝd and is strictly more compact.

Key property
The truncated log-signature at depth k is a vector of fixed size NLie(d, k) — independent of sequence length n. For d = 512, k = 2: size = 512 + 512·511/2 ≈ 131K, far smaller than the raw sequence. For k = 1: just 512. Depth controls the expressiveness/cost tradeoff.

A Concrete Example (d = 2, k = 2)

Four token embeddings in ℝ²: X₀ = (0,0), X₁ = (1,2), X₂ = (3,1), X₃ = (4,4).

Increments: δX₁ = (1, 2) δX₂ = (2, −1) δX₃ = (1, 3) Depth-1 (total displacement): L¹ = 1+2+1 = 4 L² = 2+(−1)+3 = 4 → (4, 4) — order-blind, a plain sum Depth-2 (pairwise ordered interactions): S¹² = Σ_{p<q} δX_p¹ · δX_q² = (1)(−1)+(1)(3)+(2)(3) = 8 S²¹ = Σ_{p<q} δX_p² · δX_q¹ = (2)(2)+(2)(1)+(−1)(1) = 5 Log-sig depth-2: L¹² = S¹² − S²¹ = 8 − 5 = 3 (only the antisymmetric part) Final: LogSig(X) = (4, 4, 3) — 3 numbers instead of 4 + 4×4 = 20

The single depth-2 number, L¹² = 3 > 0, tells us the path tended to move in dimension 1 before dimension 2. Reverse the token order and you get −3. The log-signature is order-sensitive in a compact, algebraically principled way — which is exactly what a positional encoding should be.

Two Ways to Use This in a Transformer

Part 1 — Context Compression

Divide the sequence into chunks of size w (say 64–128 tokens). Compute the log-signature of each chunk. Then combine adjacent chunks hierarchically using the Baker–Campbell–Hausdorff (BCH) formula, which is exactly the rule for composing log-signatures of concatenated paths. The result is a binary tree of summaries:

tokens: [t1 t2 t3 t4] [t5 t6 t7 t8] [t9 t10 t11 t12] [t13 t14 t15 t16] └── LS ──┘ └── LS ──┘ └── LS ──┘ └── LS ──┘ └───── BCH ──────┘ └────── BCH ──────┘ └──────────── BCH ────────────────┘

Attention runs over the m = n/w leaf summaries (or the full hierarchy of summaries at multiple scales), not over the n raw tokens. Cost:

Build hierarchy: O(n · poly(d, k)) ← linear in n Attention: O(m² · d) where m = n/w ← quadratic in m, not n For n = 1M, w = 64: m = 15,625 — still large. Use linear attention over summaries → O(m·d). Total: O(n · poly(d, k)) ← subquadratic in n overall

Part 2 — Positional Encoding

For each position i, compute the log-signature of the path from position 0 to i incrementally using Chen's identity:

LogSig(0→i) = BCH( LogSig(0→i−1), LogSig(i−1→i) )

This costs O(poly(d,k)) per position and O(n·poly(d,k)) total. The resulting vector is injected into the embedding layer as the positional encoding — replacing or augmenting sinusoidal or learned encodings. Because it derives from the actual token embedding path rather than just the index, it is content-aware, not purely positional.

Why TPUs Love This

The BCH combination rule at depth k=2 is:

combine((A₁, B₁), (A₂, B₂)) = (A₁ + A₂, B₁ + B₂ + A₁ ⊗ A₂)

The outer product A₁ ⊗ A₂ is a rank-1 matrix multiplication — exactly what the TPU's MXU (Matrix Multiply Unit) is designed for. And because every tensor in the tree has fixed shape NLie(d,k), XLA can compile the entire computation without dynamic shapes.

StepOperationsParallelism
IncrementsO(nd)fully parallel
Leaf initO(nd²) outer productsfully parallel
Parallel scanO(n·d²) totalO(n/2) per step
AntisymmetrizeO(nd²)fully parallel
TotalO(nd²)O(log n) depth

Naive sequential scan: same O(nd²) work but O(n) depth. Using jax.lax.associative_scan reduces depth to O(log n) — for n = 1M that is 20 parallel steps instead of 1 million.

Working JAX Code

Here is a minimal, runnable depth-2 log-signature via associative scan in JAX:

import jax
import jax.numpy as jnp
import functools
from jax import vmap

def combine(seg1, seg2):
    """BCH combination at depth 2.
    A: depth-1 (displacement), B: depth-2 (area / Lévy bracket).
    """
    A1, B1 = seg1
    A2, B2 = seg2
    A = A1 + A2
    B = B1 + B2 + jnp.outer(A1, A2)   # outer product → MXU
    return A, B

@functools.partial(jax.jit)
def log_signature_k2(tokens):
    """
    tokens: (n, d)  — sequence of token embeddings
    returns: (A, L) where
      A : (n, d)    depth-1 prefix log-sigs
      L : (n, d, d) depth-2 antisymmetric part (Lie bracket)
    """
    deltas = jnp.diff(tokens, axis=0)           # (n-1, d)
    A_init = deltas
    B_init = vmap(jnp.outer)(deltas, deltas)    # (n-1, d, d)

    # O(log n) parallel depth via associative scan
    A_scan, B_scan = jax.lax.associative_scan(combine, (A_init, B_init))

    # Log-signature depth-2: antisymmetric part only
    L2 = B_scan - jnp.transpose(B_scan, (0, 2, 1))
    return A_scan, L2

# Quick test
key = jax.random.PRNGKey(0)
tokens = jax.random.normal(key, (1024, 64))   # 1024 tokens, 64-dim
A, L = log_signature_k2(tokens)
print(A.shape, L.shape)   # (1023, 64)  (1023, 64, 64)

The combine function is the entire BCH rule at depth 2. XLA compiles this into fused MXU operations. A custom Pallas kernel can further fuse this scan with the subsequent attention to avoid HBM round-trips entirely.

Relation to Existing Work

Rough Transformers (Moreno-Pino et al., NeurIPS 2024) use signature patching and multi-view attention for time series — an important precedent — but do not target TPU memory hierarchy or hierarchical BCH combination. Deep Signature Transforms (Kidger et al., NeurIPS 2019) integrate signatures as learnable layers but without attention or long-context goals. Flash Attention, Mamba, RWKV all reduce the O(n²) cost via sparsity, linearization, or recurrence — orthogonal to the geometric compression approach here.

The specific combination — hierarchical BCH as context compression, hardware-aware Pallas kernel, and log-signature-derived positional encoding — is, to my knowledge, new. The key open question is empirical: does depth-k truncation retain enough information for typical text tasks? For time series (Rough Transformers) the answer seems yes; for language the evidence is thinner and worth measuring carefully.

Honest Open Questions

Truncation depth. Signatures truncated at depth k lose higher-order interactions. Sufficient depth for text is not well-established; better understood in financial time series. This needs ablation.

Path regularity. Rough path theory assumes some regularity (Hölder continuity) of the path. Token embeddings in high dimensions may not satisfy this in a useful sense — the theoretical guarantees may weaken. Empirical validation on long-context benchmarks (PG19, Long-Range Arena) is necessary.

Lie algebra dimension growth. NLie(d, k) grows with both d and k. For d = 512, k = 3: roughly 22M basis elements — impractical without aggressive projection. Depth-2 is the realistic starting point; sparse Lie algebra representations are a possible extension.

Status
This is a research proposal submitted to the 2026 Google TPU Awards program. Code, benchmarks, and a full write-up will be released publicly if the project proceeds. I welcome discussion — reach me at akylzhanov.r@gmail.com.

References

Moreno-Pino et al., Rough Transformers, NeurIPS 2024.
Kidger, Bonnier, Perez Arribas, Salvi, Lyons, Deep Signature Transforms, NeurIPS 2019.
Chevyrev & Kormilitzin, A primer on the signature method in machine learning, 2016.
Lyons, Rough paths, signatures and the modelling of functions on streams, ICM 2014.
Casas & Murua, An efficient algorithm for computing the BCH series, 2009.
Google JAX / Pallas documentation.

I am a mathematician turned ML researcher. My background is in harmonic analysis on non-commutative spaces (quantum groups, von Neumann algebras) — which is how I ended up thinking about paths and algebras in the context of language models. More at ra312.github.io.