Skip to content

Scalable Spatio-Temporal SE(3) Diffusion for Long-Horizon Protein Dynamics

Conference: ICLR 2026 arXiv: 2602.02128 Code: https://bytedance-seed.github.io/ConfRover/starmd Area: Medical Imaging Keywords: protein conformation generation, SE(3) diffusion model, spatio-temporal attention, autoregressive trajectory generation, molecular dynamics acceleration

TL;DR

This paper proposes STAR-MD, an SE(3)-equivariant causal diffusion Transformer that achieves microsecond-scale protein dynamics trajectory generation via joint spatio-temporal attention and contextual noise perturbation. STAR-MD attains state-of-the-art performance across all metrics on the ATLAS benchmark and stably extrapolates to microsecond timescales unseen during training.

Background & Motivation

Background: Molecular dynamics (MD) simulation is the gold standard for studying protein dynamics, but requires femtosecond-level integration steps and is computationally prohibitive (microsecond-scale simulations demand \(10^9\) steps). Recent generative models have been proposed to accelerate MD, including MDGen (diffusion model for 100 ns trajectories), AlphaFolding (multi-frame simultaneous generation), and ConfRover (autoregressive generation).

Limitations of Prior Work: (a) Existing methods are limited to short timescales (nanoseconds) and cannot scale to biologically relevant microsecond–millisecond regimes. (b) AlphaFold2-style Pairformer with triangular attention incurs \(O(N^3L)\) cubic computational cost and \(O(N^2L)\) KV-cache memory. (c) Existing architectures interleave spatial and temporal modules (space-then-time), limiting their capacity to capture non-separable spatio-temporal couplings. (d) Error accumulation in autoregressive generation of long trajectories is severe.

Key Challenge: Coarse-graining (per-residue representations instead of all-atom) renders protein dynamics non-Markovian (requiring historical memory), yet the computationally expensive pairwise feature processing obstructs modeling over longer historical contexts.

Goal: Design a protein conformation generation model that efficiently handles spatio-temporal dependencies while stably generating microsecond-scale long trajectories.

Key Insight: The paper employs the Mori–Zwanzig formalism to theoretically establish that (a) coarse-graining necessitates historical memory (non-Markovian dynamics), and (b) removing pairwise features causes the memory kernel to "inflate" and become spatio-temporally non-separable — directly motivating the use of joint spatio-temporal attention.

Core Idea: Replace alternating spatial and temporal modules with joint spatio-temporal attention, combined with causal diffusion training and contextual noise perturbation, to enable scalable long-horizon protein dynamics generation.

Method

Overall Architecture

Given a protein sequence and an initial conformation, STAR-MD autoregressively generates subsequent frames: \(\prod_{\ell=1}^{L} p(\mathbf{x}_\ell | \mathbf{x}_{<\ell}, \Delta t_\ell)\). Each frame is generated by an SE(3) diffusion model (denoising score matching), where a causal Transformer extracts conditioning information from historical clean frames and the current noisy frame.

Key Designs

  1. Joint Spatio-Temporal (S×T) Attention:

  2. Function: Applies attention over joint tokens indexed by residue–frame pairs \((i, \ell)\), replacing alternating spatial/temporal modules.

  3. Mechanism: Each token corresponds to a (residue, time frame) pair, enabling direct attention to features of any residue in any previous frame. 2D RoPE encodes residue and frame indices, supporting frame-count extrapolation. Complexity is \(O(N^2 L^2)\) versus \(O(N^3 L + N^2 L^2)\) for Pairformer with temporal attention.
  4. Design Motivation: Mori–Zwanzig theory shows that removing pairwise features causes the memory kernel to become spatio-temporally non-separable — necessitating joint attention rather than factorized attention to model this non-separable coupling.

  5. Block-Causal Attention Training:

  6. Function: Enables parallel training while preserving causal structure.

  7. Mechanism: Clean and noisy frames are concatenated as the input sequence, with a block-level attention mask ensuring each frame attends only to clean versions of preceding frames. Although sequence length doubles, a single forward pass simultaneously optimizes the denoising loss for all frames.
  8. Design Motivation: Aligns parallel teacher-forcing training with sequential autoregressive inference.

  9. Contextual Noise Perturbation:

  10. Function: Mitigates error accumulation in long-horizon autoregressive generation.

  11. Mechanism: During training, small noise \(\tau \sim \mathcal{U}[0, 0.1]\) is added to historical clean frames; the same perturbation is applied at inference — maintaining training–inference consistency and endowing the model with robustness to its own prediction errors.
  12. Design Motivation: Inspired by Diffusion Forcing, the core insight is to expose the model to imperfect historical inputs during training.

  13. Continuous Time Conditioning:

  14. Function: Enables a single model to cover multiple timescales.

  15. Mechanism: Time step size \(\Delta t \sim \text{LogUniform}[10^{-2}, 10^1]\) ns is randomly sampled and injected into the network via AdaLN. Even with a small context window, large step sizes expose the model to long-range temporal dependencies.
  16. Design Motivation: Decouples physical trajectory duration from the number of context frames, eliminating the need for complex context-length extrapolation techniques.

Loss & Training

Denoising score matching loss under SE(3) diffusion is applied, predicting noise separately for translations (Gaussian noise on \(\mathbb{R}^3\)) and rotations (isotropic Gaussian on \(\text{IGSO}_3\)). KV-caching enables efficient autoregressive generation at inference with only \(O(NL)\) memory (versus \(O(N^2 L)\) for ConfRover).

Key Experimental Results

Main Results

ATLAS benchmark 100 ns trajectory generation:

Method Cα-level Validity↑ All-atom Validity↑ Conformation Coverage↑ Dynamics Fidelity↑
MDGen Low Low Medium Low
AlphaFolding Medium Medium Medium Medium
ConfRover High High High High
STAR-MD Highest Highest Highest Highest

STAR-MD achieves state-of-the-art performance across all metrics.

Ablation Study

Configuration Observation
Remove joint S×T attention → alternating modules Significant performance drop, validating the value of joint spatio-temporal modeling
Remove contextual noise perturbation Severe stability degradation on long trajectories (>250 ns)
Remove continuous time conditioning Reduced generalization across timescales
Remove block-causal training Substantially lower training efficiency

Key Findings

  • Long-horizon extrapolation: STAR-MD maintains high structural quality at the microsecond scale (1 μs = 10× training length), whereas baseline methods exhibit catastrophic degradation beyond 250 ns.
  • Contextual noise perturbation is critical for long-horizon stability — without it, even STAR-MD degrades on long trajectories.
  • Joint S×T attention is both computationally more efficient (eliminating the \(O(N^3)\) term from triangular attention) and more expressive.
  • Continuous time conditioning allows a single model to cover step sizes ranging from \(10^{-2}\) to \(10^1\) ns.

Highlights & Insights

  • Rigorous correspondence between theory and architecture: The Mori–Zwanzig formalism is used to motivate joint spatio-temporal attention (memory kernel inflation and non-separability), rather than relying solely on intuition or ablations. This "theory-first → architecture-guided" methodology is exemplary.
  • Training–inference alignment via contextual noise perturbation: The core idea is simple — add noise to historical frames during training as well — yet it is critical for long-horizon stability. This is analogous to scheduled sampling but naturally realized within the diffusion framework.
  • Cascading benefits of removing Pairformer: Eliminating pairwise features reduces both computation (\(O(N^3) \to O(N^2)\)) and KV-cache (\(O(N^2L) \to O(NL)\)), making long-trajectory generation memory-feasible.
  • Elegance of continuous time conditioning: By sampling step sizes from a LogUniform distribution, the model learns long-range temporal dependencies even with a small context window — effectively converting temporal extrapolation into a conditioning problem.

Limitations & Future Work

  • Validation is conducted only at the Cα-level coarse-grained representation; fine-grained side-chain dynamics may be missed without full all-atom modeling.
  • The quality of MD trajectories in the ATLAS dataset is bounded by force field accuracy, imposing an upper limit on the dynamics the model can learn.
  • The \(O(N^2 L^2)\) complexity of joint S×T attention may become a bottleneck for very large proteins or extremely long trajectories.
  • Training is performed exclusively on 100 ns scale data; the reliability of microsecond extrapolation lacks direct physical validation (e.g., accuracy of free energy surfaces).
  • Comparisons with enhanced sampling methods such as Metadynamics are absent.
  • vs. ConfRover: ConfRover employs a Pairformer + IPA + KV-cache autoregressive architecture but suffers from \(O(N^3L)\) computation and \(O(N^2L)\) KV-cache overhead. STAR-MD supersedes it in both efficiency and performance by replacing this with S×T attention.
  • vs. MDGen: MDGen anchors trajectories to keyframes and uses a standard Transformer but achieves suboptimal performance. STAR-MD retains a fully autoregressive structure and is more effective.
  • vs. AlphaFolding: AlphaFolding generates multiple frames simultaneously but discards memory of earlier windows. STAR-MD preserves the full historical context via KV-caching.
  • Analogy to video generation: STAR-MD's architectural design (causal diffusion Transformer + block-causal attention training + noise perturbation for drift prevention) closely parallels techniques in video generation, suggesting convergence between protein dynamics generation and video generation methodology.

Rating

  • Novelty: ⭐⭐⭐⭐⭐ The Mori–Zwanzig theory-driven architecture design is highly elegant, and microsecond-scale extrapolation represents a breakthrough.
  • Experimental Thoroughness: ⭐⭐⭐⭐⭐ Multi-timescale evaluation (100 ns / 250 ns / 1 μs), comprehensive ablations, and multi-dimensional metrics covering structural quality, coverage, and dynamics fidelity.
  • Writing Quality: ⭐⭐⭐⭐ Theoretical and methodological exposition is clear, though notation density requires domain background.
  • Value: ⭐⭐⭐⭐⭐ Represents a significant advance in protein dynamics simulation and may open a new paradigm for generative model-driven drug discovery.