Tohle je to, co jsme koksovali posledních 9 měsíců: udělat výcvik MoE ~2x rychlejší a ~2x méně paměti! Hlavní body: - MoE obvykle zabírá nejvíce času a paměti v moderních modelech. Ukázalo se, že lze matematicky přepsat zpětný průchod MoE tak, aby se snížila aktivační paměť, kterou je třeba uložit v předním pohonu, o ~2x, což vede ke stejným gradientům bez dalšího přepočítání matmulu. Tento výsledek se mi opravdu líbí, protože kombinuje algoritmické i systémové poznatky. - Analýza úzkých míst ve vrstvě MoE vede k přirozené optimalizační strategii: co nejvíce snížit počet čtení/zápisů paměti! Shromáždění vstupů pro přední pohon a výstupní gradování pro BWD může někdy trvat stejně dlouho jako seskupené GEMM. Fúzujeme gather se skupinovým GEMM + překrýváme přístup k mem a počítáme, aby celá vrstva byla ~2x rychlejší. - Výpočet top-k pro expertní směrování může trvat překvapivě dlouho, ~15–20 % celé vrstvy MoE! Standardní top-k impl používá radix top-k algoritmus, což je skvělé pro velké k, ale suboptimální pro malé k. Top-k jsme přepsali pomocí bitonic top-k algoritmu a někdy je to 20-30krát rychlejší než pytorchovy top-k! Všechna hlavní jádra jsou napsána v Cute-DSL, takže by měla být snadno rozšiřitelná (a instalovatelná :D). Hopper jádra jsou venku, Blackwell jádra jsou téměř připravená. Modely MoE bývaly dvakrát méně hardwarově efektivní na trénování, doufejme, že to Sonic-MOE změní.