É isso que temos cocaído nos últimos 9 meses: fazer o treinamento do MoEs ir ~2x mais rápido e ~2x menos memória! Destaques: - MoE normalmente consome mais tempo e memória em modelos modernos. Acontece que é possível reescrever matematicamente o passe retroativo do MoE para reduzir o mem de ativação que você precisa armazenar no fwd em ~2x, resultando nos mesmos gradientes sem recomputação extra de matmul. Gosto muito desse resultado, pois combina insights algorítmicos e de sistemas. - Analisar gargalos na camada de MoE leva a uma estratégia natural de otimização: reduzir leituras/escritas de mems o máximo possível! Reunir a entrada para tração dianteira e o grau de saída para bwd às vezes pode levar tanto tempo quanto os GEMMs agrupados. Nós fundimos o gather com GEMM agrupado + sobreposição acesso ao mem, e calculamos para fazer toda a camada ir ~2x mais rápido. - Calcular o top-k para roteamento especialista pode levar surpreendentemente tempo, ~15-20% de toda a camada MoE! O padrão top-k impl usa o algoritmo radix top-k, ótimo para k grande, mas subótimo para k pequeno. Reescrevemos o top-k usando o algoritmo bitonic top-k, e às vezes é 20-30 vezes mais rápido que o top-k do pytorch! Todos os kernels principais são escritos em Cute-DSL, então devem ser fáceis de estender (e instalar :D). Os grãos do hopper já saíram, os grãos de Blackwell estão quase prontos. Os modelos MoE costumavam ser duas vezes menos eficientes em hardware para treinar, espero que o Sonic-MOE mude isso.