123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108 |
- import albumentations as A
- import numpy as np
- import pandas as pd
-
- from sklearn.utils import compute_class_weight
- from torch.utils.data import DataLoader
- from transformers import ViTFeatureExtractor, BertTokenizer, BigBirdTokenizer, XLNetTokenizer
-
-
- def get_transforms(config, mode="train"):
- if mode == "train":
- return A.Compose(
- [
- A.Resize(config.size, config.size, always_apply=True),
- A.Normalize(max_pixel_value=255.0, always_apply=True),
- ]
- )
- else:
- return A.Compose(
- [
- A.Resize(config.size, config.size, always_apply=True),
- A.Normalize(max_pixel_value=255.0, always_apply=True),
- ]
- )
-
-
- def make_dfs(config):
- train_dataframe = pd.read_csv(config.train_text_path)
- train_dataframe.dropna(subset=['text'], inplace=True)
- train_dataframe = train_dataframe.sample(frac=1).reset_index(drop=True)
- train_dataframe.label = train_dataframe.label.apply(lambda x: config.classes.index(x))
- config.class_weights = get_class_weights(train_dataframe.label.values)
-
- if config.test_text_path is None:
- offset = int(train_dataframe.shape[0] * 0.80)
- test_dataframe = train_dataframe[offset:]
- train_dataframe = train_dataframe[:offset]
- else:
- test_dataframe = pd.read_csv(config.test_text_path)
- test_dataframe.dropna(subset=['text'], inplace=True)
- test_dataframe = test_dataframe.sample(frac=1).reset_index(drop=True)
- test_dataframe.label = test_dataframe.label.apply(lambda x: config.classes.index(x))
-
- if config.validation_text_path is None:
- offset = int(train_dataframe.shape[0] * 0.90)
- validation_dataframe = train_dataframe[offset:]
- train_dataframe = train_dataframe[:offset]
- else:
- validation_dataframe = pd.read_csv(config.validation_text_path)
- validation_dataframe.dropna(subset=['text'], inplace=True)
- validation_dataframe = validation_dataframe.sample(frac=1).reset_index(drop=True)
- validation_dataframe.label = validation_dataframe.label.apply(lambda x: config.classes.index(x))
-
- unlabeled_dataframe = None
- if config.has_unlabeled_data:
- unlabeled_dataframe = pd.read_csv(config.unlabeled_text_path)
- unlabeled_dataframe = unlabeled_dataframe.sample(frac=1).reset_index(drop=True)
-
- return train_dataframe, test_dataframe, validation_dataframe, unlabeled_dataframe
-
-
- def get_tokenizer(config):
- if 'bigbird' in config.text_encoder_model:
- tokenizer = BigBirdTokenizer.from_pretrained(config.text_tokenizer)
- elif 'xlnet' in config.text_encoder_model:
- tokenizer = XLNetTokenizer.from_pretrained(config.text_tokenizer)
- else:
- tokenizer = BertTokenizer.from_pretrained(config.text_tokenizer)
- return tokenizer
-
-
- def build_loaders(config, dataframe, mode):
- if 'resnet' in config.image_model_name:
- transforms = get_transforms(config, mode=mode)
- else:
- transforms = ViTFeatureExtractor.from_pretrained(config.image_model_name)
-
- dataset = config.DatasetLoader(
- config,
- dataframe["image"].values,
- dataframe["text"].values,
- dataframe["label"].values,
- tokenizer=get_tokenizer(config),
- transforms=transforms,
- mode=mode,
- )
- if mode != 'train':
- dataloader = DataLoader(
- dataset,
- batch_size=config.batch_size,
- num_workers=config.num_workers,
- pin_memory=False,
- shuffle=False,
- )
- else:
- dataloader = DataLoader(
- dataset,
- batch_size=config.batch_size,
- num_workers=config.num_workers,
- pin_memory=True,
- shuffle=True,
- )
- return dataloader
-
-
- def get_class_weights(y):
- class_weights = compute_class_weight('balanced', np.unique(y), y)
- return class_weights
|