Isto é o que temos estado a cozinhar nos últimos 9 meses: fazer o treinamento de MoEs ser ~2x mais rápido e ~2x menos memória! Destaques: - MoE normalmente leva mais tempo e memória em modelos modernos. Acontece que se pode reescrever matematicamente o retrocesso do MoE para reduzir a memória de ativação que precisa ser armazenada no avanço em ~2x, resultando nos mesmos gradientes sem recomputação extra de matmul. Gosto muito deste resultado, pois combina tanto insights algorítmicos quanto de sistemas. - Analisar os gargalos na camada MoE leva a uma estratégia de otimização natural: reduzir leituras/escritas de memória tanto quanto possível! Reunir a entrada para o avanço e o gradiente de saída para o retrocesso pode às vezes levar tanto tempo quanto os GEMMs agrupados. Fundimos a coleta com GEMM agrupado + sobreposição de acesso à memória e computação para fazer toda a camada ser ~2x mais rápida. - Calcular o top-k para roteamento de especialistas pode levar surpreendentemente muito tempo, ~15-20% de toda a camada MoE! A implementação padrão de top-k usa o algoritmo radix top-k, ótimo para k grande, mas subótimo para k pequeno. Reescrevemos o top-k usando o algoritmo bitônico top-k, e às vezes é 20-30x mais rápido que o top-k do pytorch! Todos os principais núcleos estão escritos em Cute-DSL, então devem ser fáceis de estender (e instalar :D). Os núcleos Hopper estão prontos, os núcleos Blackwell estão quase prontos. Os modelos MoE costumavam ser 2x menos eficientes em hardware para treinar, espero que o Sonic-MOE mude isso.