Closed-Loop Supervised Fine-Tuning of Tokenized Traffic Models

Сегодня разберём статью о том, как ребята из NVIDIA заняли первое место в лидерборде WOSAC от Waymo. Речь пойдёт о цикле SFT, а не о способах токенизации, старых слоях архитектуры или внутреннем cross attention.

CAT-K — стратегия файнтюнинга, основанная на top-k-подходе. Её авторы поднимают проблему миссматча распределений во время обучения и на инференсе.

Для обучения в open-loop используются траектории водителей как условия (обуславливание на историю) в режиме behavior cloning. Но при симуляциях на инференсе агенты двигаются уже не по таким же хорошим траекториям в closed-loop, а по своим собственным: с ошибками, которые накапливаются при последовательной генерации движения. Так могут возникать состояния, неучтённые в обучении.

В качестве бейзлайна авторы используют авторегрессионный подход SMART с дельта-токенами:

1. Фиксируют сетку по времени с шагом 0,5–2 секунды прошлого и 8 секунд будущего.
2. На каждом шаге по времени предсказывают для каждого агента токен с собственным сдвигом в координатах.

Обычно авторегрессионные модели для Traffic Motion тренируют с помощью teacher-forcing как LLM модели: формулируют Traffic Motion как Next-Token-Prediction. Но для того, чтобы уменьшить миссматч авторы адаптируют Cross-Entropy Method (или модный SFT из LLM).

Как устроен CEM:

1. Генерирация набора траекторий (в closed-loop)
2. Отбор лучших кандидатов по метрике элиты.
3. Дообучение в режиме teacher-forcing на элитах.

Элиты — моды в распределении, индуцируемом обученной моделью. Они близки к GT-тракеториям. То есть, если дообучаться на хороших траекториях из симуляций в closed-loop, миссматч между обучением и инференсом уменьшится.

Остаётся только адаптировать дельта-токены для CEM:

1. Выбрать K самых вероятных токенов на текущем шаге генерации.
2. Из K самых вероятных токенов выбрать тот, что лучше всего аппроксимирует GT.
3. Использовать выбранный токен для пересчёта следующего состояния.

Контроль количества элит при генерации помогает избежать лишних симуляций и их фильтрации: дискретизация дельта-токенов — дискретизация первого порядка.

Внедрение CAT-K помогло небольшой политике моделирования токенизированного трафика с 7 миллионами параметров превзойти модель с 102 миллионами параметров из того же семейства моделей и занять первое место в таблице лидеров Waymo Sim Agent Challenge на момент подачи заявки.

Разбор подготовил ❣️ Тингир Бадмаев
404 driver not found