RestoreGrad: Signal Restoration Using Conditional Denoising Diffusion Models with Jointly Learned Prior¶
Conference: ICML2025
arXiv: 2502.13574
Code: To be confirmed
Area: Diffusion Models / Signal Restoration
Keywords: Diffusion Models, Signal Restoration, Learnable Prior, VAE-DDPM Integration, Speech Enhancement, Image Restoration
TL;DR¶
The RestoreGrad framework is proposed to jointly learn the prior distribution of conditional DDPMs (as opposed to a fixed standard Gaussian) using a Prior Net and a Posterior Net. By leveraging the correlation between degraded and clean signals to construct a more informative prior, it achieves 5-10× faster convergence and 2-2.5× fewer inference steps in speech enhancement and image restoration tasks.
Background & Motivation¶
Conditional DDPMs have demonstrated outstanding performance in signal restoration (speech enhancement, image restoration), but face two critical bottlenecks:
Slow Convergence: Standard DDPMs adopt standard Gaussian \(\mathcal{N}(\mathbf{0}, \mathbf{I})\) as the prior distribution. The large gap between this and the real data distribution leads to training requiring a large number of iterations to converge.
High Number of Inference Steps: The gap between the prior and the data distribution also means that the reverse process requires more steps to reconstruct the signal.
Limitations of Prior Work: PriorGrad proposed a rule-based method to extract prior information from conditional inputs. However, such handcrafted priors require domain knowledge, and suitable rules cannot be easily designed for all tasks.
Key Insight: In signal restoration tasks, the degraded signal \(\mathbf{y}\) is inherently a corrupted version of the clean signal \(\mathbf{x}_0\), exhibiting a strong correlation. Standard Gaussian priors completely discard this correlated information, resulting in a loss of efficiency.
Method¶
Core Idea: VAE + DDPM Integration¶
RestoreGrad seamlessly integrates DDPM into a VAE framework:
- DDPM as Decoder: Leveraging the strong generative capability of diffusion models for signal reconstruction.
- VAE Encoder for Prior Learning: Jointly learning a better latent space through a Prior Net (\(\psi\)) and a Posterior Net (\(\phi\)).
Architecture Design¶
The framework consists of three learnable modules:
| Module | Parameterization | Input | Function |
|---|---|---|---|
| Conditional DDPM (\(\theta\)) | Noise estimation network | \(\mathbf{x}_t, \mathbf{y}, t\) | Predicts noise, performs reverse denoising |
| Prior Net (\(\psi\)) | Prior encoder | \(\mathbf{y}\) | Provides informative prior during inference |
| Posterior Net (\(\phi\)) | Posterior encoder | \(\mathbf{x}_0, \mathbf{y}\) | Utilizes clean signals to assist learning during training |
Both prior and posterior distributions are modeled as zero-mean Gaussians:
New ELBO Derivation¶
By integrating the diffusion process into the VAE framework via Proposition 3.1, a new lower bound for the conditional data log-likelihood is derived (Eq. 9), consisting of two components: - ELBO of the conditional DDPM (reconstruction term) - KL divergence between the Prior Net and Posterior Net (prior matching term)
Loss & Training¶
The final joint training objective (Eq. 11) consists of three terms:
- Latent Regularization (LR): \(\mathcal{L}_{\text{LR}} = \bar{\alpha}_T \|\mathbf{x}_0\|^2_{\boldsymbol{\Sigma}^{-1}_{\text{post}}} + \log|\boldsymbol{\Sigma}_{\text{post}}|\), preventing the posterior covariance from increasing infinitely.
- Denoising Matching (DM): \(\mathcal{L}_{\text{DM}} = \|\boldsymbol{\epsilon} - \boldsymbol{\epsilon}_\theta(\mathbf{x}_t, \mathbf{y}, t)\|^2_{\boldsymbol{\Sigma}^{-1}_{\text{post}}}\), training the DDPM to predict the ground-truth noise, with weighting by the inverse of the posterior covariance.
- Prior Matching (PM): \(\mathcal{L}_{\text{PM}} = \log\frac{|\boldsymbol{\Sigma}_{\text{prior}}|}{|\boldsymbol{\Sigma}_{\text{post}}|} + \text{tr}(\boldsymbol{\Sigma}^{-1}_{\text{prior}}\boldsymbol{\Sigma}_{\text{post}})\), aligning prior and posterior distributions.
Where \(\eta, \lambda > 0\) are hyperparameters, with \(\eta\) controlling the regularization strength and \(\lambda\) controlling the prior matching weight.
Inference Procedure¶
During inference, the Posterior Net is no longer used (as there is no clean signal \(\mathbf{x}_0\)), and the DDPM directly samples from the Prior Net: \(\boldsymbol{\epsilon} \sim p_\psi(\boldsymbol{\epsilon}|\mathbf{y}) = \mathcal{N}(\mathbf{0}, \boldsymbol{\Sigma}_{\text{prior}})\).
Key Experimental Results¶
Speech Enhancement (VoiceBank+DEMAND)¶
| Method | Training Epochs | Inference Steps | PESQ↑ | CSIG↑ | CBAK↑ | COVL↑ | SI-SNR↑ |
|---|---|---|---|---|---|---|---|
| CDiffuSE (Baseline) | 445 | 6 | 2.44 | 3.66 | 2.83 | 3.03 | - |
| + PriorGrad | 96 | 6 | 2.42 | 3.67 | 2.93 | 3.03 | 14.21 |
| + RestoreGrad | 96 | 6 | 2.51 | 3.80 | 3.00 | 3.14 | 14.74 |
| + RestoreGrad | 96 | 3 | 2.50 | 3.75 | 2.99 | 3.11 | 14.65 |
Key points: RestoreGrad surpasses the baseline's 445-epoch result using only 96 epochs; performance is barely degraded when inference steps are halved.
Image Restoration (AllWeather → Multi-weather Evaluation)¶
| Method | Snow100K-L PSNR/SSIM | Outdoor-Rain PSNR/SSIM | RainDrop PSNR/SSIM |
|---|---|---|---|
| WeatherDiff (1775 epochs) | 30.09/0.904 | 29.64/0.931 | 30.71/0.931 |
| + RestoreGrad (887 epochs) | 30.82/0.916 | 30.83/0.941 | 31.78/0.939 |
| + RestoreGrad (1551 epochs) | 31.16/0.918 | 30.70/0.942 | 32.26/0.941 |
Key points: Halving the training epochs is sufficient to comprehensively outperform the baseline; longer training yields further improvements, comparable to the weather-specific DTPM model.
Encoder Overhead Analysis¶
| Encoder Size | PESQ↑ | SI-SNR↑ | Latency Ratio | VRAM Ratio |
|---|---|---|---|---|
| Tiny (24K) | 2.48 | 13.74 | 1.9% | 6.5% |
| Base (93K) | 2.51 | 14.74 | 2.2% | 10.3% |
| Large (370K) | 2.54 | 15.01 | 2.6% | 18.2% |
Encoder overhead is minimal (<3% latency, <19% VRAM), and performance steadily improves as the encoder size increases.
Posterior Net Ablation Study¶
| Configuration | PESQ↑ | COVL↑ | SI-SNR↑ |
|---|---|---|---|
| CDiffuSE Baseline | 2.32 | 2.89 | 11.84 |
| + PriorGrad | 2.42 | 3.03 | 14.21 |
| + RestoreGrad (Full) | 2.51 | 3.14 | 14.74 |
| + RestoreGrad w/o Posterior (η=0) | — | Training Diverges | — |
| + RestoreGrad w/o Posterior (η=1) | 2.48 | 3.12 | 13.29 |
The Posterior Net is crucial for stabilizing training and improving performance.
Highlights & Insights¶
- Theoretical Elegance of VAE-DDPM Integration: It seamlessly embeds DDPM into the VAE framework through a new ELBO derivation, inheriting both the generation capability of DDPM and the modeling efficiency of VAE.
- Extremely Low Overhead: The encoder parameters are only 0.3%-2% of the DDPM, with a latency of <3%, yet it leads to a notable acceleration in convergence.
- Cross-modal Versatility: The same framework is effective for both speech (1D waveforms) and images (2D). Visualizations of the covariance learned by the Prior Net show that it automatically captures signal structures.
- Dual Acceleration in Training and Inference: 5-10× training acceleration + 2-2.5× reduction in inference steps, demonstrating outstanding practical value.
- Clever Design of the Posterior Net: Ground-truth information is utilized during training to guide prior learning, and the Posterior Net is discarded during inference, incurring zero additional inference cost.
Limitations & Future Work¶
- Restricted Prior Formulation: Currently, it assumes a zero-mean Gaussian and only learns the covariance. This might be insufficiently flexible for scenarios with non-zero means or non-Gaussian distributions.
- Evaluation Limited to Signal Restoration Tasks: Broader applications such as conditional generation (e.g., text-to-image, text-to-speech) have not been explored.
- Hyperparameter Sensitivity: Although robust within a wide range of \(\eta\), adjustments are still required for \(\eta\) and \(\lambda\), and the optimal values may vary across different tasks.
- Simple Encoder Architecture: A fixed ResNet-20 is used, without exploring stronger encoders or adaptive architectures.
- Lack of Comparison with Recent Methods: No comparison has been made with newer paradigms such as consistency models or flow matching.
Related Work & Insights¶
- PriorGrad (Lee et al., 2022): A pioneer in handcrafted priors; RestoreGrad can be regarded as its learnable upgrade.
- CDiffuSE (Lu et al., 2022): DDPM baseline for speech enhancement.
- WeatherDiff (Özdenizci & Legenstein, 2023): DDPM baseline for multi-weather image restoration.
- DTPM (Ye et al., 2024): Diffusion texture prior model, which requires large-scale pre-training.
Rating¶
- Novelty: ⭐⭐⭐⭐ — The approach of VAE-DDPM integration to learn priors is novel, accompanied by a complete theoretical derivation.
- Experimental Thoroughness: ⭐⭐⭐⭐ — Verified across modalities with comprehensive ablation studies and OOD generalization tests.
- Writing Quality: ⭐⭐⭐⭐ — The mathematical derivations are clear, and the presentation of figures and tables is rich.
- Value: ⭐⭐⭐⭐ — Highly practical plug-and-play acceleration for existing conditional DDPMs.