使用 🤗 Transformers 加载支持流式输出的模型(如 ChatGLM、LLaMA、Mistral 等),并用 FastAPI + SSE (Server-Sent Events) 实现简化并行的流式推理系统

最后发布时间 : 2025-07-13 14:14:03 浏览量 :

下面是一个完整的示例,展示如何使用 🤗 Transformers 加载支持流式输出的模型(如 ChatGLM、LLaMA、Mistral 等),并用 FastAPI + SSE (Server-Sent Events) 实现简化并行的流式推理系统

✅ 实现目标

  • 使用 transformers 中的 generate + streamer
  • 并行推理支持:使用 asyncio.to_thread 执行模型(避免阻塞)
  • SSE 实现客户端流式接收响应
  • 并发安全:每个请求独立上下文,避免全局状态冲突

📦 安装依赖

pip install fastapi uvicorn transformers accelerate sse-starlette

🧠 示例逻辑结构

  • 使用 TextIteratorStreamer 来实现 流式生成
  • asyncio.to_thread() 包装阻塞的 PyTorch 推理代码
  • sse-starlette 实现流式响应

🧪 完整示例代码

# filename: app.py

from fastapi import FastAPI, Request
from sse_starlette.sse import EventSourceResponse
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import asyncio
import torch

app = FastAPI()

# 模型与 tokenizer 全局只加载一次(线程安全的推理)
model_name = "mistralai/Mistral-7B-Instruct-v0.1"  # 示例模型,需适配你本地显卡
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
model.eval()

class ChatRequest(BaseModel):
    prompt: str
    stream: bool = True
    max_new_tokens: int = 512
    temperature: float = 0.7

@app.post("/chat")
async def chat(request: ChatRequest):
    if request.stream:
        return EventSourceResponse(predict_stream(request))
    else:
        response = await predict_full(request)
        return {"response": response}


# ✅ 流式推理(使用 streamer + SSE)
async def predict_stream(request: ChatRequest):
    prompt = request.prompt

    # 为每个请求创建新的 streamer 和 input
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)

    generation_kwargs = dict(
        input_ids=input_ids,
        max_new_tokens=request.max_new_tokens,
        temperature=request.temperature,
        do_sample=True,
        streamer=streamer
    )

    # 用子线程跑 generate,避免阻塞事件循环
    loop = asyncio.get_event_loop()
    generate_task = loop.run_in_executor(None, model.generate, **generation_kwargs)

    # 异步读取 streamer 输出,逐步 yield
    async for token in streamer_async_wrapper(streamer):
        yield {"data": token}

    yield {"event": "end", "data": "[DONE]"}


# ✅ 封装 TextIteratorStreamer 为异步生成器
async def streamer_async_wrapper(streamer):
    loop = asyncio.get_event_loop()
    while True:
        try:
            value = await loop.run_in_executor(None, next, streamer)
            yield value
        except StopIteration:
            break


# ✅ 一次性返回全部响应(非流式)
async def predict_full(request: ChatRequest):
    prompt = request.prompt
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)

    output_ids = await asyncio.to_thread(model.generate,
        input_ids,
        max_new_tokens=request.max_new_tokens,
        temperature=request.temperature,
        do_sample=True
    )

    output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return output[len(prompt):] if output.startswith(prompt) else output

🌐 测试方法(curl 或前端)

curl -N -X POST http://localhost:8000/chat \
-H "Content-Type: application/json" \
-d '{"prompt": "Hello, how are you?", "stream": true}'

✅ 并发安全说明

  • model 是只读的、推理线程安全 ✅
  • 每次请求单独创建 streamer + input_ids
  • 使用 asyncio.to_thread() 隔离阻塞调用 ✅
  • SSE 保证多个用户连接互不影响 ✅

🔧 兼容其他模型(如 ChatGLM)

你可以把 model.generate() 替换成 ChatGLM 的 model.stream_chat(),但你必须:

  • 确保每个请求构造一个独立的上下文(对话历史);
  • 避免多个请求同时调用共享状态。

✅ 小结

功能点状态
PyTorch 推理并发✅ 安全,使用 to_thread 包装
流式输出✅ 使用 TextIteratorStreamer
SSE 推送响应✅ 使用 sse-starlette
多用户并发✅ 每请求独立上下文
多 GPU 模型支持✅ 使用 device_map="auto"