123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566 |
-
- import pickle
- import cv2
- from PIL import Image
- from numpy import asarray
-
- from data.data_loader import DatasetLoader
-
-
-
- class WeiboDatasetLoader(DatasetLoader):
- def __getitem__(self, idx):
- item = self.set_text(idx)
-
- if self.labels[idx] == 1:
- try:
- image = cv2.imread(f"{self.config.rumor_image_path}/{self.image_filenames[idx]}")
- except:
- image = Image.open(f"{self.config.rumor_image_path}/{self.image_filenames[idx]}").convert('RGB')
- image = asarray(image)
- else:
- try:
- image = cv2.imread(f"{self.config.nonrumor_image_path}/{self.image_filenames[idx]}")
- except:
- image = Image.open(f"{self.config.nonrumor_image_path}/{self.image_filenames[idx]}").convert('RGB')
- image = asarray(image)
- self.set_image(image, item)
- return item
-
- def __len__(self):
- return len(self.text)
-
-
- class WeiboEmbeddingDatasetLoader(DatasetLoader):
- def __init__(self, config, image_filenames, texts, labels, tokenizer, transforms, mode):
- super().__init__(config, image_filenames, texts, labels, tokenizer, transforms, mode)
- self.config = config
- self.image_filenames = image_filenames
- self.text = list(texts)
- self.encoded_text = tokenizer
- self.labels = labels
- self.transforms = transforms
- self.mode = mode
-
- if mode == 'train':
- with open(config.train_image_embedding_path, 'rb') as handle:
- self.image_embedding = pickle.load(handle)
- with open(config.train_text_embedding_path, 'rb') as handle:
- self.text_embedding = pickle.load(handle)
- else:
- with open(config.validation_image_embedding_path, 'rb') as handle:
- self.image_embedding = pickle.load(handle)
- with open(config.validation_text_embedding_path, 'rb') as handle:
- self.text_embedding = pickle.load(handle)
-
-
- def __getitem__(self, idx):
- item = dict()
- item['id'] = idx
- item['image_embedding'] = self.image_embedding[idx]
- item['text_embedding'] = self.text_embedding[idx]
- item['label'] = self.labels[idx]
- return item
-
- def __len__(self):
- return len(self.text)
|