Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

reid_model.py 8.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. import cv2
  2. import numpy as np
  3. import torch
  4. from torch.autograd import Variable
  5. import torch.nn.functional as F
  6. import torch.nn as nn
  7. import pickle
  8. import os
  9. from torch.nn.modules import CrossMapLRN2d as SpatialCrossMapLRN
  10. #from torch.legacy.nn import SpatialCrossMapLRN as SpatialCrossMapLRNOld
  11. from torch.autograd import Function, Variable
  12. from torch.nn import Module
  13. def clip_boxes(boxes, im_shape):
  14. """
  15. Clip boxes to image boundaries.
  16. """
  17. boxes = np.asarray(boxes)
  18. if boxes.shape[0] == 0:
  19. return boxes
  20. boxes = np.copy(boxes)
  21. # x1 >= 0
  22. boxes[:, 0::4] = np.maximum(np.minimum(boxes[:, 0::4], im_shape[1] - 1), 0)
  23. # y1 >= 0
  24. boxes[:, 1::4] = np.maximum(np.minimum(boxes[:, 1::4], im_shape[0] - 1), 0)
  25. # x2 < im_shape[1]
  26. boxes[:, 2::4] = np.maximum(np.minimum(boxes[:, 2::4], im_shape[1] - 1), 0)
  27. # y2 < im_shape[0]
  28. boxes[:, 3::4] = np.maximum(np.minimum(boxes[:, 3::4], im_shape[0] - 1), 0)
  29. return boxes
  30. def load_net(fname, net, prefix='', load_state_dict=False):
  31. import h5py
  32. with h5py.File(fname, mode='r') as h5f:
  33. h5f_is_module = True
  34. for k in h5f.keys():
  35. if not str(k).startswith('module.'):
  36. h5f_is_module = False
  37. break
  38. if prefix == '' and not isinstance(net, nn.DataParallel) and h5f_is_module:
  39. prefix = 'module.'
  40. for k, v in net.state_dict().items():
  41. k = prefix + k
  42. if k in h5f:
  43. param = torch.from_numpy(np.asarray(h5f[k]))
  44. if v.size() != param.size():
  45. print('Inconsistent shape: {}, {}'.format(v.size(), param.size()))
  46. else:
  47. v.copy_(param)
  48. else:
  49. print.warning('No layer: {}'.format(k))
  50. epoch = h5f.attrs['epoch'] if 'epoch' in h5f.attrs else -1
  51. if not load_state_dict:
  52. if 'learning_rates' in h5f.attrs:
  53. lr = h5f.attrs['learning_rates']
  54. else:
  55. lr = h5f.attrs.get('lr', -1)
  56. lr = np.asarray([lr] if lr > 0 else [], dtype=np.float)
  57. return epoch, lr
  58. state_file = fname + '.optimizer_state.pk'
  59. if os.path.isfile(state_file):
  60. with open(state_file, 'rb') as f:
  61. state_dicts = pickle.load(f)
  62. if not isinstance(state_dicts, list):
  63. state_dicts = [state_dicts]
  64. else:
  65. state_dicts = None
  66. return epoch, state_dicts
  67. # class SpatialCrossMapLRNFunc(Function):
  68. # def __init__(self, size, alpha=1e-4, beta=0.75, k=1):
  69. # self.size = size
  70. # self.alpha = alpha
  71. # self.beta = beta
  72. # self.k = k
  73. # def forward(self, input):
  74. # self.save_for_backward(input)
  75. # self.lrn = SpatialCrossMapLRNOld(self.size, self.alpha, self.beta, self.k)
  76. # self.lrn.type(input.type())
  77. # return self.lrn.forward(input)
  78. # def backward(self, grad_output):
  79. # input, = self.saved_tensors
  80. # return self.lrn.backward(input, grad_output)
  81. # # use this one instead
  82. # class SpatialCrossMapLRN(Module):
  83. # def __init__(self, size, alpha=1e-4, beta=0.75, k=1):
  84. # super(SpatialCrossMapLRN, self).__init__()
  85. # self.size = size
  86. # self.alpha = alpha
  87. # self.beta = beta
  88. # self.k = k
  89. # def forward(self, input):
  90. # return SpatialCrossMapLRNFunc(self.size, self.alpha, self.beta, self.k)(input)
  91. class Inception(nn.Module):
  92. def __init__(self, in_planes, n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes):
  93. super(Inception, self).__init__()
  94. # 1x1 conv branch
  95. self.b1 = nn.Sequential(
  96. nn.Conv2d(in_planes, n1x1, kernel_size=1),
  97. nn.ReLU(True),
  98. )
  99. # 1x1 conv -> 3x3 conv branch
  100. self.b2 = nn.Sequential(
  101. nn.Conv2d(in_planes, n3x3red, kernel_size=1),
  102. nn.ReLU(True),
  103. nn.Conv2d(n3x3red, n3x3, kernel_size=3, padding=1),
  104. nn.ReLU(True),
  105. )
  106. # 1x1 conv -> 5x5 conv branch
  107. self.b3 = nn.Sequential(
  108. nn.Conv2d(in_planes, n5x5red, kernel_size=1),
  109. nn.ReLU(True),
  110. nn.Conv2d(n5x5red, n5x5, kernel_size=5, padding=2),
  111. nn.ReLU(True),
  112. )
  113. # 3x3 pool -> 1x1 conv branch
  114. self.b4 = nn.Sequential(
  115. nn.MaxPool2d(3, stride=1, padding=1),
  116. nn.Conv2d(in_planes, pool_planes, kernel_size=1),
  117. nn.ReLU(True),
  118. )
  119. def forward(self, x):
  120. y1 = self.b1(x)
  121. y2 = self.b2(x)
  122. y3 = self.b3(x)
  123. y4 = self.b4(x)
  124. return torch.cat([y1,y2,y3,y4], 1)
  125. class GoogLeNet(nn.Module):
  126. output_channels = 832
  127. def __init__(self):
  128. super(GoogLeNet, self).__init__()
  129. self.pre_layers = nn.Sequential(
  130. nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
  131. nn.ReLU(True),
  132. nn.MaxPool2d(3, stride=2, ceil_mode=True),
  133. SpatialCrossMapLRN(5),
  134. nn.Conv2d(64, 64, 1),
  135. nn.ReLU(True),
  136. nn.Conv2d(64, 192, 3, padding=1),
  137. nn.ReLU(True),
  138. SpatialCrossMapLRN(5),
  139. nn.MaxPool2d(3, stride=2, ceil_mode=True),
  140. )
  141. self.a3 = Inception(192, 64, 96, 128, 16, 32, 32)
  142. self.b3 = Inception(256, 128, 128, 192, 32, 96, 64)
  143. self.maxpool = nn.MaxPool2d(3, stride=2, ceil_mode=True)
  144. self.a4 = Inception(480, 192, 96, 208, 16, 48, 64)
  145. self.b4 = Inception(512, 160, 112, 224, 24, 64, 64)
  146. self.c4 = Inception(512, 128, 128, 256, 24, 64, 64)
  147. self.d4 = Inception(512, 112, 144, 288, 32, 64, 64)
  148. self.e4 = Inception(528, 256, 160, 320, 32, 128, 128)
  149. def forward(self, x):
  150. out = self.pre_layers(x)
  151. out = self.a3(out)
  152. out = self.b3(out)
  153. out = self.maxpool(out)
  154. out = self.a4(out)
  155. out = self.b4(out)
  156. out = self.c4(out)
  157. out = self.d4(out)
  158. out = self.e4(out)
  159. return out
  160. class Model(nn.Module):
  161. def __init__(self, n_parts=8):
  162. super(Model, self).__init__()
  163. self.n_parts = n_parts
  164. self.feat_conv = GoogLeNet()
  165. self.conv_input_feat = nn.Conv2d(self.feat_conv.output_channels, 512, 1)
  166. # part net
  167. self.conv_att = nn.Conv2d(512, self.n_parts, 1)
  168. for i in range(self.n_parts):
  169. setattr(self, 'linear_feature{}'.format(i+1), nn.Linear(512, 64))
  170. def forward(self, x):
  171. feature = self.feat_conv(x)
  172. feature = self.conv_input_feat(feature)
  173. att_weights = torch.sigmoid(self.conv_att(feature))
  174. linear_feautres = []
  175. for i in range(self.n_parts):
  176. masked_feature = feature * torch.unsqueeze(att_weights[:, i], 1)
  177. pooled_feature = F.avg_pool2d(masked_feature, masked_feature.size()[2:4])
  178. linear_feautres.append(
  179. getattr(self, 'linear_feature{}'.format(i+1))(pooled_feature.view(pooled_feature.size(0), -1))
  180. )
  181. concat_features = torch.cat(linear_feautres, 1)
  182. normed_feature = concat_features / torch.clamp(torch.norm(concat_features, 2, 1, keepdim=True), min=1e-6)
  183. return normed_feature
  184. def load_reid_model(ckpt):
  185. model = Model(n_parts=8)
  186. model.inp_size = (80, 160)
  187. load_net(ckpt, model)
  188. print('Load ReID model from {}'.format(ckpt))
  189. model = model.cuda()
  190. model.eval()
  191. return model
  192. def im_preprocess(image):
  193. image = np.asarray(image, np.float32)
  194. image -= np.array([104, 117, 123], dtype=np.float32).reshape(1, 1, -1)
  195. image = image.transpose((2, 0, 1))
  196. return image
  197. def extract_image_patches(image, bboxes):
  198. bboxes = np.round(bboxes).astype(np.int)
  199. bboxes = clip_boxes(bboxes, image.shape)
  200. patches = [image[box[1]:box[3], box[0]:box[2]] for box in bboxes]
  201. return patches
  202. def extract_reid_features(reid_model, image, tlbrs):
  203. if len(tlbrs) == 0:
  204. return torch.FloatTensor()
  205. patches = extract_image_patches(image, tlbrs)
  206. patches = np.asarray([im_preprocess(cv2.resize(p, reid_model.inp_size)) for p in patches], dtype=np.float32)
  207. with torch.no_grad():
  208. im_var = Variable(torch.from_numpy(patches))
  209. im_var = im_var.cuda()
  210. features = reid_model(im_var).data
  211. return features