Model Predictive Control with Differentiable World Models for Offline Reinforcement Learning
Offline Reinforcement Learning (RL) aims to learn optimal policies from fixed offline datasets, without further interactions with the environment. Such methods train an offline policy (or value function), and apply it at inference time without further refinement. We introduce an inference time adaptation framework inspired by model predictive control (MPC) that utilizes a pretrained policy along with a learned world model of state transitions and rewards. While existing world model and diffusion-planning methods use learned dynamics to generate imagined trajectories during training, or to sample candidate plans at inference time, they do not use inference-time information to optimize the policy parameters on the fly. In contrast, our design is a Differentiable World Model (DWM) pipeline that enables endto-end gradient computation through imagined rollouts for policy optimization at inference time based on MPC. We evaluate our algorithm on D4RL continuous-control benchmarks (MuJoCo locomotion tasks and AntMaze), and show that exploiting inference-time information to optimize the policy parameters yields consistent gains over strong offline RL baselines.
💡 Research Summary
Offline reinforcement learning (Offline RL) seeks to learn high‑performing policies from a fixed dataset without any further interaction with the environment. Traditional Offline RL pipelines train a policy (or a value function) once on the dataset and then deploy it unchanged at test time. This static deployment suffers from distribution shift: the states and actions encountered during execution may lie outside the support of the training data, leading to inaccurate Q‑value estimates and sub‑optimal actions.
The paper proposes a fundamentally different approach that leverages a differentiable world model (DWM) together with model‑predictive control (MPC) to perform inference‑time adaptation of the policy. The DWM consists of three differentiable components trained on the offline dataset:
-
Conditional diffusion transition sampler fθ – a diffusion‑based generative model that, given a state‑action pair (sₜ, aₜ) and a random noise vector εₜ, produces a next‑state sample sₜ₊₁ = fθ(sₜ, aₜ, εₜ). The diffusion process is trained to model the conditional distribution pθ(sₜ₊₁|sₜ, aₜ) and is fully re‑parameterizable, making the sampling operation differentiable with respect to its conditioning inputs.
-
Reward model rξ – a neural network that predicts the immediate reward rₜ for any (sₜ, aₜ). It is trained with a standard regression loss on the offline transitions.
-
Terminal value function Qϕ – the critic associated with a pretrained policy πψ. This provides an estimate of the long‑horizon return from the final imagined state of a rollout.
During deployment, the current environment state sₜ is fed to the pretrained policy πψ to generate an action. However, instead of executing this action directly, the algorithm performs the following MPC‑style loop:
-
Imagined rollouts – For a horizon H (typically 3–5 steps) and a number of parallel samples N (e.g., 10), the algorithm alternates between the policy πψ and the diffusion sampler fθ to generate N imagined trajectories. At each imagined step h, the policy proposes an action âₕ = πψ(ŝₕ) and the diffusion model produces the next imagined state ŝₕ₊₁ = fθ(ŝₕ, âₕ, εₕ).
-
Surrogate return computation – For each imagined trajectory, the immediate rewards are estimated by r̂ₕ = rξ(ŝₕ, âₕ). The surrogate return G is the discounted sum of these imagined rewards plus the discounted terminal value from Qϕ:
G = Σₕ₌₀^{H‑1} γ^{h} r̂ₕ + γ^{H} Qϕ(ŝ_H). -
Policy gradient update – The surrogate return serves as a differentiable loss. Because the entire imagined rollout (policy → diffusion → reward → value) is a computation graph, gradients ∂G/∂ψ can be back‑propagated through the policy Jacobians and the diffusion Jacobians. The authors formalize this in Theorem 4.1, showing that the gradient can be expressed recursively via the chain rule. A few gradient‑descent steps are taken on ψ, yielding an adapted policy πψ^{new}.
-
Execution – The first action of the adapted policy, aₜ = πψ^{new}(sₜ), is executed in the real environment. The process repeats at the next time step with the newly observed state.
Thus, the method continuously refines the policy on‑the‑fly, using the learned world model to “look ahead” and adjust the policy parameters to maximize the imagined return.
Key technical contributions
-
Differentiable world model pipeline – By integrating a conditional diffusion transition model, a reward predictor, and a pretrained critic, the authors construct a fully differentiable simulator that can be used for gradient‑based policy adaptation.
-
Inference‑time MPC with policy updates – Unlike prior world‑model offline RL methods that either use the model only during training or sample candidate plans at test time without updating the policy, this work back‑propagates through imagined rollouts to directly improve the policy parameters at deployment.
-
Analytical gradient derivation – Theorem 4.1 provides a closed‑form recursive expression for the gradient of the surrogate return with respect to policy parameters, explicitly involving the Jacobians of both the policy and the diffusion transition model.
-
Empirical validation – Experiments on the D4RL benchmark suite (18 MuJoCo locomotion datasets and 6 AntMaze datasets) show consistent improvements over strong baselines such as CQL, IQL, Decision Diffuser, and Flow Q‑learning. Gains are especially pronounced on the AntMaze tasks, where data is sparse and the optimal behavior requires long‑horizon planning.
Strengths
- Local modeling of dynamics and rewards reduces reliance on accurate long‑term Q‑value estimates, which are notoriously difficult in offline settings.
- Diffusion‑based transition modeling captures multimodal dynamics and provides a smooth, differentiable sampling procedure.
- Inference‑time adaptation leverages the specific state encountered at test time, allowing the policy to specialize on the current context rather than being a one‑size‑fits‑all solution.
Limitations
- Training a high‑quality diffusion model demands substantial data and compute; performance may degrade when the offline dataset is very small or highly noisy.
- The inference loop requires multiple imagined rollouts and gradient steps, increasing latency and computational load, which could be problematic for real‑time control without hardware acceleration.
- The method assumes the learned world model is sufficiently accurate; model bias can lead to misguided policy updates. The authors note that adding a KL‑penalty or uncertainty‑aware weighting can mitigate this but further research is needed.
Future directions
- Incorporating model uncertainty (e.g., ensembles or Bayesian diffusion) to weight imagined trajectories based on confidence.
- Exploring lightweight diffusion architectures or hybrid models (e.g., dynamics ensembles + diffusion) to reduce inference latency.
- Extending the framework to high‑dimensional observations (images, lidar) by coupling visual encoders with the diffusion transition model.
- Applying the approach to safety‑critical domains such as autonomous driving or medical decision making, where offline data is abundant but online exploration is prohibited.
In summary, the paper introduces a novel paradigm for offline RL: model‑based, inference‑time policy adaptation using a differentiable world model and MPC. By turning the offline dataset into a reusable, gradient‑friendly simulator, the method bridges the gap between static offline policies and the dynamic, context‑aware decision making required in real‑world control tasks. The empirical results demonstrate that this approach consistently outperforms existing offline RL baselines, marking a significant step toward practical, safe, and adaptable offline reinforcement learning.
Comments & Academic Discussion
Loading comments...
Leave a Comment