认识Flash-KMeans:一个IO感知的精确K-Means,在GPU上比FAISS快200倍以上
Flash-KMeans是一个开源的、IO感知的精确K-Means实现,使用Triton GPU内核优化数据流,在不改变数学或近似的情况下实现显著加速。在NVIDIA H200上,它比最佳基线快17.9倍,比cuML快33倍,比FAISS快200倍以上。
Flash-KMeans 是由加州大学伯克利分校和德克萨斯大学奥斯汀分校的研究团队发布的开源库,旨在解决现代AI流水线中K-Means频繁调用带来的延迟瓶颈。传统的K-Means通常作为离线工具使用,但如今在训练和推理循环中频繁调用,因此每次调用的延迟比理论FLOPs更重要。
Flash-KMeans 是标准Lloyd K-Means的IO感知实现,它不改变数学或近似,仅重构算法在GPU上的数据流。在NVIDIA H200上,研究团队报告了高达17.9倍的端到端加速,对比NVIDIA cuML快33倍,对比FAISS快200倍以上。
Flash-KMeans 的核心创新包括两个关键内核:FlashAssign 和 Sort-Inverse Update。FlashAssign 受FlashAttention启发,将点和质心分块从HBM流式传输到片上SRAM,融合距离计算与在线argmin,避免物化完整的N×K距离矩阵。这使分配阶段的IO复杂度从O(NK)降至O(Nd+Kd),分配内核速度提升高达21.2倍。
Sort-Inverse Update 针对质心更新阶段的原子竞争问题,通过argsort将分配向量按簇ID排序,形成连续段。每个线程块在片上规约一段,然后每段仅执行一次原子加操作。这使更新内核速度提升高达6.3倍。
基准测试在H200上进行,FP16数据,维度d=128,对比fast_pytorch_kmeans、fastkmeans、cuML和FAISS。端到端速度提升最高17.9倍(N=8M,K=1024),分配内核21.2倍,更新内核6.3倍。外核模式下,对10亿个点(K=32768,d=128)每次迭代仅需41.4秒,而基线需261.8秒。
Flash-KMeans 还包含缓存感知编译启发式,将调优开销降低175倍,同时保持接近最佳性能。该库已开源,采用Apache 2.0许可,可通过 pip install flash-kmeans 安装。
使用场景包括:向量搜索索引构建、稀疏注意力路由、KV缓存压缩、低比特KV量化、扩散Transformer等。API与faiss和sklearn类似,支持批处理和多GPU。
总之,Flash-KMeans 通过GPU数据流优化实现了精确K-Means的显著加速,为需要频繁聚类的新兴应用提供了实用工具。