Browse Source

lime added

master
FaezeGhorbanpour 2 years ago
parent
commit
47c42ec9e8
14 changed files with 219 additions and 195 deletions
  1. 2
    2
      data/config.py
  2. 42
    11
      data/data_loader.py
  3. 12
    8
      data/twitter/config.py
  4. 4
    37
      data/twitter/data_loader.py
  5. 16
    14
      data/weibo/config.py
  6. 4
    36
      data/weibo/data_loader.py
  7. 4
    51
      data_loaders.py
  8. 110
    0
      lime_main.py
  9. 7
    6
      main.py
  10. 4
    26
      model.py
  11. 2
    1
      optuna_main.py
  12. 1
    1
      test_main.py
  13. 1
    1
      torch_main.py
  14. 10
    1
      utils.py

+ 2
- 2
data/config.py View File

@@ -59,9 +59,9 @@ class Config:
size = 224

num_projection_layers = 1
projection_size = 256
projection_size = 64
dropout = 0.3
hidden_size = 256
hidden_size = 128
num_region = 64 # 16 #TODO
region_size = projection_size // num_region


+ 42
- 11
data/data_loader.py View File

@@ -1,17 +1,48 @@
import torch
from torch.utils.data import Dataset
import albumentations as A
from transformers import ViTFeatureExtractor
from transformers import BertTokenizer, BigBirdTokenizer, XLNetTokenizer


def get_transforms(config):
return A.Compose(
[
A.Resize(config.size, config.size, always_apply=True),
A.Normalize(max_pixel_value=255.0, always_apply=True),
]
)


def get_tokenizer(config):
if 'bigbird' in config.text_encoder_model:
tokenizer = BigBirdTokenizer.from_pretrained(config.text_tokenizer)
elif 'xlnet' in config.text_encoder_model:
tokenizer = XLNetTokenizer.from_pretrained(config.text_tokenizer)
else:
tokenizer = BertTokenizer.from_pretrained(config.text_tokenizer)
return tokenizer


class DatasetLoader(Dataset):
def __init__(self, config, image_filenames, texts, labels, tokenizer, transforms, mode):
def __init__(self, config, dataframe, mode):
self.config = config
self.image_filenames = image_filenames
self.text = list(texts)
self.encoded_text = tokenizer(
list(texts), padding=True, truncation=True, max_length=config.max_length, return_tensors='pt'
)
self.labels = labels
self.transforms = transforms
self.mode = mode
if mode == 'lime':
self.image_filenames = [dataframe["image"],]
self.text = [dataframe["text"],]
self.labels = [dataframe["label"],]
else:
self.image_filenames = dataframe["image"].values
self.text = list(dataframe["text"].values)
self.labels = dataframe["label"].values

tokenizer = get_tokenizer(config)
self.encoded_text = tokenizer(self.text, padding=True, truncation=True, max_length=config.max_length, return_tensors='pt')
if 'resnet' in config.image_model_name:
self.transforms = get_transforms(config)
else:
self.transforms = ViTFeatureExtractor.from_pretrained(config.image_model_name)

def set_text(self, idx):
item = {
@@ -23,11 +54,11 @@ class DatasetLoader(Dataset):
item['id'] = idx
return item

def set_image(self, image, item):
def set_image(self, image):
if 'resnet' in self.config.image_model_name:
image = self.transforms(image=image)['image']
item['image'] = torch.as_tensor(image).reshape((3, 224, 224))
return {'image': torch.as_tensor(image).reshape((3, 224, 224))}
else:
image = self.transforms(images=image, return_tensors='pt')
image = image.convert_to_tensors(tensor_type='pt')['pixel_values']
item['image'] = image.reshape((3, 224, 224))
return {'image': image.reshape((3, 224, 224))}

+ 12
- 8
data/twitter/config.py View File

@@ -8,15 +8,17 @@ class TwitterConfig(Config):
name = 'twitter'
DatasetLoader = TwitterDatasetLoader

data_path = '/twitter/'
output_path = ''
data_path = '../../../../../media/external_3TB/3TB/ghorbanpoor/twitter/'
# data_path = '/home/faeze/PycharmProjects/fake_news_detection/data/twitter/'
output_path = '../../../../../media/external_10TB/10TB/ghorbanpoor/'
# output_path = ''

train_image_path = data_path + 'images_train/'
validation_image_path = data_path + 'images_validation/'
validation_image_path = data_path + 'images_test/'
test_image_path = data_path + 'images_test/'

train_text_path = data_path + 'twitter_train_translated.csv'
validation_text_path = data_path + 'twitter_validation_translated.csv'
validation_text_path = data_path + 'twitter_test_translated.csv'
test_text_path = data_path + 'twitter_test_translated.csv'

batch_size = 128
@@ -32,12 +34,14 @@ class TwitterConfig(Config):
attention_weight_decay = 0.001
classification_weight_decay = 0.001

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

image_model_name = 'vit-base-patch16-224'
image_model_name = '../../../../../media/external_10TB/10TB/ghorbanpoor/vit-base-patch16-224'
image_embedding = 768
text_encoder_model = "bert-base-uncased"
text_tokenizer = "bert-base-uncased"
text_encoder_model = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased"
# text_encoder_model = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased"
text_tokenizer = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased"
# text_tokenizer = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased"
text_embedding = 768
max_length = 32


+ 4
- 37
data/twitter/data_loader.py View File

@@ -10,7 +10,10 @@ from data.data_loader import DatasetLoader
class TwitterDatasetLoader(DatasetLoader):
def __getitem__(self, idx):
item = self.set_text(idx)
item.update(self.set_img(idx))
return item

def set_img(self, idx):
try:
if self.mode == 'train':
image = cv2.imread(f"{self.config.train_image_path}/{self.image_filenames[idx]}")
@@ -22,44 +25,8 @@ class TwitterDatasetLoader(DatasetLoader):
except:
image = Image.open(f"{self.config.train_image_path}/{self.image_filenames[idx]}").convert('RGB')
image = asarray(image)
self.set_image(image, item)
return item

return self.set_image(image)

def __len__(self):
return len(self.text)


class TwitterEmbeddingDatasetLoader(DatasetLoader):
def __init__(self, config, image_filenames, texts, labels, tokenizer, transforms, mode):
super().__init__(config, image_filenames, texts, labels, tokenizer, transforms, mode)
self.config = config
self.image_filenames = image_filenames
self.text = list(texts)
self.encoded_text = tokenizer
self.labels = labels
self.transforms = transforms
self.mode = mode

if mode == 'train':
with open(config.train_image_embedding_path, 'rb') as handle:
self.image_embedding = pickle.load(handle)
with open(config.train_text_embedding_path, 'rb') as handle:
self.text_embedding = pickle.load(handle)
else:
with open(config.validation_image_embedding_path, 'rb') as handle:
self.image_embedding = pickle.load(handle)
with open(config.validation_text_embedding_path, 'rb') as handle:
self.text_embedding = pickle.load(handle)


def __getitem__(self, idx):
item = dict()
item['id'] = idx
item['image_embedding'] = self.image_embedding[idx]
item['text_embedding'] = self.text_embedding[idx]
item['label'] = self.labels[idx]
return item

def __len__(self):
return len(self.text)

+ 16
- 14
data/weibo/config.py View File

@@ -9,32 +9,35 @@ class WeiboConfig(Config):
name = 'weibo'
DatasetLoader = WeiboDatasetLoader

data_path = 'weibo/'
output_path = ''
data_path = '../../../../../media/external_3TB/3TB/ghorbanpoor/weibo/'
# data_path = '/home/faeze/PycharmProjects/fake_news_detection/data/weibo/'
output_path = '../../../../../media/external_10TB/10TB/ghorbanpoor/'
# output_path = ''

rumor_image_path = data_path + 'rumor_images/'
nonrumor_image_path = data_path + 'nonrumor_images/'

train_text_path = data_path + 'weibo_train.csv'
validation_text_path = data_path + 'weibo_validation.csv'
validation_text_path = data_path + 'weibo_train.csv'
test_text_path = data_path + 'weibo_test.csv'

batch_size = 128
batch_size = 64
epochs = 100
num_workers = 2
num_workers = 4
head_lr = 1e-03
image_encoder_lr = 1e-02
text_encoder_lr = 1e-05
weight_decay = 0.001
classification_lr = 1e-02

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

image_model_name = 'vit-base-patch16-224' # 'resnet101'
image_embedding = 768 # 2048
num_img_region = 64 # TODO
text_encoder_model = "bert-base-chinese"
text_tokenizer = "bert-base-chinese"
image_model_name = '../../../../../media/external_10TB/10TB/ghorbanpoor/vit-base-patch16-224'
image_embedding = 768
text_encoder_model = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased"
# text_encoder_model = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased"
text_tokenizer = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased"
# text_tokenizer = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased"
text_embedding = 768
max_length = 200

@@ -53,10 +56,9 @@ class WeiboConfig(Config):
self.classification_lr = trial.suggest_loguniform('classification_lr', 1e-5, 1e-1)

self.head_weight_decay = trial.suggest_loguniform('head_weight_decay', 1e-5, 1e-1)
# self.attention_weight_decay = trial.suggest_loguniform('attention_weight_decay', 1e-5, 1e-1)
self.classification_weight_decay = trial.suggest_loguniform('classification_weight_decay', 1e-5, 1e-1)

self.projection_size = trial.suggest_categorical('projection_size', [256, 128, 64])
# self.hidden_size = trial.suggest_categorical('hidden_size', [256, 128, 64, ])
# self.dropout = trial.suggest_categorical('drop_out', [0.1, 0.3, 0.5, ])
self.hidden_size = trial.suggest_categorical('hidden_size', [256, 128, 64, ])
self.dropout = trial.suggest_categorical('drop_out', [0.1, 0.3, 0.5, ])


+ 4
- 36
data/weibo/data_loader.py View File

@@ -11,7 +11,10 @@ from data.data_loader import DatasetLoader
class WeiboDatasetLoader(DatasetLoader):
def __getitem__(self, idx):
item = self.set_text(idx)
item.update(self.set_img(idx))
return item

def set_img(self, idx):
if self.labels[idx] == 1:
try:
image = cv2.imread(f"{self.config.rumor_image_path}/{self.image_filenames[idx]}")
@@ -24,43 +27,8 @@ class WeiboDatasetLoader(DatasetLoader):
except:
image = Image.open(f"{self.config.nonrumor_image_path}/{self.image_filenames[idx]}").convert('RGB')
image = asarray(image)
self.set_image(image, item)
return item
return self.set_image(image)

def __len__(self):
return len(self.text)


class WeiboEmbeddingDatasetLoader(DatasetLoader):
def __init__(self, config, image_filenames, texts, labels, tokenizer, transforms, mode):
super().__init__(config, image_filenames, texts, labels, tokenizer, transforms, mode)
self.config = config
self.image_filenames = image_filenames
self.text = list(texts)
self.encoded_text = tokenizer
self.labels = labels
self.transforms = transforms
self.mode = mode

if mode == 'train':
with open(config.train_image_embedding_path, 'rb') as handle:
self.image_embedding = pickle.load(handle)
with open(config.train_text_embedding_path, 'rb') as handle:
self.text_embedding = pickle.load(handle)
else:
with open(config.validation_image_embedding_path, 'rb') as handle:
self.image_embedding = pickle.load(handle)
with open(config.validation_text_embedding_path, 'rb') as handle:
self.text_embedding = pickle.load(handle)


def __getitem__(self, idx):
item = dict()
item['id'] = idx
item['image_embedding'] = self.image_embedding[idx]
item['text_embedding'] = self.text_embedding[idx]
item['label'] = self.labels[idx]
return item

def __len__(self):
return len(self.text)

+ 4
- 51
data_loaders.py View File

@@ -1,27 +1,8 @@
import albumentations as A
import numpy as np
import pandas as pd

from sklearn.utils import compute_class_weight
from torch.utils.data import DataLoader
from transformers import ViTFeatureExtractor, BertTokenizer, BigBirdTokenizer, XLNetTokenizer


def get_transforms(config, mode="train"):
if mode == "train":
return A.Compose(
[
A.Resize(config.size, config.size, always_apply=True),
A.Normalize(max_pixel_value=255.0, always_apply=True),
]
)
else:
return A.Compose(
[
A.Resize(config.size, config.size, always_apply=True),
A.Normalize(max_pixel_value=255.0, always_apply=True),
]
)


def make_dfs(config):
@@ -51,44 +32,16 @@ def make_dfs(config):
validation_dataframe = validation_dataframe.sample(frac=1).reset_index(drop=True)
validation_dataframe.label = validation_dataframe.label.apply(lambda x: config.classes.index(x))

unlabeled_dataframe = None
if config.has_unlabeled_data:
unlabeled_dataframe = pd.read_csv(config.unlabeled_text_path)
unlabeled_dataframe = unlabeled_dataframe.sample(frac=1).reset_index(drop=True)

return train_dataframe, test_dataframe, validation_dataframe, unlabeled_dataframe


def get_tokenizer(config):
if 'bigbird' in config.text_encoder_model:
tokenizer = BigBirdTokenizer.from_pretrained(config.text_tokenizer)
elif 'xlnet' in config.text_encoder_model:
tokenizer = XLNetTokenizer.from_pretrained(config.text_tokenizer)
else:
tokenizer = BertTokenizer.from_pretrained(config.text_tokenizer)
return tokenizer
return train_dataframe, test_dataframe, validation_dataframe


def build_loaders(config, dataframe, mode):
if 'resnet' in config.image_model_name:
transforms = get_transforms(config, mode=mode)
else:
transforms = ViTFeatureExtractor.from_pretrained(config.image_model_name)

dataset = config.DatasetLoader(
config,
dataframe["image"].values,
dataframe["text"].values,
dataframe["label"].values,
tokenizer=get_tokenizer(config),
transforms=transforms,
mode=mode,
)
dataset = config.DatasetLoader(config, dataframe=dataframe, mode=mode)
if mode != 'train':
dataloader = DataLoader(
dataset,
batch_size=config.batch_size,
num_workers=config.num_workers,
batch_size=config.batch_size // 2 if mode == 'lime' else 1,
num_workers=config.num_workers // 2,
pin_memory=False,
shuffle=False,
)

+ 110
- 0
lime_main.py View File

@@ -0,0 +1,110 @@
import random

import numpy as np
import pandas as pd
import torch
from matplotlib import pyplot as plt
from tqdm import tqdm

from data_loaders import make_dfs, build_loaders
from learner import batch_constructor
from model import FakeNewsModel

from lime.lime_image import LimeImageExplainer
from skimage.segmentation import mark_boundaries
from lime.lime_text import LimeTextExplainer

from utils import make_directory


def lime_(config, test_loader, model):
for i, batch in enumerate(test_loader):
batch = batch_constructor(config, batch)
with torch.no_grad():
output, score = model(batch)

score = score.detach().cpu().numpy()

logit = model.logits.detach().cpu().numpy()

return score, logit


def lime_main(config, trial_number=None):
_, test_df, _ = make_dfs(config, )
test_df = test_df[:1]

if trial_number:
try:
checkpoint = torch.load(str(config.output_path) + '/checkpoint_' + str(trial_number) + '.pt')
except:
checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt')
else:
checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt', map_location=torch.device(config.device))

try:
parameters = checkpoint['parameters']
config.assign_hyperparameters(parameters)
except:
pass

model = FakeNewsModel(config).to(config.device)
try:
model.load_state_dict(checkpoint['model_state_dict'])
except:
model.load_state_dict(checkpoint)

model.eval()

torch.manual_seed(27)
random.seed(27)
np.random.seed(27)

text_explainer = LimeTextExplainer(class_names=config.classes)
image_explainer = LimeImageExplainer()

make_directory(config.output_path + '/lime/')
make_directory(config.output_path + '/lime/score/')
make_directory(config.output_path + '/lime/logit/')
make_directory(config.output_path + '/lime/text/')
make_directory(config.output_path + '/lime/image/')

def text_predict_proba(text):
scores = []
print(len(text))
for i in text:
row['text'] = i
test_loader = build_loaders(config, row, mode="lime")
score, _ = lime_(config, test_loader, model)
scores.append(score)
return np.array(scores)

def image_predict_proba(image):
scores = []
print(len(image))
for i in image:
test_loader = build_loaders(config, row, mode="lime")
test_loader['image'] = i.reshape((3, 224, 224))
score, _ = lime_(config, test_loader, model)
scores.append(score)
return np.array(scores)

for i, row in test_df.iterrows():
test_loader = build_loaders(config, row, mode="lime")
score, logit = lime_(config, test_loader, model)
np.savetxt(config.output_path + '/lime/score/' + str(i) + '.csv', score, delimiter=",")
np.savetxt(config.output_path + '/lime/logit/' + str(i) + '.csv', logit, delimiter=",")

text_exp = text_explainer.explain_instance(row['text'], text_predict_proba, num_features=5)
text_exp.save_to_file(config.output_path + '/lime/text/' + str(i) + '.html')
print('text', i, 'finished')

data_items = config.DatasetLoader(config, dataframe=row, mode='lime').__getitem__(0)
img_exp = image_explainer.explain_instance(data_items['image'].reshape((224, 224, 3)), image_predict_proba,
top_labels=2, hide_color=0, num_samples=1)
temp, mask = img_exp.get_image_and_mask(img_exp.top_labels[0], positive_only=False, num_features=5,
hide_rest=False)
img_boundry = mark_boundaries(temp / 255.0, mask)
plt.imshow(img_boundry)
plt.savefig(config.output_path + '/lime/image/' + str(i) + '.png')
print('image', i, 'finished')

+ 7
- 6
main.py View File

@@ -1,5 +1,6 @@
import argparse

from lime_main import lime_main
from optuna_main import optuna_main
from test_main import test_main
from torch_main import torch_main
@@ -7,6 +8,8 @@ from data.twitter.config import TwitterConfig
from data.weibo.config import WeiboConfig
from torch.utils.tensorboard import SummaryWriter

from utils import make_directory

if __name__ == '__main__':
import os

@@ -15,6 +18,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, required=True)
parser.add_argument('--use_optuna', type=int, required=False)
parser.add_argument('--use_lime', type=int, required=False)
parser.add_argument('--just_test', type=int, required=False)
parser.add_argument('--batch', type=int, required=False)
parser.add_argument('--epoch', type=int, required=False)
@@ -44,12 +48,7 @@ if __name__ == '__main__':
config.output_path = str(args.extra)
use_optuna = False

try:
os.mkdir(config.output_path)
except OSError:
print("Creation of the directory failed")
else:
print("Successfully created the directory")
make_directory(config.output_path)

config.writer = SummaryWriter(config.output_path)

@@ -57,6 +56,8 @@ if __name__ == '__main__':
optuna_main(config, args.use_optuna)
elif args.just_test:
test_main(config, args.just_test)
elif args.use_lime:
lime_main(config, args.use_lime)
else:
torch_main(config)


+ 4
- 26
model.py View File

@@ -66,7 +66,6 @@ class FakeNewsModel(nn.Module):
self.image_projection = ProjectionHead(config, embedding_dim=config.image_embedding)
self.text_projection = ProjectionHead(config, embedding_dim=config.text_embedding)
self.classifier = Classifier(config)
self.temperature = config.temperature
class_weights = torch.FloatTensor(config.class_weights)
self.classifier_loss_function = torch.nn.CrossEntropyLoss(weight=class_weights, reduction='mean')

@@ -82,36 +81,15 @@ class FakeNewsModel(nn.Module):
self.text_embeddings = self.text_projection(text_features)

self.multimodal_embeddings = torch.cat((self.image_embeddings, self.text_embeddings), dim=1)
score = self.classifier(self.multimodal_embeddings)
probs, output = torch.max(score.data, dim=1)

return output, score


class FakeNewsEmbeddingModel(nn.Module):
def __init__(
self, config
):
super().__init__()
self.config = config
self.image_projection = ProjectionHead(config, embedding_dim=config.image_embedding)
self.text_projection = ProjectionHead(config, embedding_dim=config.text_embedding)
self.classifier = Classifier(config)
self.temperature = config.temperature
class_weights = torch.FloatTensor(config.class_weights)
self.classifier_loss_function = torch.nn.CrossEntropyLoss(weight=class_weights, reduction='mean')

def forward(self, batch):
self.image_embeddings = self.image_projection(batch['image_embedding'])
self.text_embeddings = self.text_projection(batch['text_embedding'])
self.logits = (self.text_embeddings @ self.image_embeddings.T)

self.multimodal_embeddings = torch.cat((self.image_embeddings, self.text_embeddings), dim=1)
score = self.classifier(self.multimodal_embeddings)
probs, output = torch.max(score.data, dim=1)

return output, score



def calculate_loss(model, score, label):
similarity = calculate_similarity_loss(model.config, model.image_embeddings, model.text_embeddings)
# fake = (label == 1).nonzero()
@@ -124,10 +102,10 @@ def calculate_loss(model, score, label):

def calculate_similarity_loss(config, image_embeddings, text_embeddings):
# Calculating the Loss
logits = (text_embeddings @ image_embeddings.T) / config.temperature
logits = (text_embeddings @ image_embeddings.T)
images_similarity = image_embeddings @ image_embeddings.T
texts_similarity = text_embeddings @ text_embeddings.T
targets = F.softmax((images_similarity + texts_similarity) / 2 * config.temperature, dim=-1)
targets = F.softmax((images_similarity + texts_similarity) / 2, dim=-1)
texts_loss = cross_entropy(logits, targets, reduction='none')
images_loss = cross_entropy(logits.T, targets.T, reduction='none')
loss = (images_loss + texts_loss) / 2.0

+ 2
- 1
optuna_main.py View File

@@ -18,7 +18,7 @@ def objective(trial, config, train_loader, validation_loader):


def optuna_main(config, n_trials=100):
train_df, test_df, validation_df, _ = make_dfs(config)
train_df, test_df, validation_df = make_dfs(config)
train_loader = build_loaders(config, train_df, mode="train")
validation_loader = build_loaders(config, validation_df, mode="validation")
test_loader = build_loaders(config, test_df, mode="test")
@@ -47,6 +47,7 @@ def optuna_main(config, n_trials=100):
trial = study.best_trial
print(' Number: ', trial.number)
print(" Value: ", trial.value)
s += 'number: ' + str(trial.number) + '\n'
s += 'value: ' + str(trial.value) + '\n'

print(" Params: ")

+ 1
- 1
test_main.py View File

@@ -106,6 +106,6 @@ def test(config, test_loader, trial_number=None):


def test_main(config, trial_number=None):
train_df, test_df, validation_df, _ = make_dfs(config, )
train_df, test_df, validation_df = make_dfs(config, )
test_loader = build_loaders(config, test_df, mode="test")
test(config, test_loader, trial_number)

+ 1
- 1
torch_main.py View File

@@ -5,7 +5,7 @@ from test_main import test


def torch_main(config):
train_df, test_df, validation_df, _ = make_dfs(config, )
train_df, test_df, validation_df = make_dfs(config, )
train_loader = build_loaders(config, train_df, mode="train")
validation_loader = build_loaders(config, validation_df, mode="validation")
test_loader = build_loaders(config, test_df, mode="test")

+ 10
- 1
utils.py View File

@@ -1,6 +1,6 @@
import numpy as np
import torch
import os

class AvgMeter:
def __init__(self, name="Metric"):
@@ -91,3 +91,12 @@ class EarlyStopping:
# self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Model saved ...')
# torch.save(model.state_dict(), self.path)
# self.val_loss_min = val_loss


def make_directory(path):
try:
os.mkdir(path)
except OSError:
print("Creation of the directory failed")
else:
print("Successfully created the directory")

Loading…
Cancel
Save