Closed-Loop Supervised Fine-Tuning of Tokenized Traffic Models¶
Background & Motivation¶
Traffic simulation in autonomous driving serves as a core tool for evaluating the safety of planning algorithms. High-quality traffic simulation requires generating realistic, diverse, and physically compliant behavioral trajectories for traffic participants.
In recent years, Transformer-based tokenized traffic models (such as SMART) have made significant progress. These methods discretize continuous trajectories into token sequences, leveraging autoregressive language model architectures for trajectory prediction and generation. However, these models face a fundamental issue—distribution shift.
Distribution Shift¶
In standard Behavior Cloning (BC) training, models are trained on expert trajectories in a teacher-forcing manner: the input at each step is the ground-truth historical state. However, during closed-loop inference:
- Model prediction errors accumulate into inputs for subsequent steps.
- Over time, the input distribution of the model gradually deviates from the training distribution.
- The model lacks the capability to handle out-of-distribution inputs, leading to behavioral degradation.
| Training Method | Input Source | Distribution Shift | Computational Overhead |
|---|---|---|---|
| Behavior Cloning (Open-loop) | Ground Truth | Severe | Low |
| DAgger | Hybrid | Moderate | High (requires interaction) |
| RL fine-tuning | Model Prediction | Low | Extremely High |
| CLSFT (Ours) | CAT-K rollouts | Low | Moderate |
Existing mitigation schemes (DAgger, RL fine-tuning) either require high computational costs for online interaction or necessitate designing complex reward functions.
This paper proposes Closed-Loop Supervised Fine-Tuning (CLSFT), which introduces closed-loop information into the supervised learning framework via an ingenious rollout strategy (CAT-K), significantly mitigating distribution shift at a minimal extra cost.
Method¶
Tokenized Traffic Model Foundation¶
The workflow of models like SMART: 1. Quantize the trajectory points \((x, y, \theta)\) of each traffic participant into discrete tokens. 2. Use VQ-VAE to learn the token codebook. 3. Predict the next token using a GPT-style autoregressive model.
Standard training employs cross-entropy loss:
where \(z_t^*\) is the ground-truth token corresponding to the expert trajectory.
CAT-K Rollouts¶
CAT-K (Closest-Among-Top-K) is the core innovation of CLSFT. During training data generation:
Step 1: Top-K Sampling¶
At each time step, the model predicts the probability distribution of tokens \(p_\theta(z_t | z_{<t})\), selecting the \(K\) candidate tokens with the highest probabilities:
Step 2: Selecting the Token Closest to the Expert¶
Among the \(K\) candidates, select the one closest to the expert token:
where \(d(\cdot)\) is a distance metric in the token space (typically the Euclidean distance of the corresponding trajectory points).
Step 3: Closed-Loop Rollout¶
Use \(z_t^{CAT}\) as the input for the next step (instead of the ground-truth \(z_t^*\)) to continue generating subsequent tokens.
| Method | Input at Step \(t\) | Target at Step \(t\) |
|---|---|---|
| BC (Open-loop) | \(z_{<t}^*\) (Ground Truth) | \(z_t^*\) |
| Pure Closed-loop | \(z_{<t}^{model}\) (Model) | \(z_t^*\) |
| CAT-K | \(z_{<t}^{CAT}\) (Approximate Expert) | \(z_t^*\) |
Design Motivation of CAT-K¶
Key advantages of CAT-K:
- Controlled Deviation: Trajectories generated by CAT-K are close to but not completely identical to expert trajectories, exposing the model to slightly deviated input distributions during training.
- Stability: Since the token closest to the expert is selected at each step, deviations do not accumulate out of control.
- Diversity: The choice of \(K\) controls the degree of deviation, where \(K=1\) degenerates to greedy decoding, and \(K \to \infty\) approaches random sampling.
CLSFT Training Process¶
- Perform CAT-K rollouts on the training set using the pre-trained SMART model.
- Collect (CAT-K trajectory, expert label) pairs.
- Fine-tune using standard cross-entropy loss.
Key Experimental Results¶
WOSAC Leaderboard¶
| Method | Parameters | RMM↑ | Kinematic↑ | Interactive↑ | Map↑ |
|---|---|---|---|---|---|
| SMART-7M | 7M | 0.7302 | 0.821 | 0.683 | 0.712 |
| SMART-102M | 102M | 0.7614 | 0.849 | 0.715 | 0.738 |
| MotionLM | 45M | 0.7489 | 0.837 | 0.701 | 0.729 |
| SMART-7M + CLSFT | 7M | 0.7702 | 0.856 | 0.728 | 0.749 |
With only 7M parameters, SMART-7M + CLSFT outperforms SMART-102M (RMM 0.7702 vs 0.7614), demonstrating the immense potential of CLSFT.
Ablation Study on K¶
| K Value | RMM↑ | Description |
|---|---|---|
| K=1 (Greedy) | 0.7412 | Insufficient deviation, close to open-loop |
| K=5 | 0.7589 | Moderate deviation |
| K=10 | 0.7702 | Optimal balance point |
| K=50 | 0.7645 | Excessive deviation, unstable |
| K=∞ (Random) | 0.7301 | Degradation, out-of-control deviation |
K=10 provides the optimal balance between deviation and stability.
Closed-Loop Simulation Quality¶
| Metric | SMART (BC) | SMART + DAgger | SMART + CLSFT |
|---|---|---|---|
| Collision Rate↓ | 5.23% | 3.87% | 3.21% |
| Off-road Rate↓ | 2.14% | 1.62% | 1.38% |
| Progress↑ | 0.891 | 0.912 | 0.927 |
| Comfort↑ | 0.834 | 0.856 | 0.871 |
CLSFT outperforms DAgger across all closed-loop simulation metrics, without requiring online interaction.
Summary & Future Work¶
By selecting the token closest to the expert among the Top-K candidates, CAT-K constructs approximate closed-loop rollouts, offering an efficient and effective closed-loop supervised fine-tuning method for tokenized traffic models. SMART-7M + CLSFT, with only 7M parameters, outperforms the 102M-parameter SMART-102M on the WOSAC leaderboard (RMM 0.7702 vs 0.7614), proving that smaller models can surpass larger ones under the correct training paradigm. The core idea of CLSFT—introducing closed-loop information through controlled deviation—can be generalized to other sequence prediction tasks.