使用 FastAPI 训练、部署和使用 Scikit-learn 模型
本文详细介绍了如何使用 FastAPI 构建一个 Scikit-learn 机器学习模型的推理 API。从项目设置、模型训练、本地测试到云端部署,完整地演示了将乳腺癌分类器转化为可调用 API 的过程。
本文旨在指导读者完成从训练一个 Scikit-learn 模型到将其部署为 FastAPI 服务的完整流程。我们将以乳腺癌数据集为例,使用随机森林分类器,并最终将 API 部署到 FastAPI Cloud。
1. 项目设置
首先创建一个新文件夹并组织项目结构。运行以下命令:
mkdir sklearn-fastapi-app
cd sklearn-fastapi-app
mkdir app artifacts
touch app/__init__.py项目结构如下:
sklearn-fastapi-app/
├── app/
│ ├── __init__.py
│ └── main.py
├── artifacts/
├── train.py
├── pyproject.toml
└── requirements.txt在 requirements.txt 中添加依赖:
fastapi[standard]
scikit-learn
joblib
numpy然后执行 pip install -r requirements.txt 安装。
2. 训练机器学习模型
创建 train.py,加载乳腺癌数据集,分割训练/测试集,训练 RandomForestClassifier(200棵树),评估准确率,并将模型和元数据保存为 joblib 文件。
# train.py 内容(略去重复代码,简要概述)
from pathlib import Path
import joblib
from sklearn.datasets import load_breast_cancer
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
def main():
data = load_breast_cancer()
X, y = data.data, data.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
model = RandomForestClassifier(n_estimators=200, random_state=42)
model.fit(X_train, y_train)
predictions = model.predict(X_test)
accuracy = accuracy_score(y_test, predictions)
artifact = {
"model": model,
"target_names": data.target_names.tolist(),
"feature_names": data.feature_names,
}
output_path = Path("artifacts/breast_cancer_model.joblib")
output_path.parent.mkdir(parents=True, exist_ok=True)
joblib.dump(artifact, output_path)
print(f"Model saved to: {output_path}")
print(f"Test accuracy: {accuracy:.4f}")
if __name__ == "__main__":
main()运行 python train.py 后,模型文件保存到 artifacts/ 目录,测试准确率约为 0.9561。
3. 构建 FastAPI 服务器
在 app/main.py 中创建 FastAPI 应用:
- 在启动时加载模型。
- 提供
/health健康检查端点。 - 提供
/predictPOST 端点,接收 30 个特征值,返回预测的类别 ID、标签以及每个类别的概率。
使用 Pydantic 的 BaseModel 定义请求体的模式。/predict 内部将输入特征转换为 NumPy 数组,调用模型的 predict 和 predict_proba 方法,并以 JSON 格式返回结果。
4. 本地测试
启动开发服务器:
fastapi dev app/main.py访问 http://127.0.0.1:8000/docs 可看到交互式 API 文档,直接测试 /predict 端点。也可以使用 curl 发送 POST 请求。
成功响应示例:
{
"prediction_id": 0,
"prediction_label": "malignant",
"probabilities": {
"malignant": 0.99,
"benign": 0.01
}
}这表明 API 运行正常。
5. 部署到云端
使用 FastAPI CLI 部署:
fastapi login # 登录
fastapi deploy # 部署首次部署时,CLI 会引导选择团队和创建应用。部署成功后,会提供类似 https://sklearn-fastapi-app.fastapicloud.dev 的 URL。
然后就可以通过云端 URL 访问 API,并检查日志以监控运行状态。
通过本文的步骤,读者可以轻松将自己的 Scikit-learn 模型转化为生产可用的 API。