Asta am încercat în ultimele 9 luni: să facem antrenamentul MoEs să meargă ~de 2 ori mai repede și ~de 2 ori mai puțină memorie! Momente importante: - MoE consumă de obicei cel mai mult timp și memorie la modelele moderne. Se pare că se poate rescrie matematic MoE backward pass pentru a reduce mem-ul de activare pe care trebuie să-l stochezi în fwd cu ~2x, rezultând aceleași gradiente fără recalcul suplimentar matmul. Îmi place foarte mult acest rezultat, deoarece combină atât perspective algoritmice, cât și cele de sistem. - Analiza blocajelor în stratul MoE duce la o strategie naturală de optimizare: reduci cât mai mult citirile/scrierile de meme-uri! Adunarea intrărilor pentru tracțiunea din față și a gradului de ieșire pentru bwd poate dura uneori la fel de mult timp ca GEMM-urile grupate. Fuziunăm colectarea cu GEMM grupat + acces la memuri suprapuși și calculăm pentru a face ca întregul strat să meargă ~de 2 ori mai repede. - Calcularea top-k pentru rutarea expertă poate dura surprinzător de mult, ~15-20% din întregul strat MoE! Implicitul standard top-k folosește algoritmul radix top-k, excelent pentru k mari, dar suboptim pentru k mic. Am rescris top-k folosind algoritmul bitonic top-k și uneori este de 20-30 de ori mai rapid decât top-k de la pytorch! Toate nucleele principale sunt scrise în Cute-DSL, deci ar trebui să fie ușor de extins (și instalat :D). Boabele de hopper sunt scoase, boabele Blackwell sunt aproape gata. Modelele MoE erau de două ori mai puțin eficiente hardware de antrenat, sper ca Sonic-MOE să schimbe asta.