import torch import torch.nn as nn from transformers import BertTokenizerFast, BertModel from torch.optim import AdamW from torch.utils.data import Dataset, DataLoader # 假数据示例 examples = [ { "tokens": ["Short-chain", "fatty", "acids", "are", "important"], "ner_tags": [1, 2, 2, 0, 0], # 0=O, 1=B-CHEM, 2=I-CHEM # 实体列表 (start_idx, end_idx, label_id) "entities": [(0, 2, 1)], # 关系对 (head_entity_idx, tail_entity_idx, relation_label) "relations": [] }, { "tokens": ["gut", "microbiota", "regulate", "brain", "function"], "ner_tags": [3, 4, 0, 5, 6], # 3=B-MICRO,4=I-MICRO,5=B-ORG,6=I-ORG "entities": [(0, 1, 3), (3, 4, 5)], "relations": [(0, 1, 1)] # relation label 1 表示 "AFFECTS" } ] label_list = ["O", "B-CHEM", "I-CHEM", "B-MICRO", "I-MICRO", "B-ORG", "I-ORG"] label_to_id = {l: i for i, l in enumerate(label_list)} relation_list = ["no_relation", "AFFECTS"] relation_to_id = {r: i for i, r in enumerate(relation_list)} tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") class NERREDataset(Dataset): def __init__(self, examples): self.examples = examples def __len__(self): return len(self.examples) def __getitem__(self, idx): ex = self.examples[idx] # 分词并对齐NER标签 encoding = tokenizer(ex["tokens"], is_split_into_words=True, return_tensors=None, truncation=True, padding='max_length', max_length=32) word_ids = encoding.word_ids() labels = [] for word_idx in word_ids: if word_idx is None: labels.append(-100) else: labels.append(ex["ner_tags"][word_idx]) encoding["labels"] = labels # 构造实体起始位置映射(token索引) entities_token_pos = [] for (start, end, label) in ex["entities"]: # 取实体起始token index(word_ids中第一个等于start的索引) token_start = None for i, wid in enumerate(word_ids): if wid == start: token_start = i break entities_token_pos.append((token_start, label)) # 构造实体对及关系标签(实体对索引 + 标签) rel_pairs = [] rel_labels = [] for (h_idx, t_idx, r_label) in ex["relations"]: if h_idx < len(entities_token_pos) and t_idx < len(entities_token_pos): h_pos, _ = entities_token_pos[h_idx] t_pos, _ = entities_token_pos[t_idx] rel_pairs.append((h_pos, t_pos)) rel_labels.append(r_label) # 允许batch时,没关系对就用空列表,训练时注意处理 encoding["entity_positions"] = entities_token_pos encoding["rel_pairs"] = rel_pairs encoding["rel_labels"] = rel_labels return encoding # 定义模型 class JointNERREModel(nn.Module): def __init__(self, model_name, num_ner_labels, num_rel_labels): super().__init__() self.bert = BertModel.from_pretrained(model_name) hidden_size = self.bert.config.hidden_size self.ner_classifier = nn.Linear(hidden_size, num_ner_labels) self.rel_classifier = nn.Linear(hidden_size * 2, num_rel_labels) def forward(self, input_ids, attention_mask, labels=None, entity_positions=None, rel_pairs=None, rel_labels=None): outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) sequence_output = outputs.last_hidden_state # [B, L, H] ner_logits = self.ner_classifier(sequence_output) # [B, L, num_ner_labels] loss_fct = nn.CrossEntropyLoss() loss_ner = None loss_rel = None if labels is not None: loss_ner = loss_fct(ner_logits.view(-1, ner_logits.shape[-1]), labels.view(-1)) rel_logits = None if entity_positions is not None and rel_pairs is not None: batch_rel_logits = [] batch_rel_labels = [] for bidx in range(input_ids.size(0)): if len(rel_pairs[bidx]) == 0: continue e_pos = entity_positions[bidx] # [(token_pos, label), ...] pairs = rel_pairs[bidx] rel_lab = rel_labels[bidx] h_vecs = [] t_vecs = [] for (h_pos, t_pos) in pairs: h_vecs.append(sequence_output[bidx, h_pos]) t_vecs.append(sequence_output[bidx, t_pos]) h_vecs = torch.stack(h_vecs) t_vecs = torch.stack(t_vecs) pair_vec = torch.cat([h_vecs, t_vecs], dim=1) logits = self.rel_classifier(pair_vec) batch_rel_logits.append(logits) batch_rel_labels.append(torch.tensor(rel_lab, device=logits.device)) if len(batch_rel_logits) > 0: rel_logits = torch.cat(batch_rel_logits, dim=0) rel_labels_tensor = torch.cat(batch_rel_labels, dim=0) loss_rel = loss_fct(rel_logits, rel_labels_tensor) loss = None if loss_ner is not None and loss_rel is not None: loss = loss_ner + loss_rel elif loss_ner is not None: loss = loss_ner return (loss, ner_logits, rel_logits) def collate_fn(batch): # batch是list,里面是字典 input_ids = [torch.tensor(x["input_ids"]) for x in batch] attention_mask = [torch.tensor(x["attention_mask"]) for x in batch] labels = [torch.tensor(x["labels"]) for x in batch] max_len = max([len(ids) for ids in input_ids]) # pad input_ids = torch.stack([torch.nn.functional.pad(ids, (0, max_len - len(ids))) for ids in input_ids]) attention_mask = torch.stack([torch.nn.functional.pad(mask, (0, max_len - len(mask))) for mask in attention_mask]) labels = torch.stack([torch.nn.functional.pad(lbl, (0, max_len - len(lbl)), value=-100) for lbl in labels]) # 处理entity_positions和rel_pairs,rel_labels为list[list]形式,直接传给模型 entity_positions = [x["entity_positions"] for x in batch] rel_pairs = [x["rel_pairs"] for x in batch] rel_labels = [x["rel_labels"] for x in batch] return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "entity_positions": entity_positions, "rel_pairs": rel_pairs, "rel_labels": rel_labels, } # 实例化数据集和加载器 dataset = NERREDataset(examples) dataloader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = JointNERREModel("bert-base-uncased", num_ner_labels=len(label_list), num_rel_labels=len(relation_list)).to(device) optimizer = AdamW(model.parameters(), lr=5e-5) model.train() for epoch in range(3): for batch in dataloader: input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) labels = batch["labels"].to(device) # 关系数据不用tensor,直接传list entity_positions = batch["entity_positions"] rel_pairs = batch["rel_pairs"] rel_labels = batch["rel_labels"] optimizer.zero_grad() loss, ner_logits, rel_logits = model(input_ids, attention_mask, labels, entity_positions, rel_pairs, rel_labels) loss.backward() optimizer.step() print(f"Epoch {epoch} loss: {loss.item():.4f}")
import torch def inference(model, tokenizer, sentence, label_list, relation_list, device): model.eval() tokens = sentence.split() # 简单空格切分,你也可以用 tokenizer 分词后对齐 encoding = tokenizer(tokens, is_split_into_words=True, return_tensors="pt", truncation=True, padding="max_length", max_length=32) input_ids = encoding["input_ids"].to(device) attention_mask = encoding["attention_mask"].to(device) with torch.no_grad(): # 只做NER预测,关系用None输入 loss, ner_logits, rel_logits = model(input_ids, attention_mask) ner_preds = ner_logits.argmax(dim=-1)[0].cpu().tolist() # 预测标签ids列表 word_ids = encoding.word_ids(batch_index=0) # 词级对齐:过滤掉special token,得到每个词对应的预测标签 aligned_preds = [] previous_word_idx = None for idx, word_idx in enumerate(word_ids): if word_idx is None: continue if word_idx != previous_word_idx: aligned_preds.append(ner_preds[idx]) previous_word_idx = word_idx id2label = {i: l for i, l in enumerate(label_list)} decoded_ner = [(tokens[i], id2label[aligned_preds[i]]) for i in range(len(aligned_preds))] # 简单抽实体(BIO合并) entities = [] start = None current_label = None for i, (token, label) in enumerate(decoded_ner): if label.startswith("B-"): if start is not None: entities.append((start, i - 1, current_label)) start = i current_label = label[2:] elif label.startswith("I-") and current_label == label[2:]: continue else: if start is not None: entities.append((start, i - 1, current_label)) start = None current_label = None if start is not None: entities.append((start, len(decoded_ner) - 1, current_label)) print("Entities extracted:") for s, e, l in entities: print(f" {' '.join(tokens[s:e+1])} - {l} (tokens {s}-{e})") # 构造实体起始token索引,用于关系预测 # word_ids对应token到词的映射,实体起始词索引映射回token索引 entity_positions = [] for (start, _, label) in entities: # 找token级别的start位置索引 token_start_idx = None for idx, widx in enumerate(word_ids): if widx == start: token_start_idx = idx break entity_positions.append((token_start_idx, label)) # 预测实体关系 rel_preds = [] for i, e1 in enumerate(entity_positions): for j, e2 in enumerate(entity_positions): if i == j: continue e1_pos = torch.tensor([[e1[0]]], device=device) e2_pos = torch.tensor([[e2[0]]], device=device) with torch.no_grad(): _, _, rel_logits = model(input_ids, attention_mask, None, [e1], [(e1[0], e2[0])], [[0]]) # 注意:这里简单传了个假label([0]),因为你的模型forward签名需要 # 你可以根据模型修改forward支持无label推理 if rel_logits is not None: pred_rel_id = rel_logits.argmax(dim=-1).item() if pred_rel_id != 0: # 0表示 no_relation rel_preds.append((entities[i], entities[j], relation_list[pred_rel_id])) print("\nRelations predicted:") if len(rel_preds) == 0: print(" None") else: for (e1, e2, rel) in rel_preds: e1_text = ' '.join(tokens[e1[0]:e1[1]+1]) e2_text = ' '.join(tokens[e2[0]:e2[1]+1]) print(f" ({e1_text}) --[{rel}]--> ({e2_text})") return decoded_ner, entities, rel_preds # 假设: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") label_list = ["O", "B-CHEM", "I-CHEM", "B-MICRO", "I-MICRO", "B-ORG", "I-ORG"] relation_list = ["no_relation", "AFFECTS"] # model, tokenizer 已经加载好且放到 device sentence = "Short-chain fatty acids affect brain function" decoded_ner, entities, relations = inference(model, tokenizer, sentence, label_list, relation_list, device) entities