MIRNet: Integrating Constrained Graph-Based Reasoning with Pre-training for Diagnostic Medical Imaging¶
Conference: AAAI 2026 arXiv: 2511.10013 Code: GitHub Area: Medical Image Analysis / Tongue Diagnosis Keywords: tongue diagnosis, graph attention network, self-supervised pre-training, clinically constrained optimization, multi-label classification
TL;DR¶
MIRNet is a framework that integrates self-supervised masked autoencoder (MAE) pre-training with constraint-aware graph attention network (GAT) reasoning for multi-label tongue diagnosis. The paper also introduces TongueAtlas-4K, a benchmark dataset of 4,000 images with 22 labels, achieving a 77.8% improvement in Macro Recall and 33.2% in Macro-F1.
Background & Motivation¶
Medical image diagnosis requires combining fine-grained visual pattern recognition with domain knowledge-driven reasoning, particularly for understanding complex relationships among statistically correlated diagnostic labels and clinical priors. Tongue analysis is a critical component of Traditional Chinese Medicine (TCM) diagnosis — for instance, a "pale tongue" frequently co-occurs with a "white coating" — yet such domain knowledge remains underutilized in existing methods.
Existing tongue diagnosis approaches suffer from four interrelated problems:
Problem 1: Label Scarcity. Annotating professional medical images is costly and time-consuming, severely limiting supervised learning. Prior work such as jiang2022deep used a dataset of 8,676 annotated images that was never publicly released and covered only 7 categories.
Problem 2: Severe Label Imbalance. The prevalence of different symptoms in tongue diagnosis varies enormously — e.g., "white coating" appears in 78.38% of cases while "dark-red tongue" appears in only 2.15% — resulting in extremely poor detection of rare conditions.
Problem 3: Insufficient Label Correlation Modeling. Diagnostic labels exhibit significant statistical co-occurrence patterns (e.g., "thin tongue" and "enlarged tongue" are mutually exclusive), yet existing approaches (e.g., independent ResNet with late fusion, bidirectional RNN modeling) lack systematic modeling of label dependencies.
Problem 4: Absence of Clinical Plausibility Constraints. Predictions may violate medical common sense (e.g., simultaneously predicting mutually exclusive diagnoses), and existing models have no mechanism to prevent such unreasonable outputs.
Core Idea: Design an end-to-end framework that systematically addresses all four challenges: MAE pre-training for label scarcity, GAT for label dependency modeling, constraint-aware optimization for clinical plausibility, and asymmetric loss for label imbalance.
Method¶
Overall Architecture¶
MIRNet adopts an encoder-decoder architecture: a MAE pre-trained ViT encoder extracts image features; a graph constructed from label co-occurrence is processed by a GAT decoder to model inter-label dependencies; and a constraint-aware multi-objective loss function is used for joint training. The overall pipeline is: Image → MAE Encoder → Visual Features → Label Graph + GAT → Graph-Refined Features → Constraint-Aware Optimization → Multi-label Prediction.
Key Designs¶
-
Masked Autoencoder (MAE) Pre-training
- Function: Learn transferable visual representations on large-scale unannotated tongue images to address label scarcity.
- Mechanism: The input image is divided into \(N\) non-overlapping patches; 75% are randomly masked. The ViT encoder processes visible patches to produce features, and a lightweight decoder reconstructs the masked regions. The loss is the pixel-level MSE over masked patches: \(\mathcal{L}_{\text{MAE}} = \frac{1}{|\mathcal{M}|}\sum_{i \in \mathcal{M}} \|\mathbf{x}_i - \hat{\mathbf{x}}_i\|_2^2\)
- Design Motivation: Pre-training on 15,905 unannotated images enables the encoder to learn anatomical features of the tongue, providing a strong initialization for downstream tasks. The backbone is ViT-Base-Patch16-224 with embed_dim=768, depth=12, and num_heads=12.
- The visual features produced by the pre-trained encoder are used to initialize label node embeddings in the GAT.
-
Label Correlation Modeling via Graph Attention Network
- Function: Propagate information over a graph of diagnostic labels to capture high-order inter-label correlations.
- Mechanism: A label co-occurrence graph is first constructed from the training annotation matrix. The co-occurrence matrix \(\mathbf{M} = \mathbf{Y}^\top\mathbf{Y} - \text{diag}(\mathbf{Y}^\top\mathbf{Y})\) is thresholded dynamically (25th percentile of non-zero co-occurrence values) to form a sparse adjacency matrix. Two-layer GATv2Conv then performs label propagation, aggregating neighbor information via attention coefficients \(\alpha_{ij}\) at each layer. The final prediction concatenates the initial MAE features \(\mathbf{v}_k^{(0)}\) with the graph-refined features \(\mathbf{v}_k^{(L)}\) and passes them through a classification head: \(\hat{y}_k = \sigma(\mathbf{w}_k^\top [\mathbf{v}_k^{(0)} \| \mathbf{v}_k^{(L)}] + b_k)\)
- Design Motivation: This design preserves local visual evidence while injecting context-aware label correlations. Two further enhancements are introduced: rare label augmentation (rescaling attention weights by the log of inverse frequency) and correlation confidence weighting (weighting attention edges by normalized co-occurrence frequency).
-
Constraint-Aware Optimization
- Function: Incorporate clinical knowledge into training as differentiable constraints to ensure clinically plausible predictions.
- Mechanism: The unified optimization objective is \(\min_{\theta,\phi} \mathcal{L}_{\text{ASL}} + \lambda_1 \mathcal{L}_{\text{constraint}} + \lambda_2 \mathcal{L}_{\text{prior}}\), comprising three components:
- Asymmetric Loss (ASL): Addresses positive-negative imbalance via frequency-based weighting \(\gamma_k = \sqrt{\tau/\mathbb{P}(y_k=1)}\) and asymmetric focusing parameters \(\zeta_+ < \zeta_-\).
- Clinical Knowledge Constraints: Mutual exclusion (\(p_a \cdot p_b\)), co-occurrence (\(|p_a - p_b|\)), and implication (\(p_a \cdot (1-p_b)\)) constraints are encoded as differentiable penalty terms via Lagrangian relaxation, incurring loss only when violated.
- Statistical Prior Regularization: KL divergence \(\text{KL}(q(\mathbf{y}|\mathbf{X}) \| p_{\text{data}}(\mathbf{y}))\) aligns the predicted marginal distribution with empirical class priors.
- Design Motivation: With \(\lambda_1=0.1\) and \(\lambda_2=0.05\), constraints are satisfied naturally through gradient optimization without hard-coded rules.
-
Boosting Ensemble Strategy
- Function: Further improve classification performance on low-frequency labels.
- Mechanism: A base model is trained on the full dataset; a second model is fine-tuned on augmented data exclusively for classes with F1 < 0.5. Final predictions use the second model's output for the 5 worst-performing labels and the base model's output for the remainder.
Loss & Training¶
- AdamW optimizer, base learning rate \(1 \times 10^{-3}\), batch size 200, layer-wise learning rate decay with \(\text{layer\_decay}=0.75\), trained for 200 epochs.
- Training performed on an NVIDIA A800 GPU.
- Dataset split 80/10/10 into train/validation/test sets.
- Image preprocessing includes rigorous color correction, DeepLabV3+ segmentation, and manual refinement.
Key Experimental Results¶
Main Results¶
| Model | Example-F1 | Micro-F1 | Macro-F1 | Macro Recall | Macro PR-AUC |
|---|---|---|---|---|---|
| LGAN | 0.634 | 0.640 | 0.397 | 0.369 | 0.492 |
| Faster R-CNN | 0.651 | 0.662 | 0.381 | 0.339 | 0.493 |
| DenseNet121 | 0.648 | 0.657 | 0.403 | 0.364 | 0.351 |
| C-GMVAE | 0.634 | 0.647 | 0.346 | 0.305 | 0.526 |
| MIRNet | 0.680 | 0.683 | 0.525 | 0.599 | 0.527 |
| MIRNet-Boosting | 0.675 | 0.678 | 0.537 | 0.655 | 0.543 |
Ablation Study¶
| Configuration | Example-F1 Change | Macro-F1 Change | Macro Recall Change | Notes |
|---|---|---|---|---|
| MIRNet-C (w/o constraints) | −3.2% | −4.4% | — | Label consistency degraded |
| MIRNet-G (GAT → MLP) | — | −3.2% | −8.1% | Label dependency modeling is indispensable |
| MIRNet-P (w/o pre-training) | — | −23.0% | −29.0% | Most severe impact; pre-training is critical |
Key Findings¶
- MIRNet-Boosting outperforms the strongest baseline by 33.2% in Macro-F1 and 77.8% in Macro Recall, a remarkable margin.
- Even without Boosting, MIRNet surpasses all baselines (Macro Recall +62.5%, Macro-F1 +30.4%).
- Improvements are especially pronounced on rare labels: dark-red tongue (2.15% prevalence) improves from F1 < 0.25 across all baselines to 0.68; gray-black coating improves from near zero to 0.71.
- Dimension-level missed detections are substantially reduced: tongue shape missed detections drop from up to 209 cases in baselines to 14 in MIRNet-Boosting (a 93.3% reduction).
- Ablation analysis confirms that the three components are complementary: pre-training has the greatest impact on rare classes (Macro Recall −29%), constraints provide the largest overall gain (preventing inconsistent labels), and GAT maintains the precision-recall balance.
- MIRNet performs consistently across all four diagnostic dimensions: tongue color 0.81, tongue shape 0.77, coating texture 0.76, coating color 0.84, compared to baseline averages of 0.59, 0.43, 0.51, and 0.68, respectively.
Highlights & Insights¶
- Introducing the "learning-with-reasoning integration" paradigm — analogous to physical constraints in DRNets/PINNs — into medical image diagnosis represents a valuable cross-domain transfer of ideas.
- The constraint-aware optimization design is elegant: softening hard constraints into differentiable loss terms via \(\max(0, \phi_j)\) preserves the semantic meaning of the constraints without obstructing gradient flow.
- The two auxiliary techniques in the GAT — rare label augmentation and correlation confidence weighting — are simple yet effective and are worth adopting in other multi-label classification settings.
- The release of TongueAtlas-4K (4,000 images, 22 labels, consensus annotations from 10 experts) fills a critical gap in the absence of public benchmarks for tongue analysis.
Limitations & Future Work¶
- Although the paper claims the framework generalizes to other medical imaging tasks, all experiments are conducted exclusively on tongue data; validation on other modalities such as X-ray and CT is absent.
- Constraint rules (mutual exclusion, co-occurrence, implication) must be defined manually, requiring domain experts to redesign them for new clinical settings.
- The granularity of 22 labels may still be insufficient for practical TCM clinical use.
- The threshold for the Boosting strategy (F1 < 0.5) and the number of replaced labels (5) appear to be empirical choices whose robustness has not been thoroughly validated.
- The label graph construction relies on training set statistics and may be unreliable when the training set is small or the label distribution is skewed.
Related Work & Insights¶
- MIRNet shares the same philosophical lineage as DRNets, which embed thermodynamic priors into deep learning for materials discovery; MIRNet analogously embeds TCM clinical priors for tongue diagnosis.
- Compared to LGAN (CNN + bidirectional RNN for label correlation) and IFRCNet (dilated convolutions + attention), MIRNet's graph-based reasoning is more explicit and interpretable.
- The ASL loss function's strategy for handling imbalance in multi-label classification is broadly applicable to other domains.
- The decisive impact of pre-training on performance (−23% Macro-F1 in ablation) further corroborates the importance of self-supervised pre-training in annotation-scarce settings.
Rating¶
- Novelty: ⭐⭐⭐⭐
- Experimental Thoroughness: ⭐⭐⭐⭐
- Writing Quality: ⭐⭐⭐⭐
- Value: ⭐⭐⭐⭐