from typing import List, Tuple, Dict
from transformers import AutoModelForPreTraining, AutoTokenizer, \
AutoModelForSequenceClassification, AdamW
from torch.utils.data import DataLoader, Subset
from sentence_transformers import SentenceTransformer
from functools import partial
import os
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.utilities.memory import garbage_collection_cuda
[docs]class BaseModel(torch.nn.Module):
"""
The base class for a model.
Attributes:
model: Your model that takes some args, kwargs and returns an output.
Must be callable.
name: string identifier for the model. Lowercase letters and numbers.
No spaces/special characters except dashes.
description: String description of the model.
tasks: List of supported task strings
details: Dictionary of additional details about the model
"""
def __init__(self, model, name: str = None, description: str = None, tasks: List[str] = None,
details: Dict = None):
torch.nn.Module.__init__(self)
self.model = model
self.name = name or "base-model"
self.description = description or "This is the base description. Change me."
self.tasks = tasks or [] # Supports no tasks
self.details = details or {}
def __call__(self, *args, **kwargs):
return self.model(*args, **kwargs)
[docs] def finetune(self, *args, **kwargs):
raise NotImplementedError("This model does not support finetuning")
[docs] def to(self, device):
self.model.to(device)
self._model_device = device
return self
[docs] def train(self, mode: bool = True):
self.model.train(mode)
return self
[docs] def eval(self):
self.model.eval()
return self
[docs] def num_parameters(self):
return sum(p.numel() for p in self.parameters())
[docs]class PathModel(BaseModel):
"""
Class for models which are initialised from a path.
Attributes:
model_path: Path to the model
name: string identifier for the model. Lowercase letters and numbers.
No spaces/special characters except dashes.
description: String description of the model.
tasks: List of supported task strings
details: Dictionary of additional details about the model
init_model: Callable to initialise model from path
tokenizer_path (optional): Path to the tokenizer
init_tokenizer (optional): Callable to initialise tokenizer from path
device (optional): Device for inference. Defaults to "cuda" if available.
"""
def __init__(self, model_path, init_model, name: str = None,
description: str = None, tasks: List[str] = None,
details: Dict = None, tokenizer_path=None,
init_tokenizer=None, device=None):
BaseModel.__init__(self, None, name=name, description=description, tasks=tasks, details=details)
self.init_model = init_model
self.init_tokenizer = init_tokenizer
self.model_path = model_path
self.tokenizer_path = tokenizer_path
self._model_device = device
if self._model_device is None:
self._model_device = "cuda" if torch.cuda.is_available() else "cpu"
# Initialise
self.model = self.init_model(model_path).eval().to(self._model_device)
# Not all models need tokenizers
if self.tokenizer_path:
self.tokenizer = self.init_tokenizer(self.tokenizer_path)
def __call__(self, *args, **kwargs):
return self.model(*args, **kwargs)
[docs]class HFModel(PathModel):
"""
Class for huggingface models
Attributes:
model_path: Local or huggingface.co path to the model
name: string identifier for the model. Lowercase letters and numbers.
No spaces/special characters except dashes.
description: String description of the model.
tasks: List of supported task strings
details: Dictionary of additional details about the model
init_model: Callable to initialise model from path
Defaults to AutoModelForPreTraining from huggingface
tokenizer_path (optional): Path to the tokenizer
init_tokenizer (optional): Callable to initialise tokenizer from path
Defaults to AutoTokenizer from huggingface.
device (optional): Device for inference. Defaults to "cuda" if available.
"""
def __init__(self, model_path, tokenizer_path=None, name: str = None,
description: str = None, tasks: List[str] = None, details: Dict = None,
model_class=AutoModelForPreTraining,
tokenizer_class=AutoTokenizer, device=None):
# Usually the same
if not tokenizer_path:
tokenizer_path = model_path
# Object was made with init = False
if hasattr(self, "initialised"):
model_path = self.model_path
tokenizer_path = self.tokenizer_path
init_model = self.init_model
init_tokenizer = self.init_tokenizer
device = self._model_device
else:
init_model = model_class.from_pretrained
init_tokenizer = tokenizer_class.from_pretrained
return PathModel.__init__(self, model_path, name=name, description=description,
tasks=tasks,
details=details,
tokenizer_path=tokenizer_path,
init_model=init_model,
init_tokenizer=init_tokenizer,
device=device)
[docs]class HFTextGenerationModel(HFModel):
"""
Class for huggingface models that implement the .generate method.
Attributes:
*args and **kwargs are passed to HFModel's __init__
"""
[docs] def generate(self, text, variant="seq2seq", **kwargs):
"""
Generate according to the model's generate method.
"""
DEFAULT_MIN_LENGTH = 10
DEFAULT_MAX_LENGTH = 20
if text == "":
raise ValueError("Some text must be provided")
# Get and remove do_sample or set to False
do_sample = kwargs.pop("do_sample", None) or False
params = ["temperature", "top_k", "top_p", "repetition_penalty",
"length_penalty", "num_beams", "num_return_sequences", "num_generations"]
# If params are changed, we want to sample
for param in params:
if param in kwargs.keys() and kwargs[param] != None:
do_sample = True
break
if "temperature" in kwargs:
# No sampling
if kwargs["temperature"] == 0.0:
do_sample = False
del kwargs["temperature"]
# Override, name correctly
if "num_generations" in kwargs:
if kwargs["num_generations"] != None:
kwargs["num_return_sequences"] = kwargs["num_generations"]
del kwargs["num_generations"]
min_length = kwargs.pop("min_length", DEFAULT_MIN_LENGTH)
max_length = kwargs.pop("max_length", DEFAULT_MAX_LENGTH)
if min_length is None:
min_length = DEFAULT_MIN_LENGTH
if max_length is None:
max_length = DEFAULT_MAX_LENGTH
is_list = False
if isinstance(text, list):
is_list = True
if not is_list:
text = [text]
all_tokens = []
output = []
for text in text:
features = self.tokenizer(text, return_tensors="pt")
input_length = len(features["input_ids"][0])
for k, v in features.items():
features[k] = v.to(self._model_device)
if variant == "causal_lm":
min_length += input_length
max_length += input_length
with torch.no_grad():
tokens = self.model.generate(do_sample=do_sample,
min_length=min_length,
max_length=max_length,
**features, **kwargs)
if variant == "causal_lm":
output.append([self.tokenizer.decode(tokens[input_length:], skip_special_tokens=True)
for tokens in tokens])
else:
output.append([self.tokenizer.decode(tokens, skip_special_tokens=True)
for tokens in tokens])
# Unwrap generation list
if kwargs.get("num_return_sequences", 1) == 1:
output_unwrapped = []
for value in output:
output_unwrapped.append(value[0])
output = output_unwrapped
# Return single item
if not is_list:
output = output[0]
return output