from typing import List, Dict
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from backprop.models import HFModel
import torch
[docs]class HFNLIModel(HFModel):
"""
Class for Hugging Face sequence classification models trained on a NLI dataset
Attributes:
model_path: path to HF model
tokenizer_path: path to HF tokenizer
name: string identifier for the model. Lowercase letters and numbers.
No spaces/special characters except dashes.
description: String description of the model.
tasks: List of supported task strings
details: Dictionary of additional details about the model
model_class: Class used to initialise model
tokenizer_class: Class used to initialise tokenizer
device: Device for model. Defaults to "cuda" if available.
"""
def __init__(self, model_path=None, tokenizer_path=None, name: str = None,
description: str = None, tasks: List[str] = None, details: Dict = None,
model_class=AutoModelForSequenceClassification,
tokenizer_class=AutoTokenizer, device=None):
tasks = tasks or ["text-classification"]
HFModel.__init__(self, model_path, name=name, description=description,
tasks=tasks, details=details, tokenizer_path=tokenizer_path,
model_class=model_class, tokenizer_class=tokenizer_class,
device=device)
[docs] @torch.no_grad()
def __call__(self, task_input, task="text-classification"):
"""
Uses the model for the text-classification task
Args:
task_input: input dictionary according to the ``text-classification`` task specification.
Needs labels (for zero-shot).
task: text-classification
"""
if task == "text-classification":
is_list = False
text = task_input.get("text")
labels = task_input.get("labels")
top_k = task_input.get("top_k", 10000)
if labels == None:
raise ValueError("labels must be provided")
if isinstance(text, list):
is_list = True
else:
text = [text]
labels = [labels]
# Must have a consistent amount of examples
assert(len(text) == len(labels))
probs = self.classify(text, labels, top_k)
if not is_list:
probs = probs[0]
return probs
else:
raise ValueError(f"Unsupported task: {task}")
[docs] @staticmethod
def list_models():
from .models_list import models
return models
[docs] def calculate_probability(self, text, labels):
batch_features = []
hypothesis = [f"This example is {l}." for l in labels]
features = self.tokenizer([text]*len(hypothesis), hypothesis, return_tensors="pt",
truncation=True, padding=True).to(self._model_device)
logits = self.model(features["input_ids"], features["attention_mask"])[0]
entail_contradiction_logits = logits[:, [0, 2]]
probs = entail_contradiction_logits.softmax(dim=1)
prob_label_is_true = probs[:, 1]
return prob_label_is_true.tolist()
[docs] def classify(self, text, labels, top_k):
"""
Classifies text, given a set of labels.
"""
probabilities = []
for text, labels in zip(text, labels):
probs = {}
probs_list = self.calculate_probability(text, labels)
for prob, label in zip(probs_list, labels):
probs[label] = prob
probs = sorted(probs.items(), key=lambda x: x[1], reverse=True)
probs = {k: v for k, v in probs[:top_k]}
probabilities.append(probs)
return probabilities