from typing import Any import torch from torch.utils.data import Dataset import json import numpy as np from torch.nn.utils.rnn import pad_sequence import warnings import os warnings.filterwarnings(action='ignore',category=UserWarning,module='gensim') warnings.filterwarnings(action='ignore',category=FutureWarning,module='gensim') def mask_count(num): return num//5 def generate_ing_dict(path, threshold): assert path != None with open(path, "r") as json_file: full_ing_count_list:dict = json.load(json_file) filtered_ing_list = {} counter = 0 for ing, count in full_ing_count_list.items(): if count > threshold: filtered_ing_list[ing] = counter counter += 1 return filtered_ing_list def get_ingredient_frequencies(occr_path): occr = None with open(occr_path, "r") as json_file: occr = json.load(json_file) if '' in occr: del occr[''] return occr class RecipeDBDataset(Dataset): def __init__(self, json_path, cousine_dict=None, extract_ingredients=False, extract_recipes=False, extract_cousine=False, embedder=None, include_id=False, mask_threshold=1000, mask_path=None, occr_path = None, target='country', image_model="resnet18") -> None: super(RecipeDBDataset, self).__init__() with open(json_path, "r") as json_file: data = json.load(json_file) if occr_path is not None: self.freqs = get_ingredient_frequencies(occr_path) self.all_ingredients, self.all_ingredient_probs = zip(*sorted(self.freqs.items())) self.all_ingredients = list(self.all_ingredients) self.all_ingredient_probs = np.array(self.all_ingredient_probs, dtype=np.float32) self.all_ingredient_probs /= np.sum(self.all_ingredient_probs) self.ing_dict:dict = generate_ing_dict(mask_path, mask_threshold) self.len_mask_ing = len(self.ing_dict) self.data = [] self.embedder = embedder self.extract_ingredients = extract_ingredients self.extract_recipes = extract_recipes self.extract_cousine = extract_cousine self.ingredient_set = set() self.image_path = "Data/image_dict_ings.json" with open(self.image_path, 'r') as jf: self.image_ing_dict = json.load(jf) self.image_feature_path = "/home/dml/food/CuisineAdaptation/IngredientsEncoding/image-features-full" feature_size = { 'resnet18': 512, 'resnet50': 2048, 'resnet101': 2048, 'efficientnet_b0': 1280, 'efficientnet_b3': 1536, 'efficientnet_t0': 1280 } self.image_model = image_model self.image_feature_size = feature_size[self.image_model] self.not_found_ings = set() self.text_feature_path = "/home/dml/food/CuisineAdaptation/IngredientsEncoding/text-features" self.text_feature_model = "bert-base-uncased" failed_ing_count = 0 for recipe in data: temp_data = {} if extract_ingredients: temp_data["ingredients"] = [] for ing in recipe["ingredients"]: if ing["Ingredient Name"] != "": temp_data["ingredients"].append(ing["Ingredient Name"]) if len(temp_data["ingredients"]) == 0: failed_ing_count += 1 continue if extract_cousine: temp_data["cousine"] = cousine_dict[recipe[target]] if include_id: temp_data["id"] = recipe["id"] self.data.append(temp_data) self.cousine_dict = cousine_dict print(f"failed ings count: {failed_ing_count}") def __getitem__(self, index: Any): d = self.data[index] out = {} ings = [] if self.extract_ingredients: for ing in d["ingredients"]: if self.embedder.has(ing): ings.append(self.embedder.get(ing)) ings = torch.tensor(ings, dtype=torch.float32) image_ingredients = [] for ing in d["ingredients"]: npy_path = "" if ing in self.image_ing_dict: npy_path = os.path.join(self.image_feature_path, self.image_model, f"{ing}.npy") elif ing.replace(" ", "_") in self.image_ing_dict: npy_path = os.path.join(self.image_feature_path, self.image_model, f"{ing.replace(' ', '_')}.npy") else: for ing_part in ing.split(): if ing_part in self.image_ing_dict: npy_path = os.path.join(self.image_feature_path, self.image_model, f"{ing_part}.npy") break else: self.not_found_ings.add(ing) if npy_path == "": image_ingredients.append(np.zeros(self.image_feature_size)) else: image_ingredients.append(np.load(npy_path)) image_ingredients = torch.tensor(image_ingredients, dtype=torch.float32) out["ingredients"] = ings out["image_ingredients"] = image_ingredients if self.extract_recipes: out["recipe_embedding"] = torch.tensor(np.load(os.path.join(self.text_feature_path, self.text_feature_model, f'{d["id"]}.npy')), dtype=torch.float32) if self.extract_cousine: out["cousine"] = d["cousine"] return out def __len__(self): return self.data.__len__() def rdb_collate(self, batch): cousines = [] ingredients = [] masks = [] image_ingredients = [] recipe_embeddings = [] for data in batch: if "cousine" in data: cousines.append(data["cousine"]) if "recipe_embedding" in data: recipe_embeddings.append(data["recipe_embedding"]) if "ingredients" in data: ingredients.append(data["ingredients"]) masks.append(torch.ones(data["ingredients"].shape[0])) image_ingredients.append(data["image_ingredients"]) outs = {} if "ingredients" in data: masks = pad_sequence(masks, batch_first=True, padding_value=0).type(torch.bool) ingredients = pad_sequence(ingredients, batch_first=True, padding_value=0) image_ingredients = pad_sequence(image_ingredients, batch_first=True, padding_value=0) outs["masks"] = masks outs["ingredients"] = ingredients outs["image_ingredients"] = image_ingredients if "recipe_embedding" in data: outs["recipe_embeddings"] = torch.cat(recipe_embeddings, dim=0) if "cousine" in data: cousines = torch.LongTensor(cousines) outs["cousines"] = cousines return outs def dict_to_device(data:dict, device, return_new_dict=False): new_dict = {} for k, v in data.items(): if not return_new_dict: data[k] = v.to(device) else: new_dict[k] = v.to(device) return new_dict if return_new_dict else data