123456789101112131415161718192021222324252627282930313233 |
- import torch
- from torch.utils.data import Dataset
-
- class DatasetLoader(Dataset):
- def __init__(self, config, image_filenames, texts, labels, tokenizer, transforms, mode):
- self.config = config
- self.image_filenames = image_filenames
- self.text = list(texts)
- self.encoded_text = tokenizer(
- list(texts), padding=True, truncation=True, max_length=config.max_length, return_tensors='pt'
- )
- self.labels = labels
- self.transforms = transforms
- self.mode = mode
-
- def set_text(self, idx):
- item = {
- key: values[idx].clone().detach()
- for key, values in self.encoded_text.items()
- }
- item['text'] = self.text[idx]
- item['label'] = self.labels[idx]
- item['id'] = idx
- return item
-
- def set_image(self, image, item):
- if 'resnet' in self.config.image_model_name:
- image = self.transforms(image=image)['image']
- item['image'] = torch.as_tensor(image).reshape((3, 224, 224))
- else:
- image = self.transforms(images=image, return_tensors='pt')
- image = image.convert_to_tensors(tensor_type='pt')['pixel_values']
- item['image'] = image.reshape((3, 224, 224))
|