def classify_mentions(
this_id: str,
candidate_pairs: list,
tokenizer: Optional[PreTrainedTokenizer] = None,
model: Optional[PreTrainedModel] = None,
device: Optional[torch.device] = None
) -> list:
"""
Classify candidate resource mentions using the SciBERT model.
Args:
this_id (str): Identifier for the publication or text being classified.
candidate_pairs (list): List of tuples (sentence, matched_alias, resource_name).
tokenizer (PreTrainedTokenizer, optional): The tokenizer to use.
model (PreTrainedModel, optional): The pre-trained SciBERT model.
device (torch.device, optional): The device to run the model on.
Returns:
List of prediction dictionaries with keys:
- prediction (int): 1 for positive, 0 for negative.
- id (str): The provided this_id.
- resource_name (str): The name of the resource.
- matched_alias (str): The alias that was matched.
- sentence (str): The sentence in which the alias was found.
- confidence (float): The confidence score for the prediction.
"""
predictions = []
for sentence, alias, resource in tqdm(candidate_pairs, desc="🔍 Classifying"):
inputs = tokenizer(alias, sentence, return_tensors="pt", truncation=True, padding="max_length", max_length=512).to(device)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
pred = torch.argmax(probs, dim=1).item()
if pred == 1:
predictions.append({
"prediction": 1,
"id": this_id,
"resource_name": resource,
"matched_alias": alias,
"sentence": sentence,
"confidence": probs[0, 1].item()
})
else:
predictions.append({
"prediction": 0,
"id": this_id,
"resource_name": resource,
"matched_alias": alias,
"sentence": sentence,
"confidence": probs[0, 0].item()
})
return predictions