Skip to content

Learning Spatial Decay for Vision Transformers

Conference: AAAI 2026 arXiv: 2508.09525 Code: None Area: LLM NLP Keywords: Vision Transformer, spatial decay, attention mechanism, content-aware gating, image classification

TL;DR

This paper proposes the Spatial Decay Transformer (SDT), which for the first time adapts data-dependent spatial decay mechanisms from 1D sequence modeling to 2D vision Transformers. Through a Context-Aware Gating (CAG) module that generates dynamic, content-dependent decay intensities for patch interactions, SDT consistently outperforms strong baselines such as RMT on ImageNet-1K classification and generation tasks.

Background & Motivation

Background: Vision Transformers (ViTs) achieve global receptive fields via self-attention, but their permutation equivariance renders them entirely insensitive to the 2D spatial structure of images, requiring the model to learn basic spatial relationships from data alone.

Limitations of Prior Work: Methods such as RMT introduce fixed, data-independent spatial decay matrices based on Manhattan distance. This strategy is fundamentally rigid — it imposes the same spatial decay pattern regardless of image content, and cannot adaptively focus on semantically relevant regions.

Key Challenge: Semantically related regions should maintain strong attention connections even when spatially distant, while irrelevant regions should be suppressed even when spatially adjacent. Fixed decay cannot achieve this flexibility.

Inspiration from LLMs: Works such as GLA, HGRN2, Mamba2, and Forgetting Transformer demonstrate that content-aware, data-dependent gating significantly outperforms data-independent fixed positional biases.

Key Insight: Adapting 1D data-dependent decay to the 2D spatial domain introduces unique challenges: bidirectional spatial dependencies, non-causal relationships, and 2D topological complexity.

Core Idea: Design a Context-Aware Gating (CAG) mechanism that generates dynamic, content-dependent decay intensities for each pair of patch interactions, jointly incorporating Manhattan-distance spatial priors and learned content representations.

Method

Overall Architecture

SDT adopts a four-stage hierarchical design (SDT-H) that progressively reduces spatial resolution while increasing feature dimensionality. Each stage consists of several Spatial Decay Layers, each comprising a Spatial Decay Attention (SDA) module and an FFN. The first two high-resolution stages use a decomposed implementation to reduce computational cost, while the latter two low-resolution stages employ full 2D spatial decay.

Key Designs

  1. Context-Aware Gating (CAG):

    • Function: Generates content-dependent decay intensities for each pair of patch interactions.
    • Mechanism: Input features \(\mathbf{X}\) are projected via learnable weights to produce head-specific decay logits \(\mathbf{F} = \mathbf{X}\mathbf{W}_g\), which are then transformed via log-sigmoid into bounded decay intensities \(\mathbf{G} = \log\sigma(\mathbf{F}) \in (-\infty, 0]\).
    • Design Motivation: Different heads can learn different types of spatial relationships (e.g., local texture vs. global object structure); the log-sigmoid transformation ensures gradient stability.
  2. Spatial-Content Fusion Framework:

    • Function: Unifies fixed 2D geometric priors with adaptive content representations.
    • Mechanism: The combined decay is computed as \(\mathbf{M}_{\text{combined}}[i,j] = \frac{1}{2}(\mathbf{G}[i,:] + \mathbf{G}[j,:]) \cdot d_M(\mathbf{p}_i, \mathbf{p}_j) \cdot \alpha\), where the average of the gating vectors at two positions (ensuring symmetry and reciprocity) is multiplied by the Manhattan distance. The final decay is \(\mathbf{M}_{\text{decay}}[i,j] = -|\mathbf{M}_{\text{combined}}[i,j]|\).
    • Design Motivation: The negative absolute value ensures that attention scores are only attenuated rather than amplified, maintaining stable gradient flow.
  3. Efficient Decomposed Implementation:

    • Function: Addresses the memory overhead of \(O(L^2)\) spatial decay masks in high-resolution stages.
    • Mechanism: Horizontal and vertical 1D data-dependent decays are computed separately, reducing complexity from \(O(L^2)\) to \(O(H^2 + W^2)\).
    • Design Motivation: Substantially reduces memory consumption in the first two high-resolution stages while preserving the data-dependent property.

Loss & Training

Standard training on ImageNet-1K for 300 epochs with the AdamW optimizer. RoPE and Local Position Encoding (depth-wise convolution) are incorporated to enhance positional awareness.

Key Experimental Results

Main Results: ImageNet-1K Classification

Model Params FLOPs Top-1 Acc
RMT-T 14M 2.5G 82.4%
SDT-H-T 14M 2.7G 82.7%
RMT-S 27M 4.5G 84.1%
SDT-H-S 27M 4.8G 84.2%
RMT-B 54M 9.7G 85.0%
SDT-H-B 54M 10.8G 85.1%

Ablation Study: Data-Dependent vs. Data-Independent Decay

Configuration Top-1 Acc Note
Fixed decay (RMT) 82.4% Fixed Manhattan distance decay
Data-dependent (CAG) 82.7% Content-aware decay (+0.3%)
w/o spatial prior 82.3% Distance term removed; content gate only
w/o content gate 82.1% CAG removed; distance only

Key Findings

  • SDT consistently outperforms RMT at all scales (T/S/B), demonstrating the superiority of data-dependent spatial decay over fixed decay.
  • Spatial priors and content gating are mutually indispensable — their combination outperforms either component used in isolation.
  • Improvements on generation tasks (DiT integration) indicate the generality of the proposed method.
  • The decomposed implementation effectively reduces memory in the first two stages with negligible impact on final accuracy.

Highlights & Insights

  • This work is the first to successfully transfer data-dependent decay from LLMs to 2D vision Transformers, bridging attention mechanism design between NLP and CV. This demonstrates that attention innovations in LLMs can be systematically transferred to visual tasks.
  • The spatial-content fusion framework is elegant in design — symmetry is ensured by averaging the gating vectors of two positions, and geometric priors are preserved by multiplying with distance. This "prior × adaptive" paradigm is generalizable to other scenarios requiring structural priors.

Limitations & Future Work

  • Performance gains at the Tiny/Small/Base scales are modest (+0.1–0.3%); whether larger gains emerge at greater scale remains unexplored.
  • FLOPs increase noticeably (e.g., Base scale: 9.7G → 10.8G), necessitating further efficiency optimization.
  • Evaluation is limited to ImageNet-1K classification and DiT generation; downstream dense prediction tasks such as detection and segmentation are absent.
  • The decomposed implementation in the first two stages may sacrifice some global interaction information.
  • The log-sigmoid transformation constrains decay intensity to \((-\infty, 0]\), but whether this monotonic constraint limits model expressiveness remains unexamined.
  • The appropriateness of Manhattan distance as a spatial metric for specific vision tasks (e.g., super-resolution, optical flow) warrants further investigation.
  • Speed comparisons with other efficient attention methods (e.g., FlashAttention, Linear Attention) are lacking.
  • vs. RMT: RMT employs fixed Manhattan distance decay; SDT augments this with dynamic content-dependent gating, realizing an adaptive pattern wherein semantically close regions receive stronger attention and semantically distant ones are suppressed.
  • vs. Forgetting Transformer: FOX introduces learnable forgetting gates in 1D sequences; SDT extends this idea to the 2D spatial domain, addressing the challenges of bidirectional dependencies and non-causal relationships.
  • vs. Swin Transformer: Swin introduces locality via window attention; SDT implicitly encodes a locality preference within global attention through spatial decay, offering greater flexibility.

Rating

  • Novelty: ⭐⭐⭐⭐ First adaptation of 1D data-dependent decay to 2D visual tasks, with clear theoretical motivation.
  • Experimental Thoroughness: ⭐⭐⭐ Covers classification and generation tasks, but lacks dense prediction evaluations (detection/segmentation).
  • Writing Quality: ⭐⭐⭐⭐ Mathematical derivations are rigorous; the progression from 1D to 2D is logically clear.
  • Value: ⭐⭐⭐ Modest performance gains, but the direction is instructive and establishes a new paradigm for visual attention design.

Supplementary Notes

  • The learnable spatial decay bias introduced by CAG can be generalized to temporal decay in video understanding and cross-modal decay in cross-modal attention.
  • A comparison with Mamba-style state space models would be meaningful — both introduce positionally-aware biases, but via different mechanisms.