BERT NER

最后发布时间 : 2025-08-05 10:56:53 浏览量 :
import torch
import torch.nn as nn
from transformers import BertTokenizerFast, BertModel

# === 1. 标签准备 ===
labels = ["O", "B-PER", "I-PER"]
label2id = {label: i for i, label in enumerate(labels)}
id2label = {i: label for label, i in label2id.items()}
num_labels = len(labels)

# === 2. 输入数据准备 ===
sentence = ["Hello", "John", "Doe"]
ner_tags = ["O", "B-PER", "I-PER"]

tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
encoding = tokenizer(sentence, is_split_into_words=True, return_tensors="pt", padding=True, truncation=True)

input_ids = encoding["input_ids"]  # [1, seq_len]
attention_mask = encoding["attention_mask"]

# 对齐标签到 subword:
word_ids = encoding.word_ids(batch_index=0)
label_ids = []
for word_idx in word_ids:
    if word_idx is None:
        label_ids.append(-100)  # special tokens (CLS, SEP)
    else:
        label_ids.append(label2id[ner_tags[word_idx]])

labels_tensor = torch.tensor([label_ids])  # shape: [1, seq_len]

# === 3. 模型定义(BERT + Linear)===
class BERTForNER(nn.Module):
    def __init__(self, model_name, num_labels):
        super().__init__()
        self.bert = BertModel.from_pretrained(model_name)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        logits = self.classifier(outputs.last_hidden_state)  # [batch, seq_len, num_labels]
        return logits

# === 4. 初始化模型和损失函数 ===
model = BERTForNER("bert-base-uncased", num_labels)
loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

# === 5. 训练步骤(前向 -> loss -> 反向 -> 优化)===
model.train()

# 前向传播
logits = model(input_ids=input_ids, attention_mask=attention_mask)  # [1, seq_len, num_labels]

# reshape logits 和 labels 以适配 CrossEntropyLoss
logits = logits.view(-1, num_labels)     # [batch * seq_len, num_labels]
labels = labels_tensor.view(-1)          # [batch * seq_len]

loss = loss_fn(logits, labels)
print("Loss:", loss.item())

# 反向传播
loss.backward()

# 参数更新
optimizer.step()

# 清空梯度
optimizer.zero_grad()