123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155 |
- from torch.nn import Module
- from torch import nn
- import torch
- import torch.nn.functional as F
- from einops import reduce
- from gensim.models import FastText
- import numpy as np
- import json
- epsilon = 1e-8
- import pickle
- VECTORIZER_SIZE = 1500
-
- class EmbedderFasttext():
- def __init__(self, path):
- self.model = FastText.load(path)
- print(f'sFastText Embedding Loaded:\n\t Embedding Size = {self.model.wv.vector_size}\n\t Vocabulary Size = {self.model.wv.vectors.shape[0]}')
-
- def has(self, word):
- if word == "":
- return False
- return True
-
- def get(self, word):
- words = word.split('_')
- out = np.zeros(self.model.wv.vector_size)
- n = len(words)
- if n == 0:
- raise ValueError('Empty string was given.')
- for item in words:
- out += self.model.wv.get_vector(item) / n
- return list(out)
-
- class Transformer(Module):
- def __init__(self, input_size, nhead, num_layers, dim_feedforward, num_classes, aggregate = True):
- super(Transformer, self).__init__()
- self.encoder_layer = nn.TransformerEncoderLayer(d_model=input_size, dim_feedforward=dim_feedforward,nhead=nhead, batch_first=True)
- self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
- self.aggregate = aggregate
- if self.aggregate:
- self.linear = nn.Linear(input_size, num_classes, True)
-
- def forward(self, x, padding_mask):
- out = self.transformer_encoder(x, src_key_padding_mask=padding_mask)
- if self.aggregate:
- out = (out* ~padding_mask.unsqueeze(-1)).sum(dim=1)
- out = self.linear(torch.relu(out))
- return out
-
-
- class ImageTextTransformer(Module):
- def __init__(self, config):
- super(ImageTextTransformer, self).__init__()
- self.embedding_size = config.data.embedding_size
- self.custom_embed = False
- self.layers = config.model.ingredient_feature_extractor.layers
-
- if "G" in config.model.ingredient_feature_extractor.layers:
- assert False, "No GNN for this model"
-
- self.use_recipe_text = config.use_recipe_text
- self.use_text_ingredients = config.use_text_ingredients
- self.use_image_ingredients = config.use_image_ingredients
-
- if not self.use_recipe_text and not self.use_text_ingredients and not self.use_image_ingredients:
- raise Exception("The model can't work without any features")
-
-
- if self.use_text_ingredients or self.use_image_ingredients:
- transformer_input_feature_size = 0
-
- if self.use_image_ingredients:
- transformer_input_feature_size += config.model.image_feature_size
- if self.use_text_ingredients:
- transformer_input_feature_size += self.embedding_size
-
- blocks = [
- Transformer(
- input_size=transformer_input_feature_size,
- nhead=config.model.ingredient_feature_extractor.transformer.n_heads,
- num_layers=config.model.ingredient_feature_extractor.transformer.L,
- dim_feedforward=config.model.ingredient_feature_extractor.H,
- num_classes=config.model.ingredient_feature_extractor.final_ingredient_feature_size if i==len(config.model.ingredient_feature_extractor.layers)-1 else None,
- aggregate = (i==len(config.model.ingredient_feature_extractor.layers)-1)
- ) for i, m in enumerate(config.model.ingredient_feature_extractor.layers)
- ]
- self.ingredient_feature_module = nn.ModuleList(blocks)
-
- feature_size = {
- 'resnet18': 512,
- 'resnet50': 2048,
- 'resnet101': 2048,
- 'efficientnet_b0': 1280,
- 'efficientnet_b3': 1536,
- 'bert-base-uncased': 768,
- }
-
- if self.use_image_ingredients:
- self.image_feature_extractor = torch.nn.Linear(feature_size[config.image_model], config.model.image_feature_size)
-
- if self.use_recipe_text:
- self.text_feature_extractor = torch.nn.Linear(feature_size[config.text_model], config.model.text_feature_size)
-
- classifier_input_size = 0
- if self.use_image_ingredients or self.use_text_ingredients:
- classifier_input_size += config.model.ingredient_feature_extractor.final_ingredient_feature_size
- if self.use_recipe_text:
- classifier_input_size += config.model.text_feature_size
-
- self.classifier = torch.nn.Sequential(
- torch.nn.Linear(classifier_input_size, 300),
- torch.nn.ReLU(),
- torch.nn.Linear(300, 300),
- torch.nn.ReLU(),
- torch.nn.Linear(300, config.model.final_classes)
- )
-
- def forward(self, embeddings, mask, image_ingredients, recipe_embeddings):
- if self.use_recipe_text:
- text_features = self.text_feature_extractor(recipe_embeddings)
-
- if self.use_image_ingredients:
- image_features = self.image_feature_extractor(image_ingredients)
-
- if self.use_image_ingredients or self.use_text_ingredients:
- if self.use_text_ingredients and self.use_image_ingredients:
- ingredient_features = torch.cat([embeddings, image_features], dim = 2)
- elif self.use_text_ingredients:
- ingredient_features = embeddings
- else:
- ingredient_features = image_features
- out = ingredient_features
-
- for i, m in enumerate(self.layers):
- if m == "T":
- out = self.ingredient_feature_module[i](out, ~mask)
- else:
- raise Exception("Invalid module")
-
- aggregated_ingredient_features = out
-
- if self.use_recipe_text:
- recipe_features = torch.cat([text_features, aggregated_ingredient_features], dim=1)
- else:
- recipe_features = aggregated_ingredient_features
- else:
- recipe_features = text_features
-
- final_result = self.classifier(torch.nn.functional.relu(recipe_features))
- return final_result
-
- def freeze_features(self):
- self.feature_extractor.eval()
-
- def freeze_function(self):
- self.classifier.eval()
|