发布时间:2024-10-15 13:57:39
本内容由, 集智数据集收集发布,仅供参考学习,不代表集智官方赞同其观点或证实其内容的真实性,请勿用于商业用途。
SQuAD(StanfordQuestionAnsweringDataset)是一个常用的机器阅读理解数据集,包含了问题、上下文以及答案。我们将使用HuggingFace的transformers库加载预训练的RoBERTa模型,并在SQuAD2.0数据集上进行预测。
你可以从以下链接下载SQuAD 2.0数据集:
首先,确保你已经安装了transformers
、datasets
和torch
库:
pip install transformers torch datasets
import torch
from transformers import RobertaTokenizer, RobertaForQuestionAnswering
from datasets import load_dataset
# 加载RoBERTa模型和Tokenizer
model_name = "deepset/roberta-base-squad2" # RoBERTa的SQuAD 2.0版本
tokenizer = RobertaTokenizer.from_pretrained(model_name)
model = RobertaForQuestionAnswering.from_pretrained(model_name)
# 加载SQuAD 2.0数据集
dataset = load_dataset("squad_v2")
# 选择一个示例(SQuAD 2.0包含问题、上下文和答案)
example = dataset['validation'][0]
context = example['context']
question = example['question']
# 打印示例问题和上下文
print(f"问题: {question}")
print(f"上下文: {context[:200]}...") # 打印上下文的一部分
# Tokenize输入内容
inputs = tokenizer(question, context, return_tensors="pt", truncation=True, padding=True)
# 模型预测答案
with torch.no_grad():
outputs = model(**inputs)
start_logits = outputs.start_logits
end_logits = outputs.end_logits
# 获取答案的起始和结束位置
start_idx = torch.argmax(start_logits)
end_idx = torch.argmax(end_logits)
# 解码答案
answer = tokenizer.decode(inputs['input_ids'][0][start_idx:end_idx + 1])
print(f"预测的答案: {answer}")
transformers
、torch
、datasets
。代码问题: What is the purpose of the Dnieper River?
上下文: The Dnieper River is one of the major rivers of Europe, rising in the Valdai Hills near Smolensk, Russia...
预测的答案: major rivers of Europe
这个代码基于通用的SQuAD 2.0数据集,易于复现,并且是基于RoBERTa的问答任务模型。
本站将定期更新分享一些python机器学习的精选代码