发布时间:2024-12-27 17:48:50
本内容由, 集智官方收集发布,仅供参考学习,不代表集智官方赞同其观点或证实其内容的真实性准确性,请勿用于商业用途。
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
import numpy as np
# 自定义数据集类,用于加载和预处理汽车问答数据
class CarQADataset(Dataset):
def __init__(self, df, tokenizer, max_len=512):
"""
初始化数据集
params:
df: pandas DataFrame,包含问答对数据
tokenizer: BERT分词器
max_len: 序列最大长度
"""
self.tokenizer = tokenizer
self.questions = df['提问'].values
self.answers = df['回答'].values
self.max_len = max_len
def __len__(self):
return len(self.questions)
def __getitem__(self, idx):
"""
获取单个数据样本
params:
idx: 数据索引
returns:
包含输入ID、注意力掩码和答案的字典
"""
question = str(self.questions[idx])
answer = str(self.answers[idx])
# 使用BERT tokenizer对问题进行编码
encoding = self.tokenizer.encode_plus(
question,
add_special_tokens=True, # 添加[CLS]和[SEP]等特殊标记
max_length=self.max_len,
padding='max_length', # 填充到最大长度
truncation=True, # 截断过长的序列
return_tensors='pt' # 返回PyTorch张量
)
return {
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'answer': answer
}
# 定义问答模型,基于BERT的编码器-解码器架构
class CarQAModel(nn.Module):
def __init__(self, bert_model, hidden_size=768, num_classes=1000):
"""
初始化模型
params:
bert_model: 预训练的BERT模型
hidden_size: BERT隐藏层大小
num_classes: 输出类别数量
"""
super(CarQAModel, self).__init__()
self.bert = bert_model # BERT编码器
self.dropout = nn.Dropout(0.1) # 防止过拟合
self.fc = nn.Linear(hidden_size, num_classes) # 全连接输出层
def forward(self, input_ids, attention_mask):
"""
前向传播
params:
input_ids: 输入序列的ID
attention_mask: 注意力掩码
returns:
模型输出的logits
"""
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs[1] # 获取[CLS]标记的输出
x = self.dropout(pooled_output)
logits = self.fc(x)
return logits
# 数据准备
print("正在加载数据...")
df = pd.read_excel("car_data.xlsx")
# 划分训练集和测试集(80%训练,20%测试)
train_df = df.sample(frac=0.8, random_state=42)
test_df = df.drop(train_df.index)
# 初始化BERT tokenizer和模型
print("正在初始化模型...")
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
bert_model = BertModel.from_pretrained('bert-base-chinese')
model = CarQAModel(bert_model)
# 创建数据加载器
train_dataset = CarQADataset(train_df, tokenizer)
test_dataset = CarQADataset(test_df, tokenizer)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16)
# 训练设置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
num_epochs = 5
# 训练函数
def train():
"""训练模型的主循环"""
model.train()
for epoch in range(num_epochs):
total_loss = 0
print(f"\n开始 Epoch {epoch+1}/{num_epochs}")
for batch in tqdm.tqdm(train_loader, desc="Training"):
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
# 清零梯度
optimizer.zero_grad()
# 前向传播
outputs = model(input_ids, attention_mask)
# 计算损失(这里使用示例标签,实际应根据任务调整)
loss = criterion(outputs, torch.zeros(outputs.size(0)).long().to(device))
# 反向传播和优化
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(train_loader)
print(f'Epoch {epoch+1}/{num_epochs}, 平均损失: {avg_loss:.4f}')
# 评估函数
def evaluate():
"""评估模型性能"""
model.eval()
total_loss = 0
print("\n开始评估...")
with torch.no_grad():
for batch in tqdm.tqdm(test_loader, desc="Testing"):
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
outputs = model(input_ids, attention_mask)
loss = criterion(outputs, torch.zeros(outputs.size(0)).long().to(device))
total_loss += loss.item()
avg_loss = total_loss / len(test_loader)
print(f'测试集损失: {avg_loss:.4f}')
# 运行训练和评估
if __name__ == "__main__":
print("开始训练模型...")
train()
print("\n开始评估模型...")
evaluate()
# 保存模型
print("\n保存模型...")
torch.save(model.state_dict(), 'car_qa_model.pth')
print("模型已保存到 car_qa_model.pth")
对话交流数据集是一种专门用于训练对话系统或聊天机器人的数据集合,它包含了大量的对话实例。这些实例通常是由真实的对话记录或模拟的对话场景构成,旨在让机器学习模型能够理解和生成自然流畅的对话。