import torch.nn as nn from transformers import BertTokenizerFast, BertModel from labels import labels,label2id,id2label import torch class BERTForNER(nn.Module): def __init__(self, model_name, num_labels): super().__init__() self.bert = BertModel.from_pretrained(model_name) self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels) def forward(self, input_ids, attention_mask): outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) logits = self.classifier(outputs.last_hidden_state) return logits # [batch, seq_len, num_labels] if __name__ =="__main__" : num_labels = len(labels) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = BERTForNER("./bert-base-uncased", num_labels).to(device) torch.save(model.state_dict(), "empty_ner_model.pt")
from transformers import BertTokenizerFast from utils import NERDataset from torch.utils.data import DataLoader import torch.nn as nn import torch from model import BERTForNER import json import matplotlib.pyplot as plt # from labels import labels,label2id,id2label # num_labels = len(labels) # with open("data/train.json") as f: # input_data = json.load(f) from dataset import id2label,label2id,train_data num_labels = len(id2label) # hf download bert-base-uncased --local-dir bert-base-uncase device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = BertTokenizerFast.from_pretrained("./bert-base-uncased") dataset = NERDataset(train_data, tokenizer,label2id) dataloader = DataLoader(dataset, batch_size=2, shuffle=True) model = BERTForNER("./bert-base-uncased", num_labels).to(device) loss_fn = nn.CrossEntropyLoss(ignore_index=-100) optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5) epochs = 20 train_losses = [] epochs = 10 for epoch in range(epochs): print(f"\nEpoch {epoch+1}/{epochs}") model.train() batch_losses = [] for batch in dataloader: input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) labels = batch["labels"].to(device) logits = model(input_ids=input_ids, attention_mask=attention_mask) loss = loss_fn(logits.view(-1, num_labels), labels.view(-1)) loss.backward() optimizer.step() optimizer.zero_grad() batch_losses.append(loss.item()) avg_train_loss = sum(batch_losses) / len(batch_losses) train_losses.append(avg_train_loss) print(f"Average Train Loss: {avg_train_loss:.4f}") torch.save(model.state_dict(), "ner_model.pt") plt.plot(range(1, epochs+1), train_losses, label="Train Loss") plt.xlabel("Epoch") plt.ylabel("Loss") plt.title("Loss Over Epochs") plt.legend() plt.savefig("loss_epochs.png")
import torch import torch.nn as nn # ---------------------------- # 1. k-mer embedding 工具 # ---------------------------- def kmerize(seq, k=3): """将 DNA 序列切成 k-mer""" kmers = [seq[i:i+k] for i in range(len(seq)-k+1)] return kmers def build_kmer_vocab(k=3): """构建 k-mer 字典""" from itertools import product bases = ['A','C','G','T'] kmers = [''.join(p) for p in product(bases, repeat=k)] vocab = {kmer:i for i,kmer in enumerate(kmers)} return vocab # ---------------------------- # 2. 模型定义 # ---------------------------- class ReadClassifier(nn.Module): def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes): super().__init__() self.embedding = nn.Embedding(vocab_size, embed_dim) self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True, bidirectional=True) self.attention = nn.Linear(hidden_dim*2, 1) self.classifier = nn.Linear(hidden_dim*2, num_classes) def forward(self, x): # x: [batch, seq_len] -> k-mer indices emb = self.embedding(x) # [batch, seq_len, embed_dim] lstm_out, _ = self.lstm(emb) # [batch, seq_len, hidden*2] # Self-attention attn_weights = torch.softmax(self.attention(lstm_out), dim=1) # [batch, seq_len, 1] context = torch.sum(attn_weights * lstm_out, dim=1) # [batch, hidden*2] logits = self.classifier(context) # [batch, num_classes] return logits # ---------------------------- # 3. 示例数据处理 # ---------------------------- k = 3 vocab = build_kmer_vocab(k) num_classes = 5 # 假设5个物种 embed_dim = 16 hidden_dim = 32 # 假设 batch 2 条 reads reads = ["ACGTGTCAGT", "TGCAGTACGT"] def encode_reads(reads, k, vocab): encoded = [] for seq in reads: kmers = kmerize(seq, k) indices = [vocab.get(kmer, 0) for kmer in kmers] # unknown k-mer -> 0 encoded.append(indices) # pad to same length max_len = max(len(x) for x in encoded) encoded = [x + [0]*(max_len - len(x)) for x in encoded] return torch.tensor(encoded, dtype=torch.long) X = encode_reads(reads, k, vocab) # [batch, seq_len] # ---------------------------- # 4. 初始化模型 & 预测 # ---------------------------- model = ReadClassifier(vocab_size=len(vocab), embed_dim=embed_dim, hidden_dim=hidden_dim, num_classes=num_classes) logits = model(X) # [batch, num_classes] pred_prob = torch.softmax(logits, dim=1) y_pred = pred_prob.argmax(dim=1) print("Logits:\n", logits) print("Predicted class:\n", y_pred)
Logits: tensor([[-0.0808, 0.0110, -0.0984, 0.1560, 0.0818], [-0.0931, -0.0096, -0.0811, 0.1254, 0.1258]], grad_fn=<AddmmBackward0>) Predicted class: tensor([3, 4])