下面是一个完整的示例,展示如何使用 🤗 Transformers 加载支持流式输出的模型(如 ChatGLM、LLaMA、Mistral 等),并用 FastAPI + SSE (Server-Sent Events) 实现简化并行的流式推理系统。
transformers
generate
streamer
asyncio.to_thread
pip install fastapi uvicorn transformers accelerate sse-starlette
TextIteratorStreamer
asyncio.to_thread()
sse-starlette
# filename: app.py from fastapi import FastAPI, Request from sse_starlette.sse import EventSourceResponse from pydantic import BaseModel from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer import asyncio import torch app = FastAPI() # 模型与 tokenizer 全局只加载一次(线程安全的推理) model_name = "mistralai/Mistral-7B-Instruct-v0.1" # 示例模型,需适配你本地显卡 tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16) model.eval() class ChatRequest(BaseModel): prompt: str stream: bool = True max_new_tokens: int = 512 temperature: float = 0.7 @app.post("/chat") async def chat(request: ChatRequest): if request.stream: return EventSourceResponse(predict_stream(request)) else: response = await predict_full(request) return {"response": response} # ✅ 流式推理(使用 streamer + SSE) async def predict_stream(request: ChatRequest): prompt = request.prompt # 为每个请求创建新的 streamer 和 input streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device) generation_kwargs = dict( input_ids=input_ids, max_new_tokens=request.max_new_tokens, temperature=request.temperature, do_sample=True, streamer=streamer ) # 用子线程跑 generate,避免阻塞事件循环 loop = asyncio.get_event_loop() generate_task = loop.run_in_executor(None, model.generate, **generation_kwargs) # 异步读取 streamer 输出,逐步 yield async for token in streamer_async_wrapper(streamer): yield {"data": token} yield {"event": "end", "data": "[DONE]"} # ✅ 封装 TextIteratorStreamer 为异步生成器 async def streamer_async_wrapper(streamer): loop = asyncio.get_event_loop() while True: try: value = await loop.run_in_executor(None, next, streamer) yield value except StopIteration: break # ✅ 一次性返回全部响应(非流式) async def predict_full(request: ChatRequest): prompt = request.prompt input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device) output_ids = await asyncio.to_thread(model.generate, input_ids, max_new_tokens=request.max_new_tokens, temperature=request.temperature, do_sample=True ) output = tokenizer.decode(output_ids[0], skip_special_tokens=True) return output[len(prompt):] if output.startswith(prompt) else output
curl -N -X POST http://localhost:8000/chat \ -H "Content-Type: application/json" \ -d '{"prompt": "Hello, how are you?", "stream": true}'
model
input_ids
你可以把 model.generate() 替换成 ChatGLM 的 model.stream_chat(),但你必须:
model.generate()
model.stream_chat()
to_thread
device_map="auto"