Skip to content

Closed-Loop Supervised Fine-Tuning of Tokenized Traffic Models

Background & Motivation

Traffic simulation in autonomous driving serves as a core tool for evaluating the safety of planning algorithms. High-quality traffic simulation requires generating realistic, diverse, and physically compliant behavioral trajectories for traffic participants.

In recent years, Transformer-based tokenized traffic models (such as SMART) have made significant progress. These methods discretize continuous trajectories into token sequences, leveraging autoregressive language model architectures for trajectory prediction and generation. However, these models face a fundamental issue—distribution shift.

Distribution Shift

In standard Behavior Cloning (BC) training, models are trained on expert trajectories in a teacher-forcing manner: the input at each step is the ground-truth historical state. However, during closed-loop inference:

  1. Model prediction errors accumulate into inputs for subsequent steps.
  2. Over time, the input distribution of the model gradually deviates from the training distribution.
  3. The model lacks the capability to handle out-of-distribution inputs, leading to behavioral degradation.
Training Method Input Source Distribution Shift Computational Overhead
Behavior Cloning (Open-loop) Ground Truth Severe Low
DAgger Hybrid Moderate High (requires interaction)
RL fine-tuning Model Prediction Low Extremely High
CLSFT (Ours) CAT-K rollouts Low Moderate

Existing mitigation schemes (DAgger, RL fine-tuning) either require high computational costs for online interaction or necessitate designing complex reward functions.

This paper proposes Closed-Loop Supervised Fine-Tuning (CLSFT), which introduces closed-loop information into the supervised learning framework via an ingenious rollout strategy (CAT-K), significantly mitigating distribution shift at a minimal extra cost.

Method

Tokenized Traffic Model Foundation

The workflow of models like SMART: 1. Quantize the trajectory points \((x, y, \theta)\) of each traffic participant into discrete tokens. 2. Use VQ-VAE to learn the token codebook. 3. Predict the next token using a GPT-style autoregressive model.

Standard training employs cross-entropy loss:

\[\mathcal{L}_{BC} = -\sum_t \log p_\theta(z_t^* | z_{<t}^*)\]

where \(z_t^*\) is the ground-truth token corresponding to the expert trajectory.

CAT-K Rollouts

CAT-K (Closest-Among-Top-K) is the core innovation of CLSFT. During training data generation:

Step 1: Top-K Sampling

At each time step, the model predicts the probability distribution of tokens \(p_\theta(z_t | z_{<t})\), selecting the \(K\) candidate tokens with the highest probabilities:

\[\text{Top-K}(p_\theta) = \{z_t^{(1)}, z_t^{(2)}, ..., z_t^{(K)}\}\]

Step 2: Selecting the Token Closest to the Expert

Among the \(K\) candidates, select the one closest to the expert token:

\[z_t^{CAT} = \arg\min_{z \in \text{Top-K}(p_\theta)} d(z, z_t^*)\]

where \(d(\cdot)\) is a distance metric in the token space (typically the Euclidean distance of the corresponding trajectory points).

Step 3: Closed-Loop Rollout

Use \(z_t^{CAT}\) as the input for the next step (instead of the ground-truth \(z_t^*\)) to continue generating subsequent tokens.

Method Input at Step \(t\) Target at Step \(t\)
BC (Open-loop) \(z_{<t}^*\) (Ground Truth) \(z_t^*\)
Pure Closed-loop \(z_{<t}^{model}\) (Model) \(z_t^*\)
CAT-K \(z_{<t}^{CAT}\) (Approximate Expert) \(z_t^*\)

Design Motivation of CAT-K

Key advantages of CAT-K:

  1. Controlled Deviation: Trajectories generated by CAT-K are close to but not completely identical to expert trajectories, exposing the model to slightly deviated input distributions during training.
  2. Stability: Since the token closest to the expert is selected at each step, deviations do not accumulate out of control.
  3. Diversity: The choice of \(K\) controls the degree of deviation, where \(K=1\) degenerates to greedy decoding, and \(K \to \infty\) approaches random sampling.

CLSFT Training Process

  1. Perform CAT-K rollouts on the training set using the pre-trained SMART model.
  2. Collect (CAT-K trajectory, expert label) pairs.
  3. Fine-tune using standard cross-entropy loss.
\[\mathcal{L}_{CLSFT} = -\sum_t \log p_\theta(z_t^* | z_{<t}^{CAT})\]

Key Experimental Results

WOSAC Leaderboard

Method Parameters RMM↑ Kinematic↑ Interactive↑ Map↑
SMART-7M 7M 0.7302 0.821 0.683 0.712
SMART-102M 102M 0.7614 0.849 0.715 0.738
MotionLM 45M 0.7489 0.837 0.701 0.729
SMART-7M + CLSFT 7M 0.7702 0.856 0.728 0.749

With only 7M parameters, SMART-7M + CLSFT outperforms SMART-102M (RMM 0.7702 vs 0.7614), demonstrating the immense potential of CLSFT.

Ablation Study on K

K Value RMM↑ Description
K=1 (Greedy) 0.7412 Insufficient deviation, close to open-loop
K=5 0.7589 Moderate deviation
K=10 0.7702 Optimal balance point
K=50 0.7645 Excessive deviation, unstable
K=∞ (Random) 0.7301 Degradation, out-of-control deviation

K=10 provides the optimal balance between deviation and stability.

Closed-Loop Simulation Quality

Metric SMART (BC) SMART + DAgger SMART + CLSFT
Collision Rate↓ 5.23% 3.87% 3.21%
Off-road Rate↓ 2.14% 1.62% 1.38%
Progress↑ 0.891 0.912 0.927
Comfort↑ 0.834 0.856 0.871

CLSFT outperforms DAgger across all closed-loop simulation metrics, without requiring online interaction.

Summary & Future Work

By selecting the token closest to the expert among the Top-K candidates, CAT-K constructs approximate closed-loop rollouts, offering an efficient and effective closed-loop supervised fine-tuning method for tokenized traffic models. SMART-7M + CLSFT, with only 7M parameters, outperforms the 102M-parameter SMART-102M on the WOSAC leaderboard (RMM 0.7702 vs 0.7614), proving that smaller models can surpass larger ones under the correct training paradigm. The core idea of CLSFT—introducing closed-loop information through controlled deviation—can be generalized to other sequence prediction tasks.