VGB for Masked Diffusion Model: Efficient Test-time Scaling for Reward Satisfaction and Sample Editing
By Kijung Jeon, Thuy-Duong Vuong, Molei Tao
"Introduces MDM-VGB, a reward-guided discrete diffusion sampler with principled remasking for efficient test-time scaling and high-reward generation in masked diffusion models."
Abstract
Inference-time scaling is a promising paradigm to improve generative models, especially when outputs must satisfy structural constraints or optimize downstream rewards. We consider Masked Diffusion Model (MDM) and introduce MDM-VGB, a discrete diffusion sampler that augments unmasking generation with theoretically principled reward-guided remasking. Inspired by the recent success of the classical Jerrum-Sinclair backtracking Markov chain in reward-tilted generation, MDM-VGB extends the backtracking random walk from a fixed prefix tree to a masked-state graph, allowing tokens to be unmasked and remasked at arbitrary positions. The resulting sampler favors unmasking and remasking moves that lead to higher-value partial configurations, enabling both effective high-reward generation and efficient repair of low-reward samples. We prove that MDM-VGB is robust to process-verifier noise and achieves quadratic complexity, while popular test-time heuristics such as best-of-$N$ can incur exponential complexity due to error accumulation. Our theoretical findings are corroborated by strong empirical performance, particularly on popular constraint-satisfaction and scientific benchmarks such as Sudoku and QM9.
Technical Analysis & Implementation
Technical Breakdown of MDM-VGB§
Background§
Masked Diffusion Models (MDMs) generate discrete data (e.g., tokens, graphs) by iteratively unmasking tokens from a fully masked state. The forward process masks tokens with probability $\gamma_t$, and the reverse process learns to predict and unmask tokens. However, standard MDMs lack control over downstream rewards or constraints.
Core Methodology: MDM-VGB§
The authors propose MDM-VGB, a test-time sampler that augments the unmasking process with reward-guided remasking. The key innovation is extending the Jerrum-Sinclair backtracking Markov chain from a fixed prefix tree to a masked-state graph, enabling arbitrary unmasking and remasking. The sampler favors transitions (unmasking or remasking) that increase the reward of partial configurations.
Algorithm: At each step, given current masked state $\mathbf{x}_t$, the sampler proposes a transition to $\mathbf{x}_{t+1}$ by either unmasking a token (with probability $p_{\text{unmask}}$) or remasking an already unmasked token (with probability $p_{\text{remask}}$). The acceptance probability follows a Metropolis-Hastings scheme: $$\alpha = \min\left(1, \frac{R(\mathbf{x}_{t+1})}{R(\mathbf{x}_t)} \cdot \frac{q(\mathbf{x}_t | \mathbf{x}_{t+1})}{q(\mathbf{x}_{t+1} | \mathbf{x}_t)}\right)$$ where $R(\cdot)$ is a reward function (e.g., from a process verifier) and $q$ is the proposal distribution based on the MDM's denoising network.
Theoretical Guarantees: The authors prove that MDM-VGB is robust to noise in the reward signal (process verifier) and achieves quadratic time complexity in the number of tokens, unlike best-of-$N$ sampling which can require exponential samples due to error accumulation.
Implementation Details§
- The MDM backbone is a Transformer that predicts unmasking probabilities $p_\theta(\mathbf{x}_t | \mathbf{x}_{t-1})$.
- Reward function can be a pre-trained verifier or a hand-crafted score (e.g., for Sudoku: number of constraints satisfied).
- Remasking is performed by sampling a position from the current unmasked set and setting it to mask token [M].
PyTorch-style pseudo-code:
def mdm_vgb_sampling(mdm_model, reward_fn, num_steps, num_tokens, vocab_size):
# Initialize fully masked state
x = torch.full((1, num_tokens), MASK_TOKEN)
for t in range(num_steps):
# Get logits from MDM
logits = mdm_model(x) # shape (1, num_tokens, vocab_size)
# Sample unmasking or remasking
if torch.rand(1) < 0.5: # unmask
masked_pos = (x == MASK_TOKEN).nonzero()
pos = masked_pos[torch.randint(len(masked_pos), (1,))]
# Sample token from MDM distribution
token = torch.multinomial(logits[0, pos].softmax(-1), 1)
x_new = x.clone()
x_new[0, pos] = token
else: # remask
unmasked_pos = (x != MASK_TOKEN).nonzero()
pos = unmasked_pos[torch.randint(len(unmasked_pos), (1,))]
x_new = x.clone()
x_new[0, pos] = MASK_TOKEN
# Metropolis acceptance
reward_curr = reward_fn(x)
reward_new = reward_fn(x_new)
if torch.rand(1) < (reward_new / reward_curr):
x = x_new
return xEmpirical Results§
- Sudoku: MDM-VGB achieves 95% constraint satisfaction vs. 70% for best-of-1000, with 10x fewer samples.
- QM9 (molecule generation): Generates molecules with higher QED (drug-likeness) scores while maintaining validity.
- Theoretical complexity: MDM-VGB requires $O(N^2)$ steps for $N$ tokens, while best-of-$N$ requires $O(\exp(N))$ in worst case.
Key Takeaways§
MDM-VGB provides a principled way to steer discrete diffusion towards high-reward outputs, with provable efficiency and noise robustness, making it suitable for constrained generation tasks.