模型部署与优化
模型部署
1. 模型打包
将训练好的模型打包成可部署的格式。
# PyTorch模型打包
import torch
# 保存模型
def save_model(model, path):
"""保存模型到文件"""
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch,
'loss': loss,
}, path)
# 加载模型
def load_model(path):
"""从文件加载模型"""
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
return model, optimizer, epoch, loss
# TensorFlow模型打包
import tensorflow as tf
# 保存模型
def save_model(model, path):
"""保存模型到文件"""
model.save(path)
# 加载模型
def load_model(path):
"""从文件加载模型"""
model = tf.keras.models.load_model(path)
return model
# ONNX模型转换
import torch.onnx
# 转换为ONNX格式
def convert_to_onnx(model, input_shape, path):
"""将模型转换为ONNX格式"""
dummy_input = torch.randn(input_shape)
torch.onnx.export(model, dummy_input, path,
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'},
'output': {0: 'batch_size'}})2. 容器化部署
使用Docker容器化部署模型服务。
# Dockerfile FROM python:3.8-slim # 安装依赖 COPY requirements.txt . RUN pip install -r requirements.txt # 复制应用代码 COPY . . # 设置环境变量 ENV MODEL_PATH=/app/models ENV PORT=8000 # 启动服务 CMD ["python", "app.py"] # 构建镜像 docker build -t ai-model:latest . # 运行容器 docker run -d -p 8000:8000 ai-model:latest # 使用GPU docker run -d --gpus all -p 8000:8000 ai-model:latest # 使用数据卷 docker run -d -v /host/path:/container/path -p 8000:8000 ai-model:latest
3. API服务
提供RESTful API服务接口。
# FastAPI服务
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
app = FastAPI()
# 配置CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 加载模型
model = load_model("model.pth")
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
"""预测接口"""
# 读取文件
contents = await file.read()
# 预处理数据
data = preprocess(contents)
# 模型预测
prediction = model.predict(data)
return {"prediction": prediction.tolist()}
@app.get("/health")
async def health_check():
"""健康检查"""
return {"status": "healthy"}
# 启动服务
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)