BSc project of Parham Saremi. The goal of the project was to detect the geographical region of the food using textual and visual features extracted from recipes and ingredients of the food.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

extract_image_vector.py 4.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import os
  2. import numpy as np
  3. import torch
  4. import torchvision.models as models
  5. import torchvision.transforms as transforms
  6. from PIL import Image
  7. from tqdm import tqdm
  8. import warnings
  9. warnings.filterwarnings("ignore")
  10. models_names = [
  11. 'efficientnet_t0',
  12. 'resnet50',
  13. 'resnet101',
  14. 'efficientnet_b0',
  15. 'efficientnet_b3'
  16. ]
  17. input_dir = '/home/dml/food/CuisineAdaptation/crawled-images-full-384'
  18. output_root_dir = 'image-features-full'
  19. image_size = {
  20. 'resnet18': 224,
  21. 'resnet50': 224,
  22. 'resnet101': 224,
  23. 'efficientnet_b0': 224,
  24. 'efficientnet_t0': 224,
  25. 'efficientnet_b3': 300
  26. }
  27. normalize = {
  28. 'resnet18': transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
  29. 'resnet50': transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
  30. 'resnet101': transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
  31. 'efficientnet_t0': transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
  32. 'efficientnet_b3': transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  33. }
  34. transform = {
  35. 'resnet18': transforms.Compose([
  36. transforms.Resize(image_size['resnet18']),
  37. transforms.CenterCrop(image_size['resnet18']),
  38. transforms.ToTensor(),
  39. normalize['resnet18']
  40. ]),
  41. 'resnet50': transforms.Compose([
  42. transforms.Resize(image_size['resnet50']),
  43. transforms.CenterCrop(image_size['resnet50']),
  44. transforms.ToTensor(),
  45. normalize['resnet50']
  46. ]),
  47. 'resnet101': transforms.Compose([
  48. transforms.Resize(image_size['resnet101']),
  49. transforms.CenterCrop(image_size['resnet101']),
  50. transforms.ToTensor(),
  51. normalize['resnet101']
  52. ]),
  53. 'efficientnet_t0': transforms.Compose([
  54. transforms.Resize(image_size['efficientnet_t0']),
  55. transforms.CenterCrop(image_size['efficientnet_t0']),
  56. transforms.ToTensor(),
  57. normalize['efficientnet_t0']
  58. ]),
  59. 'efficientnet_b3': transforms.Compose([
  60. transforms.Resize(image_size['efficientnet_b3']),
  61. transforms.CenterCrop(image_size['efficientnet_b3']),
  62. transforms.ToTensor(),
  63. normalize['efficientnet_b3']
  64. ])
  65. }
  66. device = torch.device("cuda")
  67. counter = 0
  68. for model_name in models_names:
  69. if 'resnet' in model_name:
  70. model = getattr(models, model_name)(pretrained=True)
  71. num_features = model.fc.in_features
  72. model.fc = torch.nn.Identity()
  73. elif 'efficientnet' in model_name:
  74. model = getattr(models, model_name)(pretrained=True)
  75. num_features = model.classifier[1].in_features
  76. model.classifier = torch.nn.Identity()
  77. else:
  78. print('Unknown model name: {}'.format(model_name))
  79. continue
  80. num_classes = num_features
  81. model = model.eval().to(device)
  82. output_dir = os.path.join(output_root_dir, model_name)
  83. os.makedirs(output_dir, exist_ok=True)
  84. for folder_name in tqdm(os.listdir(input_dir)):
  85. folder_dir = os.path.join(input_dir, folder_name)
  86. if not os.path.isdir(folder_dir):
  87. continue
  88. image_tensors = []
  89. for image_filename in os.listdir(folder_dir):
  90. if not image_filename.lower().endswith(".png") and not image_filename.lower().endswith(".jpg"):
  91. continue
  92. counter += 1
  93. image_path = os.path.join(folder_dir, image_filename)
  94. image = Image.open(image_path).convert('RGB')
  95. image_tensor = transform[model_name](image).unsqueeze(0).to(device)
  96. image_tensors.append(image_tensor)
  97. if len(image_tensors) > 0:
  98. input_tensors = torch.cat(image_tensors)
  99. with torch.no_grad():
  100. avg_features = model(input_tensors).squeeze(0).mean(dim=0).cpu().numpy()
  101. else:
  102. avg_features = np.zeros(num_features)
  103. output_filename = '{}.npy'.format(folder_name)
  104. output_path = os.path.join(output_dir, output_filename)
  105. np.save(output_path, avg_features)