HuggingFace
希望参考 OpenAI 官方的 GPT-2 GitHub 仓库,训练一个最小可运行、资源占用低的 demo 模型。我可以为你提供一个完整的、可运行的 GPT-2 最小训练方案,在单张普通显卡(如 RTX 3060)上也能运行。
训练一个 最小版 GPT-2 模型(例如:117M 或更小的微型版本),在本地完成 end-to-end 流程,包括:
虽然 OpenAI/gpt-2 是原始实现,但:
pip install transformers datasets accelerate bitsandbytes
如果你没有 GPU,建议用 transformers 的 Trainer 启动 CPU 训练(虽然慢)。
transformers
Trainer
创建一个小的文本文件(比如 data/train.txt):
data/train.txt
Once upon a time, there was a small fox. It lived in the forest and loved to play. GPT models can generate text. Hello world! 你好,世界!
这就是你训练 GPT 模型的起点语料。
下面是完整训练代码(保存为 train_gpt2_mini.py):
train_gpt2_mini.py
from transformers import GPT2Config, GPT2LMHeadModel, GPT2TokenizerFast, Trainer, TrainingArguments, TextDataset, DataCollatorForLanguageModeling import os # ✅ 新建最小 GPT-2 配置(非常小,适合快速测试) config = GPT2Config( vocab_size=50257, n_positions=128, n_ctx=128, n_embd=128, n_layer=2, n_head=2 ) # ✅ 初始化模型和 tokenizer(使用 GPT-2 的 tokenizer) tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") model = GPT2LMHeadModel(config) # ✅ 准备训练数据(huggingface TextDataset 会自动切块) train_file = "data/train.txt" dataset = TextDataset( tokenizer=tokenizer, file_path=train_file, block_size=128 ) data_collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=False ) # ✅ 训练参数 training_args = TrainingArguments( output_dir="./gpt2-mini-model", overwrite_output_dir=True, per_device_train_batch_size=2, num_train_epochs=3, save_steps=50, save_total_limit=2, logging_steps=10, logging_dir="./logs" ) # ✅ 训练器 trainer = Trainer( model=model, args=training_args, data_collator=data_collator, train_dataset=dataset ) # ✅ 启动训练 trainer.train() # ✅ 保存模型 trainer.save_model("./gpt2-mini-model") tokenizer.save_pretrained("./gpt2-mini-model")
from transformers import GPT2LMHeadModel, GPT2Tokenizer model = GPT2LMHeadModel.from_pretrained("./gpt2-mini-model") tokenizer = GPT2Tokenizer.from_pretrained("./gpt2-mini-model") prompt = "Once upon a time" input_ids = tokenizer(prompt, return_tensors="pt").input_ids output = model.generate(input_ids, max_new_tokens=50) print(tokenizer.decode(output[0]))
. ├── data/ │ └── train.txt # 小语料 ├── train_gpt2_mini.py # 训练脚本 ├── gpt2-mini-model/ # 训练输出模型 ├── logs/ # tensorboard日志(可选)
fp16=True
TrainingArguments
peft
GPT2Config(n_layer=2, n_embd=128)