Det här är vad vi har kokat fram de senaste 9 månaderna: få MoEs-träningen att gå ~2 gånger snabbare och ~2 gånger mindre minne! Höjdpunkter: - MoE tar vanligtvis mest tid och minne i moderna modeller. Det visar sig att man matematiskt kan skriva om MoE-bakåtpasset för att minska aktiveringsminnet du behöver lagra i framströmsenheten med ~2x, vilket resulterar i samma gradienter utan extra matmul-omberäkning. Jag gillar verkligen detta resultat, eftersom det kombinerar både algoritmiska och systeminsikter. - Analys av flaskhalsar i MoE-lagret leder till en naturlig optimeringsstragegi: minska memläsningar/skrivningar så mycket som möjligt! Att samla in input för framhjulsdrift och utdata för bwd kan ibland ta lika lång tid som de grupperade GEMM:erna. Vi fusionerar gather med grupperad GEMM + överlappning, minnesåtkomst och beräkning för att få hela lagret att gå ~2 gånger snabbare. - Att beräkna top-k för expert-routing kan ta förvånansvärt lång tid, ~15–20 % av hela MoE-lagret! Standard top-k impl använder radix top-k algoritm, utmärkt för stora k men suboptimala för små k. Vi skrev om top-k med bitonisk top-k-algoritm, och ibland är det 20-30 gånger snabbare än pytorchs top-k! Alla huvudkärnor är skrivna i Cute-DSL så de borde vara lätta att utöka (och installera :D). Hopperkärnorna är ute, Blackwell-kärnorna är nästan klara. MoE-modeller brukade vara dubbelt så hårdvarueffektiva att träna, förhoppningsvis kommer Sonic-MOE att ändra på det.