导航菜单

模型部署与优化

模型部署

1. 模型打包

将训练好的模型打包成可部署的格式。

# PyTorch模型打包
import torch

# 保存模型
def save_model(model, path):
    """保存模型到文件"""
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch,
        'loss': loss,
    }, path)

# 加载模型
def load_model(path):
    """从文件加载模型"""
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    return model, optimizer, epoch, loss

# TensorFlow模型打包
import tensorflow as tf

# 保存模型
def save_model(model, path):
    """保存模型到文件"""
    model.save(path)

# 加载模型
def load_model(path):
    """从文件加载模型"""
    model = tf.keras.models.load_model(path)
    return model

# ONNX模型转换
import torch.onnx

# 转换为ONNX格式
def convert_to_onnx(model, input_shape, path):
    """将模型转换为ONNX格式"""
    dummy_input = torch.randn(input_shape)
    torch.onnx.export(model, dummy_input, path,
                     input_names=['input'],
                     output_names=['output'],
                     dynamic_axes={'input': {0: 'batch_size'},
                                 'output': {0: 'batch_size'}})

2. 容器化部署

使用Docker容器化部署模型服务。

# Dockerfile
FROM python:3.8-slim

# 安装依赖
COPY requirements.txt .
RUN pip install -r requirements.txt

# 复制应用代码
COPY . .

# 设置环境变量
ENV MODEL_PATH=/app/models
ENV PORT=8000

# 启动服务
CMD ["python", "app.py"]

# 构建镜像
docker build -t ai-model:latest .

# 运行容器
docker run -d -p 8000:8000 ai-model:latest

# 使用GPU
docker run -d --gpus all -p 8000:8000 ai-model:latest

# 使用数据卷
docker run -d -v /host/path:/container/path -p 8000:8000 ai-model:latest

3. API服务

提供RESTful API服务接口。

# FastAPI服务
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
import uvicorn

app = FastAPI()

# 配置CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# 加载模型
model = load_model("model.pth")

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    """预测接口"""
    # 读取文件
    contents = await file.read()
    
    # 预处理数据
    data = preprocess(contents)
    
    # 模型预测
    prediction = model.predict(data)
    
    return {"prediction": prediction.tolist()}

@app.get("/health")
async def health_check():
    """健康检查"""
    return {"status": "healthy"}

# 启动服务
if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)