Source code for backprop.models.clip.module

import torch
import torch.nn.functional as F
from torch.utils.data.sampler import Sampler
import pytorch_lightning as pl
import numpy as np

from PIL import Image
from typing import Union, List, Dict
from functools import partial
from . import clip, simple_tokenizer
from backprop.models import BaseModel
from backprop.utils import ImageTextGroupDataset, base64_to_img
from backprop.utils.losses import TripletLoss

from io import BytesIO
import base64
import random
from torch.utils.data.dataloader import DataLoader
import os

[docs]class CLIP(BaseModel): """ CLIP is a recent model by OpenAI. Attributes: model_path: ViT-B/32, RN50, RN101, RN50x4 init_model: initialise model from model_path init_tokenizer: initializes tokenizer name: string identifier for the model. Lowercase letters and numbers. No spaces/special characters except dashes. description: String description of the model. tasks: List of supported task strings details: Dictionary of additional details about the model device: Device for model. Defaults to "cuda" if available. """ def __init__(self, model_path="ViT-B/32", init_model=clip.load, init_tokenizer=simple_tokenizer.SimpleTokenizer, name: str = None, description: str = None, tasks: List[str] = None, details: Dict = None, device=None): BaseModel.__init__(self, None, name=name, description=description, tasks=tasks, details=details) self.init_model = init_model self.init_tokenizer = init_tokenizer self.model_path = model_path self._model_device = device if self._model_device is None: self._model_device = "cuda" if torch.cuda.is_available() else "cpu" # Initialise self.model, self.transform = self.init_model(model_path, device=self._model_device) tokenizer = self.init_tokenizer() self.tokenizer = partial(clip.tokenize, tokenizer) self.process_image = self.transform self.optimal_batch_size = 128 # Can't specify max_length self.max_length = None self.pre_finetuning = self.model.float
[docs] @staticmethod def list_models(): from .models_list import models return models
[docs] @torch.no_grad() def __call__(self, task_input, task="image-classification", return_tensor=False): """ Do inference with the model. Args: task_input: input dictionary according to task task: one of supported tasks return_tensor: return a tensor instead of list for vectorisation output """ output = None is_list = False if task == "image-classification": image = task_input.get("image") labels = task_input.get("labels") top_k = task_input.get("top_k", 10000) image = base64_to_img(image) if labels is None: raise ValueError("labels must be provided") is_list = type(image) == list if not is_list: image = [image] labels = [labels] image = [self.process_image(img).unsqueeze(0).to(self._model_device) for img in image] text = [self.tokenizer(l).to(self._model_device) for l in labels] output = self.image_classification(image=image, text=text, labels=labels, top_k=top_k) elif task == "image-vectorisation": image = task_input.get("image") image = base64_to_img(image) is_list = type(image) == list if not is_list: image = [image] image = [self.process_image(img) for img in image] image = torch.stack(image).to(self._model_device) img_vecs = self.image_vectorisation(image=image) if not return_tensor: img_vecs = img_vecs.tolist() output = img_vecs elif task == "text-vectorisation": text = task_input.get("text") is_list = type(text) == list if not is_list: text = [text] text = self.tokenizer(text).to(self._model_device) text_vecs = self.text_vectorisation(text=text) if not return_tensor: text_vecs = text_vecs.tolist() output = text_vecs elif task == "image-text-vectorisation": image = task_input.get("image") text = task_input.get("text") image = base64_to_img(image) is_list = type(image) == list if not is_list: image = [image] text = [text] text = self.tokenizer(text).to(self._model_device) image = [self.process_image(img) for img in image] image = torch.stack(image).to(self._model_device) img_text_vecs = self.image_text_vectorisation(image, text) if not return_tensor: img_text_vecs = img_text_vecs.tolist() output = img_text_vecs if not is_list: output = output[0] return output
[docs] def training_step(self, params, task): if task == "image-vectorisation": image = params["image"] return self.image_vectorisation(image) elif task == "text-vectorisation": text = params["text"] return self.text_vectorisation(text) elif task == "image-text-vectorisation": image = params["image"] text = params["text"] return self.image_text_vectorisation(image, text)
[docs] def process_batch(self, params, task): if task == "image-vectorisation": image = params["image"] return self.process_image(Image.open(image)).squeeze(0) elif task == "text-vectorisation": text = params["text"] return self.tokenizer(text)
[docs] def image_classification(self, image: torch.TensorType, text: torch.TensorType, labels, top_k=10000): probabilities = [] inputs = zip(image, text, labels) for image, text, labels in inputs: logits_per_image, logits_per_text = self.model(image, text) probs = logits_per_image.softmax(dim=-1) probs = probs.tolist()[0] label_probs = zip(labels, probs) label_probs = {lp[0]: lp[1] for lp in label_probs} label_probs = sorted(label_probs.items(), key=lambda x: x[1], reverse=True) label_probs = {k: v for k, v in label_probs[:top_k]} probabilities.append(label_probs) return probabilities
[docs] def image_vectorisation(self, image: torch.TensorType): image_features = self.model.encode_image(image) return image_features
[docs] def text_vectorisation(self, text: torch.TensorType): text = self.model.encode_text(text) return text
[docs] def image_text_vectorisation(self, image: torch.TensorType, text: torch.TensorType): image_vecs = self.model.encode_image(image) text_vecs = self.model.encode_text(text) img_text_vecs = torch.cat([image_vecs, text_vecs], 1) img_text_vecs_norm = img_text_vecs / img_text_vecs.norm(dim=-1, keepdim=True) return img_text_vecs_norm