rajpurkar/squad_v2
Viewer • Updated • 142k • 31.2k • 251
This model is a fine-tuned version of google-bert/bert-base-uncased on the SQuAD v2 dataset. It has been trained to perform extractive question answering with the ability to detect unanswerable questions.
This model is based on BERT base uncased architecture and has been fine-tuned on SQuAD v2, which extends the original SQuAD dataset to include questions that cannot be answered based on the provided context. The model learns to either provide the answer span from the context or indicate that the question cannot be answered.
Key features:
The model was trained with the following hyperparameters:
The model achieved the following performance metrics:
Additional training statistics:
This model is intended for:
Limitations:
import torch
from transformers import AutoModelForQuestionAnswering, AutoTokenizer
# Load model & tokenizer
model_name = "real-jiakai/bert-base-uncased-finetuned-squadv2"
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
def get_answer_v2(question, context, threshold=0.0):
# Tokenize input with maximum sequence length of 384
inputs = tokenizer(
question,
context,
return_tensors="pt",
max_length=384,
truncation=True
)
with torch.no_grad():
outputs = model(**inputs)
start_logits = outputs.start_logits[0]
end_logits = outputs.end_logits[0]
# Calculate null score (score for predicting no answer)
null_score = start_logits[0].item() + end_logits[0].item()
# Find the best non-null answer, excluding [CLS] position
# Set logits at [CLS] position to negative infinity
start_logits[0] = float('-inf')
end_logits[0] = float('-inf')
start_idx = torch.argmax(start_logits)
end_idx = torch.argmax(end_logits)
# Ensure end_idx is not less than start_idx
if end_idx < start_idx:
end_idx = start_idx
answer_score = start_logits[start_idx].item() + end_logits[end_idx].item()
# If null score is higher (beyond threshold), return "no answer"
if null_score - answer_score > threshold:
return "Question cannot be answered based on the given context."
# Otherwise, return the extracted answer
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
answer = tokenizer.convert_tokens_to_string(tokens[start_idx:end_idx+1])
# Check if answer is empty or contains only special tokens
if not answer.strip() or answer.strip() in ['[CLS]', '[SEP]']:
return "Question cannot be answered based on the given context."
return answer.strip()
# Example usage
context = "The Apollo program was designed to land humans on the Moon and bring them safely back to Earth."
questions = [
"What was the goal of the Apollo program?",
"Who was the first person to walk on Mars?", # Unanswerable question
"What was the Apollo program designed to do?"
]
for question in questions:
answer = get_answer_v2(question, context, threshold=1.0)
print(f"Question: {question}")
print(f"Answer: {answer}")
print("-" * 50)
Base model
google-bert/bert-base-uncased