BSc project of Parham Saremi. The goal of the project was to detect the geographical region of the food using textual and visual features extracted from recipes and ingredients of the food.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

recipedb_dataset.py 7.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. from typing import Any
  2. import torch
  3. from torch.utils.data import Dataset
  4. import json
  5. import numpy as np
  6. from torch.nn.utils.rnn import pad_sequence
  7. import warnings
  8. import os
  9. warnings.filterwarnings(action='ignore',category=UserWarning,module='gensim')
  10. warnings.filterwarnings(action='ignore',category=FutureWarning,module='gensim')
  11. def mask_count(num):
  12. return num//5
  13. def generate_ing_dict(path, threshold):
  14. assert path != None
  15. with open(path, "r") as json_file:
  16. full_ing_count_list:dict = json.load(json_file)
  17. filtered_ing_list = {}
  18. counter = 0
  19. for ing, count in full_ing_count_list.items():
  20. if count > threshold:
  21. filtered_ing_list[ing] = counter
  22. counter += 1
  23. return filtered_ing_list
  24. def get_ingredient_frequencies(occr_path):
  25. occr = None
  26. with open(occr_path, "r") as json_file:
  27. occr = json.load(json_file)
  28. if '' in occr:
  29. del occr['']
  30. return occr
  31. class RecipeDBDataset(Dataset):
  32. def __init__(self, json_path, cousine_dict=None,
  33. extract_ingredients=False, extract_recipes=False, extract_cousine=False,
  34. embedder=None, include_id=False, mask_threshold=1000, mask_path=None,
  35. occr_path = None, target='country',
  36. image_model="resnet18") -> None:
  37. super(RecipeDBDataset, self).__init__()
  38. with open(json_path, "r") as json_file:
  39. data = json.load(json_file)
  40. if occr_path is not None:
  41. self.freqs = get_ingredient_frequencies(occr_path)
  42. self.all_ingredients, self.all_ingredient_probs = zip(*sorted(self.freqs.items()))
  43. self.all_ingredients = list(self.all_ingredients)
  44. self.all_ingredient_probs = np.array(self.all_ingredient_probs, dtype=np.float32)
  45. self.all_ingredient_probs /= np.sum(self.all_ingredient_probs)
  46. self.ing_dict:dict = generate_ing_dict(mask_path, mask_threshold)
  47. self.len_mask_ing = len(self.ing_dict)
  48. self.data = []
  49. self.embedder = embedder
  50. self.extract_ingredients = extract_ingredients
  51. self.extract_recipes = extract_recipes
  52. self.extract_cousine = extract_cousine
  53. self.ingredient_set = set()
  54. self.image_path = "Data/image_dict_ings.json"
  55. with open(self.image_path, 'r') as jf:
  56. self.image_ing_dict = json.load(jf)
  57. self.image_feature_path = "/home/dml/food/CuisineAdaptation/IngredientsEncoding/image-features-full"
  58. feature_size = {
  59. 'resnet18': 512,
  60. 'resnet50': 2048,
  61. 'resnet101': 2048,
  62. 'efficientnet_b0': 1280,
  63. 'efficientnet_b3': 1536,
  64. 'efficientnet_t0': 1280
  65. }
  66. self.image_model = image_model
  67. self.image_feature_size = feature_size[self.image_model]
  68. self.not_found_ings = set()
  69. self.text_feature_path = "/home/dml/food/CuisineAdaptation/IngredientsEncoding/text-features"
  70. self.text_feature_model = "bert-base-uncased"
  71. failed_ing_count = 0
  72. for recipe in data:
  73. temp_data = {}
  74. if extract_ingredients:
  75. temp_data["ingredients"] = []
  76. for ing in recipe["ingredients"]:
  77. if ing["Ingredient Name"] != "":
  78. temp_data["ingredients"].append(ing["Ingredient Name"])
  79. if len(temp_data["ingredients"]) == 0:
  80. failed_ing_count += 1
  81. continue
  82. if extract_cousine:
  83. temp_data["cousine"] = cousine_dict[recipe[target]]
  84. if include_id:
  85. temp_data["id"] = recipe["id"]
  86. self.data.append(temp_data)
  87. self.cousine_dict = cousine_dict
  88. print(f"failed ings count: {failed_ing_count}")
  89. def __getitem__(self, index: Any):
  90. d = self.data[index]
  91. out = {}
  92. ings = []
  93. if self.extract_ingredients:
  94. for ing in d["ingredients"]:
  95. if self.embedder.has(ing):
  96. ings.append(self.embedder.get(ing))
  97. ings = torch.tensor(ings, dtype=torch.float32)
  98. image_ingredients = []
  99. for ing in d["ingredients"]:
  100. npy_path = ""
  101. if ing in self.image_ing_dict:
  102. npy_path = os.path.join(self.image_feature_path, self.image_model, f"{ing}.npy")
  103. elif ing.replace(" ", "_") in self.image_ing_dict:
  104. npy_path = os.path.join(self.image_feature_path, self.image_model, f"{ing.replace(' ', '_')}.npy")
  105. else:
  106. for ing_part in ing.split():
  107. if ing_part in self.image_ing_dict:
  108. npy_path = os.path.join(self.image_feature_path, self.image_model, f"{ing_part}.npy")
  109. break
  110. else:
  111. self.not_found_ings.add(ing)
  112. if npy_path == "":
  113. image_ingredients.append(np.zeros(self.image_feature_size))
  114. else:
  115. image_ingredients.append(np.load(npy_path))
  116. image_ingredients = torch.tensor(image_ingredients, dtype=torch.float32)
  117. out["ingredients"] = ings
  118. out["image_ingredients"] = image_ingredients
  119. if self.extract_recipes:
  120. out["recipe_embedding"] = torch.tensor(np.load(os.path.join(self.text_feature_path, self.text_feature_model, f'{d["id"]}.npy')), dtype=torch.float32)
  121. if self.extract_cousine:
  122. out["cousine"] = d["cousine"]
  123. return out
  124. def __len__(self):
  125. return self.data.__len__()
  126. def rdb_collate(self, batch):
  127. cousines = []
  128. ingredients = []
  129. masks = []
  130. image_ingredients = []
  131. recipe_embeddings = []
  132. for data in batch:
  133. if "cousine" in data:
  134. cousines.append(data["cousine"])
  135. if "recipe_embedding" in data:
  136. recipe_embeddings.append(data["recipe_embedding"])
  137. if "ingredients" in data:
  138. ingredients.append(data["ingredients"])
  139. masks.append(torch.ones(data["ingredients"].shape[0]))
  140. image_ingredients.append(data["image_ingredients"])
  141. outs = {}
  142. if "ingredients" in data:
  143. masks = pad_sequence(masks, batch_first=True, padding_value=0).type(torch.bool)
  144. ingredients = pad_sequence(ingredients, batch_first=True, padding_value=0)
  145. image_ingredients = pad_sequence(image_ingredients, batch_first=True, padding_value=0)
  146. outs["masks"] = masks
  147. outs["ingredients"] = ingredients
  148. outs["image_ingredients"] = image_ingredients
  149. if "recipe_embedding" in data:
  150. outs["recipe_embeddings"] = torch.cat(recipe_embeddings, dim=0)
  151. if "cousine" in data:
  152. cousines = torch.LongTensor(cousines)
  153. outs["cousines"] = cousines
  154. return outs
  155. def dict_to_device(data:dict, device, return_new_dict=False):
  156. new_dict = {}
  157. for k, v in data.items():
  158. if not return_new_dict:
  159. data[k] = v.to(device)
  160. else:
  161. new_dict[k] = v.to(device)
  162. return new_dict if return_new_dict else data