Build the Neural Network

最后发布时间 : 2025-11-18 11:45:51 浏览量 :
import torch.nn as nn
from transformers import BertTokenizerFast, BertModel
from labels import labels,label2id,id2label
import torch

class BERTForNER(nn.Module):
    def __init__(self, model_name, num_labels):
        super().__init__()
        self.bert = BertModel.from_pretrained(model_name)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        logits = self.classifier(outputs.last_hidden_state)
        return logits  # [batch, seq_len, num_labels]



if __name__ =="__main__" :
    num_labels = len(labels)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = BERTForNER("./bert-base-uncased", num_labels).to(device)
    torch.save(model.state_dict(), "empty_ner_model.pt")
from transformers import BertTokenizerFast
from utils import NERDataset
from torch.utils.data import DataLoader
import torch.nn as nn
import torch
from model import BERTForNER
import json
import matplotlib.pyplot as plt

# from labels import labels,label2id,id2label
# num_labels = len(labels)
# with open("data/train.json") as f:
#     input_data = json.load(f)

from dataset import id2label,label2id,train_data
num_labels = len(id2label)

# hf download  bert-base-uncased --local-dir bert-base-uncase


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizerFast.from_pretrained("./bert-base-uncased")
dataset = NERDataset(train_data, tokenizer,label2id)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
model = BERTForNER("./bert-base-uncased", num_labels).to(device)
loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
epochs = 20

train_losses = []

epochs = 10
for epoch in range(epochs):
    print(f"\nEpoch {epoch+1}/{epochs}")
    model.train()
    batch_losses = []
    for batch in dataloader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        logits = model(input_ids=input_ids, attention_mask=attention_mask)
        loss = loss_fn(logits.view(-1, num_labels), labels.view(-1))

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        batch_losses.append(loss.item())

    avg_train_loss = sum(batch_losses) / len(batch_losses)
    train_losses.append(avg_train_loss)
    print(f"Average Train Loss: {avg_train_loss:.4f}")

torch.save(model.state_dict(), "ner_model.pt")

plt.plot(range(1, epochs+1), train_losses, label="Train Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Loss Over Epochs")
plt.legend()
plt.savefig("loss_epochs.png")
import torch
import torch.nn as nn

# ----------------------------
# 1. k-mer embedding 工具
# ----------------------------
def kmerize(seq, k=3):
    """将 DNA 序列切成 k-mer"""
    kmers = [seq[i:i+k] for i in range(len(seq)-k+1)]
    return kmers

def build_kmer_vocab(k=3):
    """构建 k-mer 字典"""
    from itertools import product
    bases = ['A','C','G','T']
    kmers = [''.join(p) for p in product(bases, repeat=k)]
    vocab = {kmer:i for i,kmer in enumerate(kmers)}
    return vocab

# ----------------------------
# 2. 模型定义
# ----------------------------
class ReadClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.attention = nn.Linear(hidden_dim*2, 1)
        self.classifier = nn.Linear(hidden_dim*2, num_classes)
    
    def forward(self, x):
        # x: [batch, seq_len] -> k-mer indices
        emb = self.embedding(x)                 # [batch, seq_len, embed_dim]
        lstm_out, _ = self.lstm(emb)           # [batch, seq_len, hidden*2]
        
        # Self-attention
        attn_weights = torch.softmax(self.attention(lstm_out), dim=1)  # [batch, seq_len, 1]
        context = torch.sum(attn_weights * lstm_out, dim=1)            # [batch, hidden*2]
        
        logits = self.classifier(context)        # [batch, num_classes]
        return logits

# ----------------------------
# 3. 示例数据处理
# ----------------------------
k = 3
vocab = build_kmer_vocab(k)
num_classes = 5  # 假设5个物种
embed_dim = 16
hidden_dim = 32

# 假设 batch 2 条 reads
reads = ["ACGTGTCAGT", "TGCAGTACGT"]

def encode_reads(reads, k, vocab):
    encoded = []
    for seq in reads:
        kmers = kmerize(seq, k)
        indices = [vocab.get(kmer, 0) for kmer in kmers]  # unknown k-mer -> 0
        encoded.append(indices)
    # pad to same length
    max_len = max(len(x) for x in encoded)
    encoded = [x + [0]*(max_len - len(x)) for x in encoded]
    return torch.tensor(encoded, dtype=torch.long)

X = encode_reads(reads, k, vocab)  # [batch, seq_len]

# ----------------------------
# 4. 初始化模型 & 预测
# ----------------------------
model = ReadClassifier(vocab_size=len(vocab), embed_dim=embed_dim,
                       hidden_dim=hidden_dim, num_classes=num_classes)

logits = model(X)                        # [batch, num_classes]
pred_prob = torch.softmax(logits, dim=1)
y_pred = pred_prob.argmax(dim=1)

print("Logits:\n", logits)
print("Predicted class:\n", y_pred)
Logits:
 tensor([[-0.0808,  0.0110, -0.0984,  0.1560,  0.0818],
        [-0.0931, -0.0096, -0.0811,  0.1254,  0.1258]],
       grad_fn=<AddmmBackward0>)
Predicted class:
 tensor([3, 4])