123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203 |
- from datetime import datetime
- import os
-
- experiment_name = 'Parham BS Project Region Prediction'
- experiment_code = experiment_name.replace(' - ', '.').replace(' ', '_').lower()
-
-
- import nltk
-
-
- # nltk.download('wordnet')
- # nltk.download('omw-1.4')
- # nltk.download('punkt')
-
-
- import json
- import numpy as np
-
- from torch.utils.data import Dataset, DataLoader
- from torch.nn import Module
- import torch
- import json
- from tqdm import tqdm
- from gensim.models import FastText
- from utils.sam import SAM
- from utils.bypass_bn import enable_running_stats, disable_running_stats
- from einops import reduce
- from utils.recipedb_dataset import RecipeDBDataset
-
- import logging
- import argparse
- from tqdm import tqdm
- import mlflow
- import mlflow.pytorch
- logging.basicConfig(level=logging.WARN)
- logger = logging.getLogger(__name__)
-
- from network import ImageTextTransformer
- from utils.io import load_config, save_config
- print("here")
-
- mlflow.set_experiment(experiment_name)
-
- parser = argparse.ArgumentParser()
- parser.add_argument('--config', type=str)
-
-
- args = parser.parse_args()
- config = load_config(args.config)
-
- epochs = config.optim.epochs
- batch_size = config.optim.batch_size
- learning_rate = config.optim.max_lr
- weight_decay = config.optim.weight_decay
- embedding_size = config.data.embedding_size
- num_classes = config.model.final_classes
- sam_rho = config.optim.sam_rho
- num_workers = config.optim.num_workers
- data_path = config.data.dataset_path
- target = config.data.target
- target_dictionary = json.load(open(os.path.join(data_path, f'{target}.json'), 'r'))
- if 'entropy' in config.optim:
- entropy_weight = config.optim.entropy
- else:
- entropy_weight = 0
- config.model.final_classes= len(target_dictionary)
- epsilon = 1e-8
- print(target)
- print(target_dictionary)
- output_dir = f'parham-models_image_taext_transformer/{target}/{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}'
- if not os.path.isdir(output_dir):
- os.makedirs(output_dir, exist_ok=True)
-
- class EmbedderFasttext():
- def __init__(self, path):
- self.model = FastText.load(path)
- print(f'sFastText Embedding Loaded:\n\t Embedding Size = {self.model.wv.vector_size}\n\t Vocabulary Size = {self.model.wv.vectors.shape[0]}')
-
- def has(self, word):
- if word == "":
- return False
- return True
-
- def get(self, word):
- words = word.split('_')
- out = np.zeros(self.model.wv.vector_size)
- n = len(words)
- if n == 0:
- raise ValueError('Empty string was given.')
- for item in words:
- out += self.model.wv.get_vector(item) / n
- return list(out)
- embedder = EmbedderFasttext(config.data.fasttext_path)
-
- datasets = {
- "train": RecipeDBDataset(os.path.join(data_path, 'train.json'),
- cousine_dict=target_dictionary,
- extract_ingredients=True, extract_recipes=True, extract_cousine=(target != 'category'),
- embedder=embedder, target=target, occr_path=os.path.join(data_path, "ingredient_counts.json"),
- mask_path=os.path.join(data_path, "ingredient_counts.json"), include_id=True, image_model = config.image_model),
-
- "val": RecipeDBDataset(os.path.join(data_path, "val.json"),
- cousine_dict=target_dictionary,
- extract_ingredients=True, extract_recipes=True, extract_cousine=(target != 'category'),
- embedder=embedder, target=target, occr_path=os.path.join(data_path, "ingredient_counts.json"),
- mask_path=os.path.join(data_path, "ingredient_counts.json"), include_id=True, image_model = config.image_model)
- }
- print('Dataset constructed.')
- print(len(datasets['train']), len(datasets['val']))
- print(f'target: {target}')
- print(f'number of classes: {len(target_dictionary)}')
-
- device = config.optim.device
-
- dataloaders = {
- "train":DataLoader(datasets["train"], batch_size=batch_size, collate_fn=datasets['train'].rdb_collate, shuffle=True, num_workers=num_workers),
- "val":DataLoader(datasets["val"], batch_size=batch_size, collate_fn=datasets['val'].rdb_collate, shuffle=False,num_workers=num_workers)
- }
- loss_fn = torch.nn.CrossEntropyLoss().to(device)
- print('Dataloader constructed.')
-
- model = ImageTextTransformer(config)
- print(model)
- model = model.to(device)
- optimizer = SAM(model.parameters(), rho=sam_rho, base_optimizer=torch.optim.Adam, lr=learning_rate/10, weight_decay=weight_decay)
- scheduler = torch.optim.lr_scheduler.OneCycleLR(max_lr = learning_rate, epochs=epochs, steps_per_epoch=len(dataloaders["train"]), optimizer=optimizer.base_optimizer)
-
- def stable_log_sigmoid(x):
- max_value = torch.maximum(x, torch.zeros(*x.shape, dtype=torch.float32, device=x.device))
- return -max_value - torch.log(torch.exp(-max_value) + torch.exp(x - max_value))
-
- def argtopk(tensor, k, dim):
- indices = torch.argsort(tensor, dim=dim, descending=True)
- topk_indices = indices.narrow(dim, 0, k)
- return topk_indices
-
- with mlflow.start_run():
- mlflow.log_params(dict(config))
- result = None
- best_val_acc = 0
- best_val_top3 = 0
- best_val_top5 = 0
- for epoch in range(epochs):
- for mode in ["train", "val"]:
- if mode == 'train':
- model.train()
- else:
- model.eval()
- running_loss = 0.0
- running_corrects = 0
- top_5_corrects = 0
- top_3_corrects = 0
- num_samples = 0
- s = 0
- for data_batch in tqdm(dataloaders[mode]):
- embeddings= data_batch['ingredients'].to(device)
- masks = data_batch['masks'].to(device)
- targets = data_batch['cousines'].to(device) if 'cousines' in data_batch else data_batch['targets'].to(device)
- image_ingredients = data_batch['image_ingredients'].to(device)
- recipe_embeddings = data_batch['recipe_embeddings'].to(device)
- with torch.set_grad_enabled(mode == 'train'):
- enable_running_stats(model)
- out = model(embeddings, masks, image_ingredients, recipe_embeddings)
- entropy = -torch.sum(torch.sigmoid(out) * stable_log_sigmoid(out)) / embeddings.shape[0]
- loss = loss_fn(out, targets) + entropy_weight * entropy
- if mode == 'train':
- loss.backward()
- optimizer.first_step(zero_grad=True)
- disable_running_stats(model)
- out = model(embeddings, masks, image_ingredients, recipe_embeddings)
- entropy = -torch.sum(torch.sigmoid(out) * stable_log_sigmoid(out)) / embeddings.shape[0]
- (loss_fn(out, targets) + entropy_weight * entropy).backward()
- optimizer.second_step(zero_grad=True)
- scheduler.step()
-
- running_loss+=loss.item()*embeddings.shape[0]
- running_corrects += (out.argmax(dim=1) == targets).sum().item()
- num_samples+=embeddings.shape[0]
- top_5_corrects += (argtopk(out, k=5, dim=1) == targets.unsqueeze(1)).sum().item()
- top_3_corrects += (argtopk(out, k=3, dim=1) == targets.unsqueeze(1)).sum().item()
- print(f"epoch: {epoch}, loss: {running_loss/num_samples}, acc: {running_corrects/num_samples}, top3: {top_3_corrects/num_samples}, top5: {top_5_corrects/num_samples}")
- if mode=="val":
- best_val_acc = running_corrects/num_samples*100 if running_corrects/num_samples*100 > best_val_acc else best_val_acc
- best_val_top3 = top_3_corrects/num_samples*100 if top_3_corrects/num_samples*100 > best_val_top3 else best_val_top3
- best_val_top5 = top_5_corrects/num_samples*100 if top_5_corrects/num_samples*100 > best_val_top5 else best_val_top5
- metrics = {
- '{}_loss'.format(mode): running_loss/num_samples,
- '{}_acc'.format(mode): running_corrects/num_samples*100,
- '{}_acc3'.format(mode): top_3_corrects/num_samples*100,
- '{}_acc5'.format(mode): top_5_corrects/num_samples*100
- }
- if mode == 'val':
- metrics["best_val_acc"] = best_val_acc
- metrics["best_val_acc3"] = best_val_top3
- metrics["best_val_acc5"] = best_val_top5
-
- result = running_corrects/num_samples*100
- mlflow.log_metrics(metrics)
- os.makedirs(output_dir, exist_ok=True)
- mlflow.pytorch.log_model(model, 'model')
- config.result = result
- torch.save(model.state_dict(), os.path.join(output_dir, "checkpoint.pth"))
- save_config(config, os.path.join(output_dir, "config.yml"))
|