from datetime import datetime import os experiment_name = 'Parham BS Project Region Prediction' experiment_code = experiment_name.replace(' - ', '.').replace(' ', '_').lower() import nltk # nltk.download('wordnet') # nltk.download('omw-1.4') # nltk.download('punkt') import json import numpy as np from torch.utils.data import Dataset, DataLoader from torch.nn import Module import torch import json from tqdm import tqdm from gensim.models import FastText from utils.sam import SAM from utils.bypass_bn import enable_running_stats, disable_running_stats from einops import reduce from utils.recipedb_dataset import RecipeDBDataset import logging import argparse from tqdm import tqdm import mlflow import mlflow.pytorch logging.basicConfig(level=logging.WARN) logger = logging.getLogger(__name__) from network import ImageTextTransformer from utils.io import load_config, save_config print("here") mlflow.set_experiment(experiment_name) parser = argparse.ArgumentParser() parser.add_argument('--config', type=str) args = parser.parse_args() config = load_config(args.config) epochs = config.optim.epochs batch_size = config.optim.batch_size learning_rate = config.optim.max_lr weight_decay = config.optim.weight_decay embedding_size = config.data.embedding_size num_classes = config.model.final_classes sam_rho = config.optim.sam_rho num_workers = config.optim.num_workers data_path = config.data.dataset_path target = config.data.target target_dictionary = json.load(open(os.path.join(data_path, f'{target}.json'), 'r')) if 'entropy' in config.optim: entropy_weight = config.optim.entropy else: entropy_weight = 0 config.model.final_classes= len(target_dictionary) epsilon = 1e-8 print(target) print(target_dictionary) output_dir = f'parham-models_image_taext_transformer/{target}/{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}' if not os.path.isdir(output_dir): os.makedirs(output_dir, exist_ok=True) class EmbedderFasttext(): def __init__(self, path): self.model = FastText.load(path) print(f'sFastText Embedding Loaded:\n\t Embedding Size = {self.model.wv.vector_size}\n\t Vocabulary Size = {self.model.wv.vectors.shape[0]}') def has(self, word): if word == "": return False return True def get(self, word): words = word.split('_') out = np.zeros(self.model.wv.vector_size) n = len(words) if n == 0: raise ValueError('Empty string was given.') for item in words: out += self.model.wv.get_vector(item) / n return list(out) embedder = EmbedderFasttext(config.data.fasttext_path) datasets = { "train": RecipeDBDataset(os.path.join(data_path, 'train.json'), cousine_dict=target_dictionary, extract_ingredients=True, extract_recipes=True, extract_cousine=(target != 'category'), embedder=embedder, target=target, occr_path=os.path.join(data_path, "ingredient_counts.json"), mask_path=os.path.join(data_path, "ingredient_counts.json"), include_id=True, image_model = config.image_model), "val": RecipeDBDataset(os.path.join(data_path, "val.json"), cousine_dict=target_dictionary, extract_ingredients=True, extract_recipes=True, extract_cousine=(target != 'category'), embedder=embedder, target=target, occr_path=os.path.join(data_path, "ingredient_counts.json"), mask_path=os.path.join(data_path, "ingredient_counts.json"), include_id=True, image_model = config.image_model) } print('Dataset constructed.') print(len(datasets['train']), len(datasets['val'])) print(f'target: {target}') print(f'number of classes: {len(target_dictionary)}') device = config.optim.device dataloaders = { "train":DataLoader(datasets["train"], batch_size=batch_size, collate_fn=datasets['train'].rdb_collate, shuffle=True, num_workers=num_workers), "val":DataLoader(datasets["val"], batch_size=batch_size, collate_fn=datasets['val'].rdb_collate, shuffle=False,num_workers=num_workers) } loss_fn = torch.nn.CrossEntropyLoss().to(device) print('Dataloader constructed.') model = ImageTextTransformer(config) print(model) model = model.to(device) optimizer = SAM(model.parameters(), rho=sam_rho, base_optimizer=torch.optim.Adam, lr=learning_rate/10, weight_decay=weight_decay) scheduler = torch.optim.lr_scheduler.OneCycleLR(max_lr = learning_rate, epochs=epochs, steps_per_epoch=len(dataloaders["train"]), optimizer=optimizer.base_optimizer) def stable_log_sigmoid(x): max_value = torch.maximum(x, torch.zeros(*x.shape, dtype=torch.float32, device=x.device)) return -max_value - torch.log(torch.exp(-max_value) + torch.exp(x - max_value)) def argtopk(tensor, k, dim): indices = torch.argsort(tensor, dim=dim, descending=True) topk_indices = indices.narrow(dim, 0, k) return topk_indices with mlflow.start_run(): mlflow.log_params(dict(config)) result = None best_val_acc = 0 best_val_top3 = 0 best_val_top5 = 0 for epoch in range(epochs): for mode in ["train", "val"]: if mode == 'train': model.train() else: model.eval() running_loss = 0.0 running_corrects = 0 top_5_corrects = 0 top_3_corrects = 0 num_samples = 0 s = 0 for data_batch in tqdm(dataloaders[mode]): embeddings= data_batch['ingredients'].to(device) masks = data_batch['masks'].to(device) targets = data_batch['cousines'].to(device) if 'cousines' in data_batch else data_batch['targets'].to(device) image_ingredients = data_batch['image_ingredients'].to(device) recipe_embeddings = data_batch['recipe_embeddings'].to(device) with torch.set_grad_enabled(mode == 'train'): enable_running_stats(model) out = model(embeddings, masks, image_ingredients, recipe_embeddings) entropy = -torch.sum(torch.sigmoid(out) * stable_log_sigmoid(out)) / embeddings.shape[0] loss = loss_fn(out, targets) + entropy_weight * entropy if mode == 'train': loss.backward() optimizer.first_step(zero_grad=True) disable_running_stats(model) out = model(embeddings, masks, image_ingredients, recipe_embeddings) entropy = -torch.sum(torch.sigmoid(out) * stable_log_sigmoid(out)) / embeddings.shape[0] (loss_fn(out, targets) + entropy_weight * entropy).backward() optimizer.second_step(zero_grad=True) scheduler.step() running_loss+=loss.item()*embeddings.shape[0] running_corrects += (out.argmax(dim=1) == targets).sum().item() num_samples+=embeddings.shape[0] top_5_corrects += (argtopk(out, k=5, dim=1) == targets.unsqueeze(1)).sum().item() top_3_corrects += (argtopk(out, k=3, dim=1) == targets.unsqueeze(1)).sum().item() print(f"epoch: {epoch}, loss: {running_loss/num_samples}, acc: {running_corrects/num_samples}, top3: {top_3_corrects/num_samples}, top5: {top_5_corrects/num_samples}") if mode=="val": best_val_acc = running_corrects/num_samples*100 if running_corrects/num_samples*100 > best_val_acc else best_val_acc best_val_top3 = top_3_corrects/num_samples*100 if top_3_corrects/num_samples*100 > best_val_top3 else best_val_top3 best_val_top5 = top_5_corrects/num_samples*100 if top_5_corrects/num_samples*100 > best_val_top5 else best_val_top5 metrics = { '{}_loss'.format(mode): running_loss/num_samples, '{}_acc'.format(mode): running_corrects/num_samples*100, '{}_acc3'.format(mode): top_3_corrects/num_samples*100, '{}_acc5'.format(mode): top_5_corrects/num_samples*100 } if mode == 'val': metrics["best_val_acc"] = best_val_acc metrics["best_val_acc3"] = best_val_top3 metrics["best_val_acc5"] = best_val_top5 result = running_corrects/num_samples*100 mlflow.log_metrics(metrics) os.makedirs(output_dir, exist_ok=True) mlflow.pytorch.log_model(model, 'model') config.result = result torch.save(model.state_dict(), os.path.join(output_dir, "checkpoint.pth")) save_config(config, os.path.join(output_dir, "config.yml"))