from torch.nn import Module from torch import nn import torch import torch.nn.functional as F from einops import reduce from gensim.models import FastText import numpy as np import json epsilon = 1e-8 import pickle VECTORIZER_SIZE = 1500 class EmbedderFasttext(): def __init__(self, path): self.model = FastText.load(path) print(f'sFastText Embedding Loaded:\n\t Embedding Size = {self.model.wv.vector_size}\n\t Vocabulary Size = {self.model.wv.vectors.shape[0]}') def has(self, word): if word == "": return False return True def get(self, word): words = word.split('_') out = np.zeros(self.model.wv.vector_size) n = len(words) if n == 0: raise ValueError('Empty string was given.') for item in words: out += self.model.wv.get_vector(item) / n return list(out) class Transformer(Module): def __init__(self, input_size, nhead, num_layers, dim_feedforward, num_classes, aggregate = True): super(Transformer, self).__init__() self.encoder_layer = nn.TransformerEncoderLayer(d_model=input_size, dim_feedforward=dim_feedforward,nhead=nhead, batch_first=True) self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers) self.aggregate = aggregate if self.aggregate: self.linear = nn.Linear(input_size, num_classes, True) def forward(self, x, padding_mask): out = self.transformer_encoder(x, src_key_padding_mask=padding_mask) if self.aggregate: out = (out* ~padding_mask.unsqueeze(-1)).sum(dim=1) out = self.linear(torch.relu(out)) return out class ImageTextTransformer(Module): def __init__(self, config): super(ImageTextTransformer, self).__init__() self.embedding_size = config.data.embedding_size self.custom_embed = False self.layers = config.model.ingredient_feature_extractor.layers if "G" in config.model.ingredient_feature_extractor.layers: assert False, "No GNN for this model" self.use_recipe_text = config.use_recipe_text self.use_text_ingredients = config.use_text_ingredients self.use_image_ingredients = config.use_image_ingredients if not self.use_recipe_text and not self.use_text_ingredients and not self.use_image_ingredients: raise Exception("The model can't work without any features") if self.use_text_ingredients or self.use_image_ingredients: transformer_input_feature_size = 0 if self.use_image_ingredients: transformer_input_feature_size += config.model.image_feature_size if self.use_text_ingredients: transformer_input_feature_size += self.embedding_size blocks = [ Transformer( input_size=transformer_input_feature_size, nhead=config.model.ingredient_feature_extractor.transformer.n_heads, num_layers=config.model.ingredient_feature_extractor.transformer.L, dim_feedforward=config.model.ingredient_feature_extractor.H, num_classes=config.model.ingredient_feature_extractor.final_ingredient_feature_size if i==len(config.model.ingredient_feature_extractor.layers)-1 else None, aggregate = (i==len(config.model.ingredient_feature_extractor.layers)-1) ) for i, m in enumerate(config.model.ingredient_feature_extractor.layers) ] self.ingredient_feature_module = nn.ModuleList(blocks) feature_size = { 'resnet18': 512, 'resnet50': 2048, 'resnet101': 2048, 'efficientnet_b0': 1280, 'efficientnet_b3': 1536, 'bert-base-uncased': 768, } if self.use_image_ingredients: self.image_feature_extractor = torch.nn.Linear(feature_size[config.image_model], config.model.image_feature_size) if self.use_recipe_text: self.text_feature_extractor = torch.nn.Linear(feature_size[config.text_model], config.model.text_feature_size) classifier_input_size = 0 if self.use_image_ingredients or self.use_text_ingredients: classifier_input_size += config.model.ingredient_feature_extractor.final_ingredient_feature_size if self.use_recipe_text: classifier_input_size += config.model.text_feature_size self.classifier = torch.nn.Sequential( torch.nn.Linear(classifier_input_size, 300), torch.nn.ReLU(), torch.nn.Linear(300, 300), torch.nn.ReLU(), torch.nn.Linear(300, config.model.final_classes) ) def forward(self, embeddings, mask, image_ingredients, recipe_embeddings): if self.use_recipe_text: text_features = self.text_feature_extractor(recipe_embeddings) if self.use_image_ingredients: image_features = self.image_feature_extractor(image_ingredients) if self.use_image_ingredients or self.use_text_ingredients: if self.use_text_ingredients and self.use_image_ingredients: ingredient_features = torch.cat([embeddings, image_features], dim = 2) elif self.use_text_ingredients: ingredient_features = embeddings else: ingredient_features = image_features out = ingredient_features for i, m in enumerate(self.layers): if m == "T": out = self.ingredient_feature_module[i](out, ~mask) else: raise Exception("Invalid module") aggregated_ingredient_features = out if self.use_recipe_text: recipe_features = torch.cat([text_features, aggregated_ingredient_features], dim=1) else: recipe_features = aggregated_ingredient_features else: recipe_features = text_features final_result = self.classifier(torch.nn.functional.relu(recipe_features)) return final_result def freeze_features(self): self.feature_extractor.eval() def freeze_function(self): self.classifier.eval()