import torch from torch.utils.data import Dataset class DatasetLoader(Dataset): def __init__(self, config, image_filenames, texts, labels, tokenizer, transforms, mode): self.config = config self.image_filenames = image_filenames self.text = list(texts) self.encoded_text = tokenizer( list(texts), padding=True, truncation=True, max_length=config.max_length, return_tensors='pt' ) self.labels = labels self.transforms = transforms self.mode = mode def set_text(self, idx): item = { key: values[idx].clone().detach() for key, values in self.encoded_text.items() } item['text'] = self.text[idx] item['label'] = self.labels[idx] item['id'] = idx return item def set_image(self, image, item): if 'resnet' in self.config.image_model_name: image = self.transforms(image=image)['image'] item['image'] = torch.as_tensor(image).reshape((3, 224, 224)) else: image = self.transforms(images=image, return_tensors='pt') image = image.convert_to_tensors(tensor_type='pt')['pixel_values'] item['image'] = image.reshape((3, 224, 224))