NER和RE联合建模

最后发布时间 : 2025-08-04 15:35:16 浏览量 :
import torch
import torch.nn as nn
from transformers import BertTokenizerFast, BertModel
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader

# 假数据示例
examples = [
    {
        "tokens": ["Short-chain", "fatty", "acids", "are", "important"],
        "ner_tags": [1, 2, 2, 0, 0],  # 0=O, 1=B-CHEM, 2=I-CHEM
        # 实体列表 (start_idx, end_idx, label_id)
        "entities": [(0, 2, 1)],
        # 关系对 (head_entity_idx, tail_entity_idx, relation_label)
        "relations": []
    },
    {
        "tokens": ["gut", "microbiota", "regulate", "brain", "function"],
        "ner_tags": [3, 4, 0, 5, 6],  # 3=B-MICRO,4=I-MICRO,5=B-ORG,6=I-ORG
        "entities": [(0, 1, 3), (3, 4, 5)],
        "relations": [(0, 1, 1)]  # relation label 1 表示 "AFFECTS"
    }
]

label_list = ["O", "B-CHEM", "I-CHEM", "B-MICRO", "I-MICRO", "B-ORG", "I-ORG"]
label_to_id = {l: i for i, l in enumerate(label_list)}

relation_list = ["no_relation", "AFFECTS"]
relation_to_id = {r: i for i, r in enumerate(relation_list)}

tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

class NERREDataset(Dataset):
    def __init__(self, examples):
        self.examples = examples

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        ex = self.examples[idx]
        # 分词并对齐NER标签
        encoding = tokenizer(ex["tokens"], is_split_into_words=True, return_tensors=None, truncation=True, padding='max_length', max_length=32)
        word_ids = encoding.word_ids()
        labels = []
        for word_idx in word_ids:
            if word_idx is None:
                labels.append(-100)
            else:
                labels.append(ex["ner_tags"][word_idx])
        encoding["labels"] = labels

        # 构造实体起始位置映射(token索引)
        entities_token_pos = []
        for (start, end, label) in ex["entities"]:
            # 取实体起始token index(word_ids中第一个等于start的索引)
            token_start = None
            for i, wid in enumerate(word_ids):
                if wid == start:
                    token_start = i
                    break
            entities_token_pos.append((token_start, label))

        # 构造实体对及关系标签(实体对索引 + 标签)
        rel_pairs = []
        rel_labels = []
        for (h_idx, t_idx, r_label) in ex["relations"]:
            if h_idx < len(entities_token_pos) and t_idx < len(entities_token_pos):
                h_pos, _ = entities_token_pos[h_idx]
                t_pos, _ = entities_token_pos[t_idx]
                rel_pairs.append((h_pos, t_pos))
                rel_labels.append(r_label)
        # 允许batch时,没关系对就用空列表,训练时注意处理
        encoding["entity_positions"] = entities_token_pos
        encoding["rel_pairs"] = rel_pairs
        encoding["rel_labels"] = rel_labels

        return encoding

# 定义模型
class JointNERREModel(nn.Module):
    def __init__(self, model_name, num_ner_labels, num_rel_labels):
        super().__init__()
        self.bert = BertModel.from_pretrained(model_name)
        hidden_size = self.bert.config.hidden_size
        self.ner_classifier = nn.Linear(hidden_size, num_ner_labels)
        self.rel_classifier = nn.Linear(hidden_size * 2, num_rel_labels)

    def forward(self, input_ids, attention_mask, labels=None, entity_positions=None, rel_pairs=None, rel_labels=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state  # [B, L, H]

        ner_logits = self.ner_classifier(sequence_output)  # [B, L, num_ner_labels]

        loss_fct = nn.CrossEntropyLoss()
        loss_ner = None
        loss_rel = None

        if labels is not None:
            loss_ner = loss_fct(ner_logits.view(-1, ner_logits.shape[-1]), labels.view(-1))

        rel_logits = None
        if entity_positions is not None and rel_pairs is not None:
            batch_rel_logits = []
            batch_rel_labels = []
            for bidx in range(input_ids.size(0)):
                if len(rel_pairs[bidx]) == 0:
                    continue
                e_pos = entity_positions[bidx]  # [(token_pos, label), ...]
                pairs = rel_pairs[bidx]
                rel_lab = rel_labels[bidx]

                h_vecs = []
                t_vecs = []
                for (h_pos, t_pos) in pairs:
                    h_vecs.append(sequence_output[bidx, h_pos])
                    t_vecs.append(sequence_output[bidx, t_pos])
                h_vecs = torch.stack(h_vecs)
                t_vecs = torch.stack(t_vecs)
                pair_vec = torch.cat([h_vecs, t_vecs], dim=1)
                logits = self.rel_classifier(pair_vec)
                batch_rel_logits.append(logits)
                batch_rel_labels.append(torch.tensor(rel_lab, device=logits.device))
            if len(batch_rel_logits) > 0:
                rel_logits = torch.cat(batch_rel_logits, dim=0)
                rel_labels_tensor = torch.cat(batch_rel_labels, dim=0)
                loss_rel = loss_fct(rel_logits, rel_labels_tensor)

        loss = None
        if loss_ner is not None and loss_rel is not None:
            loss = loss_ner + loss_rel
        elif loss_ner is not None:
            loss = loss_ner

        return (loss, ner_logits, rel_logits)

def collate_fn(batch):
    # batch是list,里面是字典
    input_ids = [torch.tensor(x["input_ids"]) for x in batch]
    attention_mask = [torch.tensor(x["attention_mask"]) for x in batch]
    labels = [torch.tensor(x["labels"]) for x in batch]

    max_len = max([len(ids) for ids in input_ids])
    # pad
    input_ids = torch.stack([torch.nn.functional.pad(ids, (0, max_len - len(ids))) for ids in input_ids])
    attention_mask = torch.stack([torch.nn.functional.pad(mask, (0, max_len - len(mask))) for mask in attention_mask])
    labels = torch.stack([torch.nn.functional.pad(lbl, (0, max_len - len(lbl)), value=-100) for lbl in labels])

    # 处理entity_positions和rel_pairs,rel_labels为list[list]形式,直接传给模型
    entity_positions = [x["entity_positions"] for x in batch]
    rel_pairs = [x["rel_pairs"] for x in batch]
    rel_labels = [x["rel_labels"] for x in batch]

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels,
        "entity_positions": entity_positions,
        "rel_pairs": rel_pairs,
        "rel_labels": rel_labels,
    }

# 实例化数据集和加载器
dataset = NERREDataset(examples)
dataloader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = JointNERREModel("bert-base-uncased", num_ner_labels=len(label_list), num_rel_labels=len(relation_list)).to(device)
optimizer = AdamW(model.parameters(), lr=5e-5)

model.train()
for epoch in range(3):
    for batch in dataloader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        # 关系数据不用tensor,直接传list
        entity_positions = batch["entity_positions"]
        rel_pairs = batch["rel_pairs"]
        rel_labels = batch["rel_labels"]

        optimizer.zero_grad()
        loss, ner_logits, rel_logits = model(input_ids, attention_mask, labels, entity_positions, rel_pairs, rel_labels)
        loss.backward()
        optimizer.step()

        print(f"Epoch {epoch} loss: {loss.item():.4f}")

import torch

def inference(model, tokenizer, sentence, label_list, relation_list, device):
    model.eval()
    tokens = sentence.split()  # 简单空格切分,你也可以用 tokenizer 分词后对齐

    encoding = tokenizer(tokens, is_split_into_words=True, return_tensors="pt", truncation=True, padding="max_length", max_length=32)
    input_ids = encoding["input_ids"].to(device)
    attention_mask = encoding["attention_mask"].to(device)

    with torch.no_grad():
        # 只做NER预测,关系用None输入
        loss, ner_logits, rel_logits = model(input_ids, attention_mask)

    ner_preds = ner_logits.argmax(dim=-1)[0].cpu().tolist()  # 预测标签ids列表
    word_ids = encoding.word_ids(batch_index=0)

    # 词级对齐:过滤掉special token,得到每个词对应的预测标签
    aligned_preds = []
    previous_word_idx = None
    for idx, word_idx in enumerate(word_ids):
        if word_idx is None:
            continue
        if word_idx != previous_word_idx:
            aligned_preds.append(ner_preds[idx])
            previous_word_idx = word_idx

    id2label = {i: l for i, l in enumerate(label_list)}
    decoded_ner = [(tokens[i], id2label[aligned_preds[i]]) for i in range(len(aligned_preds))]

    # 简单抽实体(BIO合并)
    entities = []
    start = None
    current_label = None
    for i, (token, label) in enumerate(decoded_ner):
        if label.startswith("B-"):
            if start is not None:
                entities.append((start, i - 1, current_label))
            start = i
            current_label = label[2:]
        elif label.startswith("I-") and current_label == label[2:]:
            continue
        else:
            if start is not None:
                entities.append((start, i - 1, current_label))
                start = None
                current_label = None
    if start is not None:
        entities.append((start, len(decoded_ner) - 1, current_label))

    print("Entities extracted:")
    for s, e, l in entities:
        print(f"  {' '.join(tokens[s:e+1])} - {l} (tokens {s}-{e})")

    # 构造实体起始token索引,用于关系预测
    # word_ids对应token到词的映射,实体起始词索引映射回token索引
    entity_positions = []
    for (start, _, label) in entities:
        # 找token级别的start位置索引
        token_start_idx = None
        for idx, widx in enumerate(word_ids):
            if widx == start:
                token_start_idx = idx
                break
        entity_positions.append((token_start_idx, label))

    # 预测实体关系
    rel_preds = []
    for i, e1 in enumerate(entity_positions):
        for j, e2 in enumerate(entity_positions):
            if i == j:
                continue
            e1_pos = torch.tensor([[e1[0]]], device=device)
            e2_pos = torch.tensor([[e2[0]]], device=device)

            with torch.no_grad():
                _, _, rel_logits = model(input_ids, attention_mask, None, [e1], [(e1[0], e2[0])], [[0]])
                # 注意:这里简单传了个假label([0]),因为你的模型forward签名需要
                # 你可以根据模型修改forward支持无label推理

            if rel_logits is not None:
                pred_rel_id = rel_logits.argmax(dim=-1).item()
                if pred_rel_id != 0:  # 0表示 no_relation
                    rel_preds.append((entities[i], entities[j], relation_list[pred_rel_id]))

    print("\nRelations predicted:")
    if len(rel_preds) == 0:
        print("  None")
    else:
        for (e1, e2, rel) in rel_preds:
            e1_text = ' '.join(tokens[e1[0]:e1[1]+1])
            e2_text = ' '.join(tokens[e2[0]:e2[1]+1])
            print(f"  ({e1_text}) --[{rel}]--> ({e2_text})")

    return decoded_ner, entities, rel_preds


# 假设:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
label_list = ["O", "B-CHEM", "I-CHEM", "B-MICRO", "I-MICRO", "B-ORG", "I-ORG"]
relation_list = ["no_relation", "AFFECTS"]

# model, tokenizer 已经加载好且放到 device

sentence = "Short-chain fatty acids affect brain function"
decoded_ner, entities, relations = inference(model, tokenizer, sentence, label_list, relation_list, device)
entities