import albumentations as A import numpy as np import pandas as pd from sklearn.utils import compute_class_weight from torch.utils.data import DataLoader from transformers import ViTFeatureExtractor, BertTokenizer, BigBirdTokenizer, XLNetTokenizer def get_transforms(config, mode="train"): if mode == "train": return A.Compose( [ A.Resize(config.size, config.size, always_apply=True), A.Normalize(max_pixel_value=255.0, always_apply=True), ] ) else: return A.Compose( [ A.Resize(config.size, config.size, always_apply=True), A.Normalize(max_pixel_value=255.0, always_apply=True), ] ) def make_dfs(config): train_dataframe = pd.read_csv(config.train_text_path) train_dataframe.dropna(subset=['text'], inplace=True) train_dataframe = train_dataframe.sample(frac=1).reset_index(drop=True) train_dataframe.label = train_dataframe.label.apply(lambda x: config.classes.index(x)) config.class_weights = get_class_weights(train_dataframe.label.values) if config.test_text_path is None: offset = int(train_dataframe.shape[0] * 0.80) test_dataframe = train_dataframe[offset:] train_dataframe = train_dataframe[:offset] else: test_dataframe = pd.read_csv(config.test_text_path) test_dataframe.dropna(subset=['text'], inplace=True) test_dataframe = test_dataframe.sample(frac=1).reset_index(drop=True) test_dataframe.label = test_dataframe.label.apply(lambda x: config.classes.index(x)) if config.validation_text_path is None: offset = int(train_dataframe.shape[0] * 0.90) validation_dataframe = train_dataframe[offset:] train_dataframe = train_dataframe[:offset] else: validation_dataframe = pd.read_csv(config.validation_text_path) validation_dataframe.dropna(subset=['text'], inplace=True) validation_dataframe = validation_dataframe.sample(frac=1).reset_index(drop=True) validation_dataframe.label = validation_dataframe.label.apply(lambda x: config.classes.index(x)) unlabeled_dataframe = None if config.has_unlabeled_data: unlabeled_dataframe = pd.read_csv(config.unlabeled_text_path) unlabeled_dataframe = unlabeled_dataframe.sample(frac=1).reset_index(drop=True) return train_dataframe, test_dataframe, validation_dataframe, unlabeled_dataframe def get_tokenizer(config): if 'bigbird' in config.text_encoder_model: tokenizer = BigBirdTokenizer.from_pretrained(config.text_tokenizer) elif 'xlnet' in config.text_encoder_model: tokenizer = XLNetTokenizer.from_pretrained(config.text_tokenizer) else: tokenizer = BertTokenizer.from_pretrained(config.text_tokenizer) return tokenizer def build_loaders(config, dataframe, mode): if 'resnet' in config.image_model_name: transforms = get_transforms(config, mode=mode) else: transforms = ViTFeatureExtractor.from_pretrained(config.image_model_name) dataset = config.DatasetLoader( config, dataframe["image"].values, dataframe["text"].values, dataframe["label"].values, tokenizer=get_tokenizer(config), transforms=transforms, mode=mode, ) if mode != 'train': dataloader = DataLoader( dataset, batch_size=config.batch_size, num_workers=config.num_workers, pin_memory=False, shuffle=False, ) else: dataloader = DataLoader( dataset, batch_size=config.batch_size, num_workers=config.num_workers, pin_memory=True, shuffle=True, ) return dataloader def get_class_weights(y): class_weights = compute_class_weight('balanced', np.unique(y), y) return class_weights