Demons in the Detail: On Implementing Load Balancing Loss for Training Specialized Mixture-of-Expert Models

Сегодня разберём статью от команды Qwen о том, как они придумали новый LBL-лосс для обучения MoE.

В MoE-моделях токены по экспертам распределяет роутер. LBL — вспомогательный лосс, который делает распределение равномерным, чтобы избежать перегрузки одних экспертов и голода других.

Обычно LBL считают на уровне отдельного микробатча каждого DP-ранка, а потом усредняют полученные LBL по всем микробатчам. Но заставлять роутер распределять токены равномерно в рамках одного микро-батча — довольно строгое ограничение. Пара длинных семплов может заполнить весь микро-батч, и тогда, если эти семплы пришли из одного домена, роутер всë равно будет вынужден разослать эти токены равномерно по всем экспертам. Так теряется логика специализации экспертов.

Для того чтобы избежать потери специализации, авторы предлагают считать LBL на уровне глобального батча (global-batch), где больше разнообразия данных. Как? Добавляют шаг коммуникации: синхронизируют нужные для подсчёта LBL статистики роутера по выбору экспертов со всей DP-группы, то есть со всех микробатчей. Рассмотрим пример:

1. Вообразим 2 карты и обучение с DP.
2. А к ним — 4 эксперта и 16 токенов (после пермьюта).
На первой карте токены распределятся по экспертам так: [0, 0, 8, 8]. На второй — [8, 8, 0, 0].
3. Для micro-lbl этот лосс будет на каждой карте ругать роутер за неравномерное распределение токенов.
5. Но если мы соберём глобальную статистику (то есть, сложим вектора распределений со всех карт), то получим [8, 8, 8, 8]. Это идеальная равномерность и macro-lbl на такое не обижается.
6. macro-lbl даёт роутеру больше свободы, что конвертируется в прирост качества.

Авторы отмечают значительный рост производительности при обучении новым методом: модели с глобальной балансировкой показывают лучшие результаты как по лоссам, так и на различных бенчах. А ещё у экспертов появляется настоящая специализация: чёткая и интерпретируемая на доменах (код, математика, разные языки).

Предложенный метод при эффективной реализации совсем не замедляет обучение. Можно собрать статистики каждого слоя и сделать лишь одну незначительную коммуникацию в конце.

Разбор подготовил Даниил Сухой

Душный NLP