BSc project of Parham Saremi. The goal of the project was to detect the geographical region of the food using textual and visual features extracted from recipes and ingredients of the food.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

network.py 6.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. from torch.nn import Module
  2. from torch import nn
  3. import torch
  4. import torch.nn.functional as F
  5. from einops import reduce
  6. from gensim.models import FastText
  7. import numpy as np
  8. import json
  9. epsilon = 1e-8
  10. import pickle
  11. VECTORIZER_SIZE = 1500
  12. class EmbedderFasttext():
  13. def __init__(self, path):
  14. self.model = FastText.load(path)
  15. print(f'sFastText Embedding Loaded:\n\t Embedding Size = {self.model.wv.vector_size}\n\t Vocabulary Size = {self.model.wv.vectors.shape[0]}')
  16. def has(self, word):
  17. if word == "":
  18. return False
  19. return True
  20. def get(self, word):
  21. words = word.split('_')
  22. out = np.zeros(self.model.wv.vector_size)
  23. n = len(words)
  24. if n == 0:
  25. raise ValueError('Empty string was given.')
  26. for item in words:
  27. out += self.model.wv.get_vector(item) / n
  28. return list(out)
  29. class Transformer(Module):
  30. def __init__(self, input_size, nhead, num_layers, dim_feedforward, num_classes, aggregate = True):
  31. super(Transformer, self).__init__()
  32. self.encoder_layer = nn.TransformerEncoderLayer(d_model=input_size, dim_feedforward=dim_feedforward,nhead=nhead, batch_first=True)
  33. self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
  34. self.aggregate = aggregate
  35. if self.aggregate:
  36. self.linear = nn.Linear(input_size, num_classes, True)
  37. def forward(self, x, padding_mask):
  38. out = self.transformer_encoder(x, src_key_padding_mask=padding_mask)
  39. if self.aggregate:
  40. out = (out* ~padding_mask.unsqueeze(-1)).sum(dim=1)
  41. out = self.linear(torch.relu(out))
  42. return out
  43. class ImageTextTransformer(Module):
  44. def __init__(self, config):
  45. super(ImageTextTransformer, self).__init__()
  46. self.embedding_size = config.data.embedding_size
  47. self.custom_embed = False
  48. self.layers = config.model.ingredient_feature_extractor.layers
  49. if "G" in config.model.ingredient_feature_extractor.layers:
  50. assert False, "No GNN for this model"
  51. self.use_recipe_text = config.use_recipe_text
  52. self.use_text_ingredients = config.use_text_ingredients
  53. self.use_image_ingredients = config.use_image_ingredients
  54. if not self.use_recipe_text and not self.use_text_ingredients and not self.use_image_ingredients:
  55. raise Exception("The model can't work without any features")
  56. if self.use_text_ingredients or self.use_image_ingredients:
  57. transformer_input_feature_size = 0
  58. if self.use_image_ingredients:
  59. transformer_input_feature_size += config.model.image_feature_size
  60. if self.use_text_ingredients:
  61. transformer_input_feature_size += self.embedding_size
  62. blocks = [
  63. Transformer(
  64. input_size=transformer_input_feature_size,
  65. nhead=config.model.ingredient_feature_extractor.transformer.n_heads,
  66. num_layers=config.model.ingredient_feature_extractor.transformer.L,
  67. dim_feedforward=config.model.ingredient_feature_extractor.H,
  68. num_classes=config.model.ingredient_feature_extractor.final_ingredient_feature_size if i==len(config.model.ingredient_feature_extractor.layers)-1 else None,
  69. aggregate = (i==len(config.model.ingredient_feature_extractor.layers)-1)
  70. ) for i, m in enumerate(config.model.ingredient_feature_extractor.layers)
  71. ]
  72. self.ingredient_feature_module = nn.ModuleList(blocks)
  73. feature_size = {
  74. 'resnet18': 512,
  75. 'resnet50': 2048,
  76. 'resnet101': 2048,
  77. 'efficientnet_b0': 1280,
  78. 'efficientnet_b3': 1536,
  79. 'bert-base-uncased': 768,
  80. }
  81. if self.use_image_ingredients:
  82. self.image_feature_extractor = torch.nn.Linear(feature_size[config.image_model], config.model.image_feature_size)
  83. if self.use_recipe_text:
  84. self.text_feature_extractor = torch.nn.Linear(feature_size[config.text_model], config.model.text_feature_size)
  85. classifier_input_size = 0
  86. if self.use_image_ingredients or self.use_text_ingredients:
  87. classifier_input_size += config.model.ingredient_feature_extractor.final_ingredient_feature_size
  88. if self.use_recipe_text:
  89. classifier_input_size += config.model.text_feature_size
  90. self.classifier = torch.nn.Sequential(
  91. torch.nn.Linear(classifier_input_size, 300),
  92. torch.nn.ReLU(),
  93. torch.nn.Linear(300, 300),
  94. torch.nn.ReLU(),
  95. torch.nn.Linear(300, config.model.final_classes)
  96. )
  97. def forward(self, embeddings, mask, image_ingredients, recipe_embeddings):
  98. if self.use_recipe_text:
  99. text_features = self.text_feature_extractor(recipe_embeddings)
  100. if self.use_image_ingredients:
  101. image_features = self.image_feature_extractor(image_ingredients)
  102. if self.use_image_ingredients or self.use_text_ingredients:
  103. if self.use_text_ingredients and self.use_image_ingredients:
  104. ingredient_features = torch.cat([embeddings, image_features], dim = 2)
  105. elif self.use_text_ingredients:
  106. ingredient_features = embeddings
  107. else:
  108. ingredient_features = image_features
  109. out = ingredient_features
  110. for i, m in enumerate(self.layers):
  111. if m == "T":
  112. out = self.ingredient_feature_module[i](out, ~mask)
  113. else:
  114. raise Exception("Invalid module")
  115. aggregated_ingredient_features = out
  116. if self.use_recipe_text:
  117. recipe_features = torch.cat([text_features, aggregated_ingredient_features], dim=1)
  118. else:
  119. recipe_features = aggregated_ingredient_features
  120. else:
  121. recipe_features = text_features
  122. final_result = self.classifier(torch.nn.functional.relu(recipe_features))
  123. return final_result
  124. def freeze_features(self):
  125. self.feature_extractor.eval()
  126. def freeze_function(self):
  127. self.classifier.eval()