Source code for backprop.utils.datasets

import torch
from torch.utils.data import Dataset
from PIL import Image
import numpy as np

[docs]class ImageTextPairDataset(Dataset): def __init__(self, img_text_pairs1, img_text_pairs2, similarity_scores, process_batch): super().__init__() self.texts1 = [t1 for i1, t1 in img_text_pairs1] self.texts2 = [t2 for i2, t2 in img_text_pairs2] self.imgs1 = [i1 for i1, t1 in img_text_pairs1] self.imgs2 = [i2 for i2, t2 in img_text_pairs2] self.similarity_scores = similarity_scores self.process_batch = process_batch def __len__(self): return len(self.similarity_scores) def __getitem__(self, idx): texts1 = self.process_batch({"text": self.texts1[idx]}, task="text-vectorisation") texts2 = self.process_batch({"text": self.texts2[idx]}, task="text-vectorisation") if isinstance(texts1, torch.Tensor): texts1 = texts1.squeeze(0) else: texts1 = {k: v.squeeze(0) for k, v in texts1.items()} if isinstance(texts2, torch.Tensor): texts2 = texts2.squeeze(0) else: texts2 = {k: v.squeeze(0) for k, v in texts2.items()} imgs1 = self.process_batch({"image": self.imgs1[idx]}, task="image-vectorisation") imgs2 = self.process_batch({"image": self.imgs2[idx]}, task="image-vectorisation") similarity_scores = torch.tensor(self.similarity_scores[idx]) return texts1, imgs1, texts2, imgs2, similarity_scores
[docs]class ImagePairDataset(Dataset): def __init__(self, imgs1, imgs2, similarity_scores, process_batch): super().__init__() self.imgs1 = imgs1 self.imgs2 = imgs2 self.similarity_scores = similarity_scores self.process_batch = process_batch def __len__(self): return len(self.similarity_scores) def __getitem__(self, idx): imgs1 = self.process_batch({"image": self.imgs1[idx]}, task="image-vectorisation") imgs2 = self.process_batch({"image": self.imgs2[idx]}, task="image-vectorisation") similarity_scores = torch.tensor(self.similarity_scores[idx]) return imgs1, imgs2, similarity_scores
[docs]class TextPairDataset(Dataset): def __init__(self, texts1, texts2, similarity_scores, process_batch, max_length=None): super().__init__() self.texts1 = texts1 self.texts2 = texts2 self.similarity_scores = similarity_scores self.process_batch = process_batch self.max_length = max_length def __len__(self): return len(self.similarity_scores) def __getitem__(self, idx): texts1 = self.process_batch({"text": self.texts1[idx], "max_length": self.max_length}, task="text-vectorisation") texts2 = self.process_batch({"text": self.texts2[idx], "max_length": self.max_length}, task="text-vectorisation") if isinstance(texts1, torch.Tensor): texts1 = texts1.squeeze(0) else: texts1 = {k: v.squeeze(0) for k, v in texts1.items()} if isinstance(texts2, torch.Tensor): texts2 = texts2.squeeze(0) else: texts2 = {k: v.squeeze(0) for k, v in texts2.items()} similarity_scores = torch.tensor(self.similarity_scores[idx]) return texts1, texts2, similarity_scores
[docs]class ImageTextGroupDataset(Dataset): def __init__(self, images, texts, groups, process_batch): super().__init__() self.images = images self.texts = texts self.groups = groups self.process_batch = process_batch def __len__(self): return len(self.images) def __getitem__(self, idx): image = self.process_batch({"image": self.images[idx]}, task="image-vectorisation") text = self.process_batch({"text": self.texts[idx]}, task="text-vectorisation") if isinstance(text, torch.Tensor): text = text.squeeze(0) else: text = {k: v.squeeze(0) for k, v in text.items()} group = torch.tensor(self.groups[idx]) return image, text, group
[docs]class ImageGroupDataset(Dataset): def __init__(self, images, groups, process_batch): super().__init__() self.images = images self.groups = groups self.process_batch = process_batch def __len__(self): return len(self.images) def __getitem__(self, idx): image = self.process_batch({"image": self.images[idx]}, task="image-vectorisation") group = torch.tensor(self.groups[idx]) return image, group
[docs]class TextGroupDataset(Dataset): def __init__(self, texts, groups, process_batch, max_length=None): super().__init__() self.texts = texts self.groups = groups self.process_batch = process_batch self.max_length = max_length def __len__(self): return len(self.texts) def __getitem__(self, idx): text = self.process_batch({"text": self.texts[idx], "max_length": self.max_length}) if isinstance(text, torch.Tensor): text = text.squeeze(0) else: text = {k: v.squeeze(0) for k, v in text.items()} group = torch.tensor(self.groups[idx]) return text, group
[docs]class SingleLabelImageClassificationDataset(Dataset): def __init__(self, images, labels, process_batch): super().__init__() self.images = images self.labels = labels self.label_to_idx = {label: i for i, label in enumerate(set(labels))} self.process_batch = process_batch def __len__(self): return len(self.images) def __getitem__(self, idx): image = self.process_batch({"image": self.images[idx]}, task="image-classification") target = torch.tensor(self.label_to_idx[self.labels[idx]]) return image, target
[docs]class MultiLabelImageClassificationDataset(Dataset): def __init__(self, images, labels, process_batch): super().__init__() self.images = images self.labels = labels all_labels = list(np.concatenate(labels).flat) self.all_labels = set(all_labels) self.label_to_idx = {label: i for i, label in enumerate(self.all_labels)} self.process_batch = process_batch def __len__(self): return len(self.images) def __getitem__(self, idx): image = self.process_batch({"image": images[idx]}, task="image-classification") target = torch.zeros(len(self.all_labels)) for label in self.labels[idx]: i = self.label_to_idx[label] target[i] = 1. return image, target
[docs]class TextToTextDataset(Dataset): def __init__(self, params, task, process_batch, length): self.params = params self.task = task self.process_batch = process_batch self.length = length def __len__(self): return self.length def __getitem__(self, idx): # self.params is a dict containig lists (inputs, outputs) and fixed values (e.g. max_input_length) # Line here gets [idx] of lists, as well as fixed values, as a dict to be passed to model for processing. params = {k: (v if type(v) != list else v[idx]) for k, v in self.params.items()} inp = self.process_batch(params, task=self.task) if isinstance(inp, torch.Tensor): inp = inp.squeeze(0) # out = out.squeeze(0) else: inp = {k: v.squeeze(0) for k, v in inp.items()} # out = {k: v.squeeze(0) for k, v in out.items()} return {**inp}
[docs]class SingleLabelTextClassificationDataset(Dataset): def __init__(self, params, process_batch, length): super().__init__() self.params = params self.process_batch = process_batch self.length = length def __len__(self): return self.length def __getitem__(self, idx): # self.params is a dict containig lists (inputs, outputs) and fixed values (e.g. max_input_length) # Line here gets [idx] of lists, as well as fixed values, as a dict to be passed to model for processing. params = {k: (v if type(v) != list else v[idx]) for k, v in self.params.items()} inp = self.process_batch(params, task="text-classification") return {**inp}