Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

reid_model.py 4.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import numpy as np
  5. import cv2
  6. import logging
  7. import torchvision.transforms as transforms
  8. class BasicBlock(nn.Module):
  9. def __init__(self, c_in, c_out, is_downsample=False):
  10. super(BasicBlock, self).__init__()
  11. self.is_downsample = is_downsample
  12. if is_downsample:
  13. self.conv1 = nn.Conv2d(
  14. c_in, c_out, 3, stride=2, padding=1, bias=False)
  15. else:
  16. self.conv1 = nn.Conv2d(
  17. c_in, c_out, 3, stride=1, padding=1, bias=False)
  18. self.bn1 = nn.BatchNorm2d(c_out)
  19. self.relu = nn.ReLU(True)
  20. self.conv2 = nn.Conv2d(c_out, c_out, 3, stride=1,
  21. padding=1, bias=False)
  22. self.bn2 = nn.BatchNorm2d(c_out)
  23. if is_downsample:
  24. self.downsample = nn.Sequential(
  25. nn.Conv2d(c_in, c_out, 1, stride=2, bias=False),
  26. nn.BatchNorm2d(c_out)
  27. )
  28. elif c_in != c_out:
  29. self.downsample = nn.Sequential(
  30. nn.Conv2d(c_in, c_out, 1, stride=1, bias=False),
  31. nn.BatchNorm2d(c_out)
  32. )
  33. self.is_downsample = True
  34. def forward(self, x):
  35. y = self.conv1(x)
  36. y = self.bn1(y)
  37. y = self.relu(y)
  38. y = self.conv2(y)
  39. y = self.bn2(y)
  40. if self.is_downsample:
  41. x = self.downsample(x)
  42. return F.relu(x.add(y), True)
  43. def make_layers(c_in, c_out, repeat_times, is_downsample=False):
  44. blocks = []
  45. for i in range(repeat_times):
  46. if i == 0:
  47. blocks += [BasicBlock(c_in, c_out, is_downsample=is_downsample), ]
  48. else:
  49. blocks += [BasicBlock(c_out, c_out), ]
  50. return nn.Sequential(*blocks)
  51. class Net(nn.Module):
  52. def __init__(self, num_classes=751, reid=False):
  53. super(Net, self).__init__()
  54. # 3 128 64
  55. self.conv = nn.Sequential(
  56. nn.Conv2d(3, 64, 3, stride=1, padding=1),
  57. nn.BatchNorm2d(64),
  58. nn.ReLU(inplace=True),
  59. # nn.Conv2d(32,32,3,stride=1,padding=1),
  60. # nn.BatchNorm2d(32),
  61. # nn.ReLU(inplace=True),
  62. nn.MaxPool2d(3, 2, padding=1),
  63. )
  64. # 32 64 32
  65. self.layer1 = make_layers(64, 64, 2, False)
  66. # 32 64 32
  67. self.layer2 = make_layers(64, 128, 2, True)
  68. # 64 32 16
  69. self.layer3 = make_layers(128, 256, 2, True)
  70. # 128 16 8
  71. self.layer4 = make_layers(256, 512, 2, True)
  72. # 256 8 4
  73. self.avgpool = nn.AvgPool2d((8, 4), 1)
  74. # 256 1 1
  75. self.reid = reid
  76. self.classifier = nn.Sequential(
  77. nn.Linear(512, 256),
  78. nn.BatchNorm1d(256),
  79. nn.ReLU(inplace=True),
  80. nn.Dropout(),
  81. nn.Linear(256, num_classes),
  82. )
  83. def forward(self, x):
  84. x = self.conv(x)
  85. x = self.layer1(x)
  86. x = self.layer2(x)
  87. x = self.layer3(x)
  88. x = self.layer4(x)
  89. x = self.avgpool(x)
  90. x = x.view(x.size(0), -1)
  91. # B x 128
  92. if self.reid:
  93. x = x.div(x.norm(p=2, dim=1, keepdim=True))
  94. return x
  95. # classifier
  96. x = self.classifier(x)
  97. return x
  98. class Extractor(object):
  99. def __init__(self, model_path, use_cuda=True):
  100. self.net = Net(reid=True)
  101. self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
  102. state_dict = torch.load(model_path, map_location=torch.device(self.device))[
  103. 'net_dict']
  104. self.net.load_state_dict(state_dict)
  105. logger = logging.getLogger("root.tracker")
  106. logger.info("Loading weights from {}... Done!".format(model_path))
  107. self.net.to(self.device)
  108. self.size = (64, 128)
  109. self.norm = transforms.Compose([
  110. transforms.ToTensor(),
  111. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
  112. ])
  113. def _preprocess(self, im_crops):
  114. """
  115. TODO:
  116. 1. to float with scale from 0 to 1
  117. 2. resize to (64, 128) as Market1501 dataset did
  118. 3. concatenate to a numpy array
  119. 3. to torch Tensor
  120. 4. normalize
  121. """
  122. def _resize(im, size):
  123. return cv2.resize(im.astype(np.float32)/255., size)
  124. im_batch = torch.cat([self.norm(_resize(im, self.size)).unsqueeze(
  125. 0) for im in im_crops], dim=0).float()
  126. return im_batch
  127. def __call__(self, im_crops):
  128. im_batch = self._preprocess(im_crops)
  129. with torch.no_grad():
  130. im_batch = im_batch.to(self.device)
  131. features = self.net(im_batch)
  132. return features.cpu().numpy()