AI News HubLIVE
站内改写2 分钟阅读

使用LLM嵌入和HDBSCAN对非结构化文本进行聚类

本文介绍了如何结合大语言模型嵌入和HDBSCAN密度聚类算法,构建文本聚类管道,自动发现未标注文本数据中的主题。包括使用预训练模型生成嵌入、UMAP降维、HDBSCAN聚类及可视化。

来源Machine Learning Mastery作者: Iván Palomares Carrascosa

在当前的生成式AI热潮中,人们往往关注聊天界面和提示词,但大语言模型(LLM)的应用远不止于此。其最强大的能力之一是将原始、杂乱的非结构化文本转化为语义丰富的数学表示——嵌入。随后,这些文本表示可用于各种机器学习任务,聚类便是其中之一。

特别是,嵌入可以与HDBSCAN等高级密度聚类技术结合,从而发现文本集合中的隐藏主题、模式或类别,无需预先标注。本文将从零开始构建一个基于文本的聚类管道,使用免费数据集和开源嵌入模型,并借助Python库实现。

逐步实现

首先,安装必要的Python库:sentence-transformers用于加载预训练LLM生成嵌入,umap-learn用于降维。此外,还需要scikit-learnpandas

!pip install sentence-transformers umap-learn

接着,获取数据。使用fetch_20newsgroups函数加载新闻组数据集,并筛选出三个类别(sci.space、sci.med、rec.autos)的150个样本。注意,尽管数据集包含标签,但我们将其忽略,以模拟真实聚类场景。

import pandas as pd
from sklearn.datasets import fetch_20newsgroups

categories = ['sci.space', 'sci.med', 'rec.autos']
newsgroups = fetch_20newsgroups(subset='train', categories=categories, remove=('headers', 'footers', 'quotes'))
df = pd.DataFrame({'text': newsgroups.data, 'true_label': newsgroups.target})
df = df[df['text'].str.strip().str.len() > 100].sample(150, random_state=42).reset_index(drop=True)
print(f"Loaded {len(df)} text documents.")

然后,使用all-MiniLM-L6-v2模型生成嵌入。这是一个轻量级但高效的模型。

from sentence_transformers import SentenceTransformer
model = SentenceTransformer('all-MiniLM-L6-v2')
embeddings = model.encode(df['text'].tolist(), show_progress_bar=True)
print(f"Embedding matrix shape: {embeddings.shape}")

嵌入维度较高,需要降维。采用UMAP将维度降至5,同时保留足够的信息。

import umap
reducer = umap.UMAP(n_neighbors=15, n_components=5, min_dist=0.0, random_state=42)
reduced_embeddings = reducer.fit_transform(embeddings)
print(f"Reduced matrix shape: {reduced_embeddings.shape}")

降维后,应用HDBSCAN聚类。设置最小簇大小为8,最小样本数为3。

from sklearn.cluster import HDBSCAN
clusterer = HDBSCAN(min_cluster_size=8, min_samples=3, store_centers='centroid')
df['cluster'] = clusterer.fit_predict(reduced_embeddings)
cluster_counts = df['cluster'].value_counts()
print("Cluster Distribution:")
print(cluster_counts)

结果发现两个簇,分别对应不同主题。通过检查样本文本,可知簇0主要涉及空间和医学话题,簇1主要涉及汽车话题。

为了进一步可视化,可以绘制各维度组合的散点图。

import matplotlib.pyplot as plt
import seaborn as sns
import itertools

reduced_df = pd.DataFrame(reduced_embeddings, columns=[f'UMAP_D{i+1}' for i in range(reduced_embeddings.shape[1])])
reduced_df['cluster'] = df['cluster']
dim_pairs = list(itertools.combinations(reduced_df.columns[:-1], 2))
num_plots = len(dim_pairs)
num_cols = 3
num_rows = (num_plots + num_cols - 1) // num_cols
plt.figure(figsize=(num_cols * 5, num_rows * 4))
for i, (dim1, dim2) in enumerate(dim_pairs):
    plt.subplot(num_rows, num_cols, i + 1)
    sns.scatterplot(x=dim1, y=dim2, hue='cluster', data=reduced_df, palette='viridis', s=70, alpha=0.7)
    plt.title(f'{dim1} vs {dim2}')
    plt.grid(True, linestyle='--', alpha=0.6)
plt.tight_layout()
plt.show()

总结

通过将LLM嵌入与HDBSCAN结合,我们可以有效捕捉文本的语义信息,自动确定聚类数量,并识别噪声点。这种方法适用于多种无监督文本分析任务。