Skip to content

ChA-MAEViT: Unifying Channel-Aware Masked Autoencoders and Multi-Channel Vision Transformers for Improved Cross-Channel Learning

Conference: NeurIPS 2025 arXiv: 2503.19331 Code: GitHub Area: Multi-Channel Image Processing / Remote Sensing / Cell Microscopy Keywords: Masked Autoencoder, Multi-Channel Imaging, Vision Transformer, Cross-Channel Learning, Self-Supervised Learning

TL;DR

This paper proposes ChA-MAEViT, which enhances cross-channel feature learning for multi-channel images (MCI) through four key components: dynamic channel-patch joint masking, memory tokens, hybrid token fusion, and a channel-aware decoder. The method outperforms the state of the art by an average of 3.0–21.5% across three satellite and microscopy datasets.

Background & Motivation

  • Core Problem: Multi-channel images (MCI), such as multispectral+LiDAR in satellite remote sensing and fluorescence+brightfield in cell microscopy, exhibit variable channel counts and types at both training and inference time, necessitating a single model capable of adapting to diverse channel configurations.
  • Limitations of Prior Work: Existing MCI-MAE methods (e.g., CA-MAE) rely solely on random patch masking, assuming significant redundancy across channels — an assumption that holds for natural RGB images but fails for MCI, where channels are often complementary with minimal feature overlap. Attention analysis reveals that patches primarily attend to their own channels (diagonal pattern), indicating that cross-channel interactions are not effectively learned.
  • Key Challenge: Existing methods fail to model the complex relationships among heterogeneous channels and exhibit insufficient robustness to missing channels.
  • Key Insight: By simultaneously masking channels and patches, the model is compelled to reconstruct missing information from other channels, thereby strengthening cross-channel dependency learning.

Method

Dynamic Channel-Patch Masking (DCP Masking)

Core Idea: The masking strategy is decomposed into random patch masking (fixed ratio \(r_p\), e.g., 75%, with positions sampled independently per channel) and dynamic channel masking (uniformly sampling \(k \sim \mathcal{U}\{0,...,c-1\}\) channels to be fully masked). Two hyperparameters \(p_\text{patch}\) and \(p_\text{channel}\) govern the probability of applying each masking type:

  • \(p_\text{patch}=p_\text{channel}=0\): both masks are merged into a unified mask
  • \(p_\text{patch}=p_\text{channel}=0.5\): the two masks are applied alternately

Unlike Hierarchical Channel Sampling, masked channels serve as supervision signals for reconstruction rather than being discarded, enabling the model to directly learn inter-channel relationships.

Memory Tokens

\(l\) learnable memory embeddings (default: 4) are introduced to store global cross-channel information as long-term memory. During training, they aggregate channel features via self-attention; at inference time, they assist in handling missing channels. Attention analysis reveals that different channel types specialize to different memory tokens (e.g., the VH channel attends to token 8, while the Lee-filtered channel attends to token 1).

Channel-Aware Decoder

A single shared decoder processes tokens from all channels simultaneously (in contrast to CA-MAE's per-channel independent decoders), injecting channel-specific information by adding patch tokens to their corresponding channel tokens. Only 1–2 Transformer blocks are required. The loss function is a weighted combination of pixel-space L2 loss and Fourier-space L1 loss.

Hybrid Token Fusion Module

Learnable queries \(\mathbf{q}_\text{patch}\) perform cross-attention over patch tokens; the result is then element-wise multiplied with the [CLS] token and enhanced via an MLP: \(f_\text{final} = \text{Linear}(\text{GELU}(\text{Linear}(f_\text{fusion})))\).

Overall Training Objective

\[\mathcal{L}_\text{final} = (1-\lambda_\text{recon}) \cdot (\mathcal{L}_\text{task} + \lambda_d \cdot \mathcal{L}_d) + \lambda_\text{recon} \cdot \mathcal{L}_\text{recon}\]

where \(\lambda_\text{recon}=0.99\) and \(\lambda_d=0.001\).

Key Experimental Results

Main Results on Three Datasets (Classification / Representation Learning Accuracy)

Method CHAMMI Avg JUMP-CP Full JUMP-CP Partial So2Sat Full So2Sat Partial
DiChaViT 69.77 69.19 57.98 63.36 47.76
CA-MAE+Sup 59.15 69.54 20.93 64.21 15.75
ChA-MAEViT 74.63 90.73 68.05 67.44 52.11

Gains: CHAMMI +5.0%, JUMP-CP Full +21.5%, So2Sat Full +3.0%.

Ablation Study

Variant CHAMMI Avg JUMP-CP Full So2Sat Full
Full ChA-MAEViT 74.63 90.73 67.44
w/o DCP Masking 70.51 88.01 64.50
w/o Memory Tokens 73.62 87.81 65.18
w/o Channel-Aware Decoder 72.95 87.52 65.78
w/o Hybrid Token Fusion 73.84 88.25 65.48

Removing DCP Masking incurs the largest performance drop (CHAMMI −4.12%, JUMP-CP −2.72%).

Robustness to Missing Channels (JUMP-CP, trained on 8 channels)

Method 8ch 7ch 6ch 5ch 4ch
DiChaViT 69.19 61.91 54.49 46.35 38.00
ChA-MAEViT 90.73 83.36 74.55 63.46 50.85

38-Cloud Segmentation Task

Method Accuracy IoU F1
DiChaViT 0.951 0.857 0.923
ChA-MAEViT 0.964 0.894 0.944

Highlights & Insights

  1. Attention patterns validate the design motivation: After applying DCP Masking, patch attention shifts from concentrating on the same channel (diagonal pattern) to being evenly distributed across all channels, intuitively confirming that cross-channel interactions are effectively activated.
  2. Specialization of memory tokens: Different channel types automatically focus on distinct memory tokens (e.g., SAR VH → token 8, optical → token 1), reflecting an implicit division of roles among channels.
  3. A single shared decoder outperforms independent decoders: This design is more scalable (So2Sat has 18 channels) and achieves superior performance.
  4. Self-supervised and supervised learning are complementary: Combining DCP Masking alone with DiChaViT already surpasses all other SSL methods by 0.6–5.6%.

Limitations & Future Work

  1. Only classification and segmentation tasks are evaluated: Dense prediction tasks such as object detection and semantic segmentation are not explored.
  2. Computational overhead is not thoroughly analyzed: DCP Masking requires additional mask-sampling logic; its impact on large-scale deployment remains unclear.
  3. Assumptions about inter-channel relationships: For completely unrelated channel combinations (e.g., acoustic + optical), the benefit of cross-channel learning has not been established.
  4. Datasets are biased toward remote sensing and biology: Generalization to other MCI scenarios (e.g., multi-sensor fusion in robotics) has not been validated.
  • Evolution of MAE for MCI: from standard MAE's random patch masking → CA-MAE's per-channel independent decoding → this work's joint channel-patch masking with a shared decoder.
  • Channel-adaptive ViTs: ChannelViT and DiChaViT address variable channel counts; this work builds upon them by incorporating a self-supervised objective to strengthen feature learning.
  • Broader Inspiration: The DCP Masking concept can be generalized to multimodal learning (e.g., joint masked pre-training over vision, language, and audio).

Rating

⭐⭐⭐⭐ — The method is systematically designed with four complementary components, and experiments demonstrate substantial improvements across three benchmarks (notably JUMP-CP +21.5%). Attention analysis clearly validates the design motivation. The primary limitation is the relatively niche application domain (MCI).