模型部署与优化
模型部署
1. 模型打包
将训练好的模型打包成可部署的格式。
# PyTorch模型打包 import torch # 保存模型 def save_model(model, path): """保存模型到文件""" torch.save({ 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'epoch': epoch, 'loss': loss, }, path) # 加载模型 def load_model(path): """从文件加载模型""" checkpoint = torch.load(path) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] loss = checkpoint['loss'] return model, optimizer, epoch, loss # TensorFlow模型打包 import tensorflow as tf # 保存模型 def save_model(model, path): """保存模型到文件""" model.save(path) # 加载模型 def load_model(path): """从文件加载模型""" model = tf.keras.models.load_model(path) return model # ONNX模型转换 import torch.onnx # 转换为ONNX格式 def convert_to_onnx(model, input_shape, path): """将模型转换为ONNX格式""" dummy_input = torch.randn(input_shape) torch.onnx.export(model, dummy_input, path, input_names=['input'], output_names=['output'], dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})
2. 容器化部署
使用Docker容器化部署模型服务。
# Dockerfile FROM python:3.8-slim # 安装依赖 COPY requirements.txt . RUN pip install -r requirements.txt # 复制应用代码 COPY . . # 设置环境变量 ENV MODEL_PATH=/app/models ENV PORT=8000 # 启动服务 CMD ["python", "app.py"] # 构建镜像 docker build -t ai-model:latest . # 运行容器 docker run -d -p 8000:8000 ai-model:latest # 使用GPU docker run -d --gpus all -p 8000:8000 ai-model:latest # 使用数据卷 docker run -d -v /host/path:/container/path -p 8000:8000 ai-model:latest
3. API服务
提供RESTful API服务接口。
# FastAPI服务 from fastapi import FastAPI, File, UploadFile from fastapi.middleware.cors import CORSMiddleware import uvicorn app = FastAPI() # 配置CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # 加载模型 model = load_model("model.pth") @app.post("/predict") async def predict(file: UploadFile = File(...)): """预测接口""" # 读取文件 contents = await file.read() # 预处理数据 data = preprocess(contents) # 模型预测 prediction = model.predict(data) return {"prediction": prediction.tolist()} @app.get("/health") async def health_check(): """健康检查""" return {"status": "healthy"} # 启动服务 if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)