123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119 |
- import os
- import numpy as np
- import torch
- import torchvision.models as models
- import torchvision.transforms as transforms
- from PIL import Image
- from tqdm import tqdm
-
-
- import warnings
- warnings.filterwarnings("ignore")
-
- models_names = [
- 'efficientnet_t0',
- 'resnet50',
- 'resnet101',
- 'efficientnet_b0',
- 'efficientnet_b3'
- ]
-
- input_dir = '/home/dml/food/CuisineAdaptation/crawled-images-full-384'
- output_root_dir = 'image-features-full'
-
- image_size = {
- 'resnet18': 224,
- 'resnet50': 224,
- 'resnet101': 224,
- 'efficientnet_b0': 224,
- 'efficientnet_t0': 224,
- 'efficientnet_b3': 300
- }
-
- normalize = {
- 'resnet18': transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
- 'resnet50': transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
- 'resnet101': transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
- 'efficientnet_t0': transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
- 'efficientnet_b3': transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
- }
-
- transform = {
- 'resnet18': transforms.Compose([
- transforms.Resize(image_size['resnet18']),
- transforms.CenterCrop(image_size['resnet18']),
- transforms.ToTensor(),
- normalize['resnet18']
- ]),
- 'resnet50': transforms.Compose([
- transforms.Resize(image_size['resnet50']),
- transforms.CenterCrop(image_size['resnet50']),
- transforms.ToTensor(),
- normalize['resnet50']
- ]),
- 'resnet101': transforms.Compose([
- transforms.Resize(image_size['resnet101']),
- transforms.CenterCrop(image_size['resnet101']),
- transforms.ToTensor(),
- normalize['resnet101']
- ]),
- 'efficientnet_t0': transforms.Compose([
- transforms.Resize(image_size['efficientnet_t0']),
- transforms.CenterCrop(image_size['efficientnet_t0']),
- transforms.ToTensor(),
- normalize['efficientnet_t0']
- ]),
- 'efficientnet_b3': transforms.Compose([
- transforms.Resize(image_size['efficientnet_b3']),
- transforms.CenterCrop(image_size['efficientnet_b3']),
- transforms.ToTensor(),
- normalize['efficientnet_b3']
- ])
- }
-
- device = torch.device("cuda")
- counter = 0
-
- for model_name in models_names:
- if 'resnet' in model_name:
- model = getattr(models, model_name)(pretrained=True)
- num_features = model.fc.in_features
- model.fc = torch.nn.Identity()
- elif 'efficientnet' in model_name:
- model = getattr(models, model_name)(pretrained=True)
- num_features = model.classifier[1].in_features
- model.classifier = torch.nn.Identity()
- else:
- print('Unknown model name: {}'.format(model_name))
- continue
- num_classes = num_features
- model = model.eval().to(device)
- output_dir = os.path.join(output_root_dir, model_name)
- os.makedirs(output_dir, exist_ok=True)
-
- for folder_name in tqdm(os.listdir(input_dir)):
- folder_dir = os.path.join(input_dir, folder_name)
- if not os.path.isdir(folder_dir):
- continue
-
- image_tensors = []
- for image_filename in os.listdir(folder_dir):
- if not image_filename.lower().endswith(".png") and not image_filename.lower().endswith(".jpg"):
- continue
- counter += 1
- image_path = os.path.join(folder_dir, image_filename)
-
- image = Image.open(image_path).convert('RGB')
- image_tensor = transform[model_name](image).unsqueeze(0).to(device)
- image_tensors.append(image_tensor)
-
- if len(image_tensors) > 0:
- input_tensors = torch.cat(image_tensors)
- with torch.no_grad():
- avg_features = model(input_tensors).squeeze(0).mean(dim=0).cpu().numpy()
- else:
- avg_features = np.zeros(num_features)
-
- output_filename = '{}.npy'.format(folder_name)
- output_path = os.path.join(output_dir, output_filename)
- np.save(output_path, avg_features)
|