BSc project of Parham Saremi. The goal of the project was to detect the geographical region of the food using textual and visual features extracted from recipes and ingredients of the food.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 8.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. from datetime import datetime
  2. import os
  3. experiment_name = 'Parham BS Project Region Prediction'
  4. experiment_code = experiment_name.replace(' - ', '.').replace(' ', '_').lower()
  5. import nltk
  6. # nltk.download('wordnet')
  7. # nltk.download('omw-1.4')
  8. # nltk.download('punkt')
  9. import json
  10. import numpy as np
  11. from torch.utils.data import Dataset, DataLoader
  12. from torch.nn import Module
  13. import torch
  14. import json
  15. from tqdm import tqdm
  16. from gensim.models import FastText
  17. from utils.sam import SAM
  18. from utils.bypass_bn import enable_running_stats, disable_running_stats
  19. from einops import reduce
  20. from utils.recipedb_dataset import RecipeDBDataset
  21. import logging
  22. import argparse
  23. from tqdm import tqdm
  24. import mlflow
  25. import mlflow.pytorch
  26. logging.basicConfig(level=logging.WARN)
  27. logger = logging.getLogger(__name__)
  28. from network import ImageTextTransformer
  29. from utils.io import load_config, save_config
  30. print("here")
  31. mlflow.set_experiment(experiment_name)
  32. parser = argparse.ArgumentParser()
  33. parser.add_argument('--config', type=str)
  34. args = parser.parse_args()
  35. config = load_config(args.config)
  36. epochs = config.optim.epochs
  37. batch_size = config.optim.batch_size
  38. learning_rate = config.optim.max_lr
  39. weight_decay = config.optim.weight_decay
  40. embedding_size = config.data.embedding_size
  41. num_classes = config.model.final_classes
  42. sam_rho = config.optim.sam_rho
  43. num_workers = config.optim.num_workers
  44. data_path = config.data.dataset_path
  45. target = config.data.target
  46. target_dictionary = json.load(open(os.path.join(data_path, f'{target}.json'), 'r'))
  47. if 'entropy' in config.optim:
  48. entropy_weight = config.optim.entropy
  49. else:
  50. entropy_weight = 0
  51. config.model.final_classes= len(target_dictionary)
  52. epsilon = 1e-8
  53. print(target)
  54. print(target_dictionary)
  55. output_dir = f'parham-models_image_taext_transformer/{target}/{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}'
  56. if not os.path.isdir(output_dir):
  57. os.makedirs(output_dir, exist_ok=True)
  58. class EmbedderFasttext():
  59. def __init__(self, path):
  60. self.model = FastText.load(path)
  61. print(f'sFastText Embedding Loaded:\n\t Embedding Size = {self.model.wv.vector_size}\n\t Vocabulary Size = {self.model.wv.vectors.shape[0]}')
  62. def has(self, word):
  63. if word == "":
  64. return False
  65. return True
  66. def get(self, word):
  67. words = word.split('_')
  68. out = np.zeros(self.model.wv.vector_size)
  69. n = len(words)
  70. if n == 0:
  71. raise ValueError('Empty string was given.')
  72. for item in words:
  73. out += self.model.wv.get_vector(item) / n
  74. return list(out)
  75. embedder = EmbedderFasttext(config.data.fasttext_path)
  76. datasets = {
  77. "train": RecipeDBDataset(os.path.join(data_path, 'train.json'),
  78. cousine_dict=target_dictionary,
  79. extract_ingredients=True, extract_recipes=True, extract_cousine=(target != 'category'),
  80. embedder=embedder, target=target, occr_path=os.path.join(data_path, "ingredient_counts.json"),
  81. mask_path=os.path.join(data_path, "ingredient_counts.json"), include_id=True, image_model = config.image_model),
  82. "val": RecipeDBDataset(os.path.join(data_path, "val.json"),
  83. cousine_dict=target_dictionary,
  84. extract_ingredients=True, extract_recipes=True, extract_cousine=(target != 'category'),
  85. embedder=embedder, target=target, occr_path=os.path.join(data_path, "ingredient_counts.json"),
  86. mask_path=os.path.join(data_path, "ingredient_counts.json"), include_id=True, image_model = config.image_model)
  87. }
  88. print('Dataset constructed.')
  89. print(len(datasets['train']), len(datasets['val']))
  90. print(f'target: {target}')
  91. print(f'number of classes: {len(target_dictionary)}')
  92. device = config.optim.device
  93. dataloaders = {
  94. "train":DataLoader(datasets["train"], batch_size=batch_size, collate_fn=datasets['train'].rdb_collate, shuffle=True, num_workers=num_workers),
  95. "val":DataLoader(datasets["val"], batch_size=batch_size, collate_fn=datasets['val'].rdb_collate, shuffle=False,num_workers=num_workers)
  96. }
  97. loss_fn = torch.nn.CrossEntropyLoss().to(device)
  98. print('Dataloader constructed.')
  99. model = ImageTextTransformer(config)
  100. print(model)
  101. model = model.to(device)
  102. optimizer = SAM(model.parameters(), rho=sam_rho, base_optimizer=torch.optim.Adam, lr=learning_rate/10, weight_decay=weight_decay)
  103. scheduler = torch.optim.lr_scheduler.OneCycleLR(max_lr = learning_rate, epochs=epochs, steps_per_epoch=len(dataloaders["train"]), optimizer=optimizer.base_optimizer)
  104. def stable_log_sigmoid(x):
  105. max_value = torch.maximum(x, torch.zeros(*x.shape, dtype=torch.float32, device=x.device))
  106. return -max_value - torch.log(torch.exp(-max_value) + torch.exp(x - max_value))
  107. def argtopk(tensor, k, dim):
  108. indices = torch.argsort(tensor, dim=dim, descending=True)
  109. topk_indices = indices.narrow(dim, 0, k)
  110. return topk_indices
  111. with mlflow.start_run():
  112. mlflow.log_params(dict(config))
  113. result = None
  114. best_val_acc = 0
  115. best_val_top3 = 0
  116. best_val_top5 = 0
  117. for epoch in range(epochs):
  118. for mode in ["train", "val"]:
  119. if mode == 'train':
  120. model.train()
  121. else:
  122. model.eval()
  123. running_loss = 0.0
  124. running_corrects = 0
  125. top_5_corrects = 0
  126. top_3_corrects = 0
  127. num_samples = 0
  128. s = 0
  129. for data_batch in tqdm(dataloaders[mode]):
  130. embeddings= data_batch['ingredients'].to(device)
  131. masks = data_batch['masks'].to(device)
  132. targets = data_batch['cousines'].to(device) if 'cousines' in data_batch else data_batch['targets'].to(device)
  133. image_ingredients = data_batch['image_ingredients'].to(device)
  134. recipe_embeddings = data_batch['recipe_embeddings'].to(device)
  135. with torch.set_grad_enabled(mode == 'train'):
  136. enable_running_stats(model)
  137. out = model(embeddings, masks, image_ingredients, recipe_embeddings)
  138. entropy = -torch.sum(torch.sigmoid(out) * stable_log_sigmoid(out)) / embeddings.shape[0]
  139. loss = loss_fn(out, targets) + entropy_weight * entropy
  140. if mode == 'train':
  141. loss.backward()
  142. optimizer.first_step(zero_grad=True)
  143. disable_running_stats(model)
  144. out = model(embeddings, masks, image_ingredients, recipe_embeddings)
  145. entropy = -torch.sum(torch.sigmoid(out) * stable_log_sigmoid(out)) / embeddings.shape[0]
  146. (loss_fn(out, targets) + entropy_weight * entropy).backward()
  147. optimizer.second_step(zero_grad=True)
  148. scheduler.step()
  149. running_loss+=loss.item()*embeddings.shape[0]
  150. running_corrects += (out.argmax(dim=1) == targets).sum().item()
  151. num_samples+=embeddings.shape[0]
  152. top_5_corrects += (argtopk(out, k=5, dim=1) == targets.unsqueeze(1)).sum().item()
  153. top_3_corrects += (argtopk(out, k=3, dim=1) == targets.unsqueeze(1)).sum().item()
  154. print(f"epoch: {epoch}, loss: {running_loss/num_samples}, acc: {running_corrects/num_samples}, top3: {top_3_corrects/num_samples}, top5: {top_5_corrects/num_samples}")
  155. if mode=="val":
  156. best_val_acc = running_corrects/num_samples*100 if running_corrects/num_samples*100 > best_val_acc else best_val_acc
  157. best_val_top3 = top_3_corrects/num_samples*100 if top_3_corrects/num_samples*100 > best_val_top3 else best_val_top3
  158. best_val_top5 = top_5_corrects/num_samples*100 if top_5_corrects/num_samples*100 > best_val_top5 else best_val_top5
  159. metrics = {
  160. '{}_loss'.format(mode): running_loss/num_samples,
  161. '{}_acc'.format(mode): running_corrects/num_samples*100,
  162. '{}_acc3'.format(mode): top_3_corrects/num_samples*100,
  163. '{}_acc5'.format(mode): top_5_corrects/num_samples*100
  164. }
  165. if mode == 'val':
  166. metrics["best_val_acc"] = best_val_acc
  167. metrics["best_val_acc3"] = best_val_top3
  168. metrics["best_val_acc5"] = best_val_top5
  169. result = running_corrects/num_samples*100
  170. mlflow.log_metrics(metrics)
  171. os.makedirs(output_dir, exist_ok=True)
  172. mlflow.pytorch.log_model(model, 'model')
  173. config.result = result
  174. torch.save(model.state_dict(), os.path.join(output_dir, "checkpoint.pth"))
  175. save_config(config, os.path.join(output_dir, "config.yml"))