Source code for backprop.utils.samplers

import random
import numpy as np
from torch.utils.data.sampler import Sampler


[docs]class SameGroupSampler(Sampler): def __init__(self, dataset): super().__init__(dataset) groups = dataset.groups items = zip(list(range(len(groups))), groups) item_to_group = {} group_to_items = {} for idx, group in items: item_to_group[idx] = group if group not in group_to_items: group_to_items[group] = [idx] else: group_to_items[group].append(idx) self.groups = set(groups) self.item_to_group = item_to_group self.group_to_items = group_to_items def __len__(self): return len(self.groups) def __iter__(self): for _ in range(len(self)): # Sample one group group_sample = random.sample(self.groups, 1)[0] items = self.group_to_items[group_sample] replace = False if len(items) < 2: replace = True # Sample two ids sample1, sample2 = np.random.choice(items, 2, replace=replace) yield sample1 yield sample2