完整的中文语言模型从头预训练(最小结构) + FastAPI SSE 多线程流式输出服务

最后发布时间 : 2025-07-13 14:34:02 浏览量 :

好的,下面我将为你提供一份完整的中文语言模型从头预训练(最小结构) + FastAPI SSE 多线程流式输出服务 的全流程方案,适用于:

  • 训练一个最小的中文 GPT2 模型;
  • 提供支持 SSE 的流式服务接口;
  • 可运行、可修改、资源需求极低。

🧱 Part 1:最小中文预训练 GPT2 模型结构(从头训练)

✅ 1.1 模型配置(config.json

{
  "architectures": ["GPT2LMHeadModel"],
  "vocab_size": 5000,
  "n_positions": 512,
  "n_ctx": 512,
  "n_embd": 256,
  "n_layer": 4,
  "n_head": 4,
  "bos_token_id": 0,
  "eos_token_id": 1,
  "pad_token_id": 1
}

✅ 1.2 分词器训练(使用 tokenizers

准备语料文件 corpus.txt(每行一个句子),训练:

from tokenizers import ByteLevelBPETokenizer

tokenizer = ByteLevelBPETokenizer()
tokenizer.train(["corpus.txt"], vocab_size=5000, min_frequency=2, special_tokens=[
    "<pad>", "<s>", "</s>", "<unk>"
])
tokenizer.save_model("tokenizer")

将生成:

  • vocab.json
  • merges.txt

用来构建 HuggingFace 分词器:

from transformers import GPT2TokenizerFast

tokenizer = GPT2TokenizerFast.from_pretrained("tokenizer")
tokenizer.pad_token = "<pad>"
tokenizer.save_pretrained("chatgpt2-mini")

✅ 1.3 训练脚本(使用 🤗 Transformers)

from transformers import GPT2Config, GPT2LMHeadModel, GPT2TokenizerFast, Trainer, TrainingArguments, TextDataset, DataCollatorForLanguageModeling

# 加载 tokenizer 和模型
tokenizer = GPT2TokenizerFast.from_pretrained("chatgpt2-mini")
config = GPT2Config.from_pretrained("config.json")
model = GPT2LMHeadModel(config)

# 加载数据
train_dataset = TextDataset(
    tokenizer=tokenizer,
    file_path="corpus.txt",
    block_size=128
)

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=False
)

training_args = TrainingArguments(
    output_dir="./chatgpt2-mini",
    overwrite_output_dir=True,
    num_train_epochs=10,
    per_device_train_batch_size=8,
    save_steps=100,
    save_total_limit=2,
    logging_steps=10
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset
)

trainer.train()
trainer.save_model("chatgpt2-mini")
tokenizer.save_pretrained("chatgpt2-mini")

🚀 Part 2:多线程 SSE 流式输出(基于训练后模型)

# filename: server.py

from fastapi import FastAPI
from sse_starlette.sse import EventSourceResponse
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import asyncio
import torch

app = FastAPI()

# ✅ 加载刚才训练完成的最小模型
model_path = "./chatgpt2-mini"
tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_path).to("cuda" if torch.cuda.is_available() else "cpu")
model.eval()

class ChatRequest(BaseModel):
    prompt: str
    stream: bool = True
    max_new_tokens: int = 128
    temperature: float = 0.7

@app.post("/chat")
async def chat(request: ChatRequest):
    if request.stream:
        return EventSourceResponse(predict_stream(request))
    else:
        output = await predict_once(request)
        return {"response": output}

async def predict_stream(request: ChatRequest):
    prompt = request.prompt
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)

    generation_kwargs = dict(
        input_ids=input_ids,
        max_new_tokens=request.max_new_tokens,
        temperature=request.temperature,
        do_sample=True,
        streamer=streamer
    )

    loop = asyncio.get_event_loop()
    _ = loop.run_in_executor(None, model.generate, **generation_kwargs)

    async for token in streamer_async_wrapper(streamer):
        yield {"data": token}

    yield {"event": "end", "data": "[DONE]"}

async def streamer_async_wrapper(streamer):
    loop = asyncio.get_event_loop()
    while True:
        try:
            value = await loop.run_in_executor(None, next, streamer)
            yield value
        except StopIteration:
            break

async def predict_once(request: ChatRequest):
    input_ids = tokenizer(request.prompt, return_tensors="pt").input_ids.to(model.device)
    output_ids = await asyncio.to_thread(
        model.generate,
        input_ids,
        max_new_tokens=request.max_new_tokens,
        temperature=request.temperature,
        do_sample=True
    )
    response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return response[len(request.prompt):] if response.startswith(request.prompt) else response

🧪 测试方式:

uvicorn server:app --port 8000
curl -N -X POST http://localhost:8000/chat \
  -H "Content-Type: application/json" \
  -d '{"prompt": "你好", "stream": true}'

🧠 总结

组件内容
分词器tokenizers 训练 BPE,导出为 fast tokenizer
模型结构4 层 GPT2 + 256 hidden,GPT2LMHeadModel
数据任意中文语料,按行训练
推理服务FastAPI + TextIteratorStreamer + SSE
多线程防阻塞loop.run_in_executor()