You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

modules.py 11KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. import torch
  2. from torch import nn
  3. from torch.nn import functional as F
  4. from einops import rearrange, repeat
  5. import math
  6. """ Local Context Attention Module"""
  7. def scaled_dot_product(q, k, v, mask=None):
  8. d_k = q.size()[-1]
  9. attn_logits = torch.matmul(q, k.transpose(-2, -1))
  10. attn_logits = attn_logits / math.sqrt(d_k)
  11. if mask is not None:
  12. attn_logits = attn_logits.masked_fill(mask == 0, -9e15)
  13. attention = F.softmax(attn_logits, dim=-1)
  14. values = torch.matmul(attention, v)
  15. return values, attention
  16. class LCA(nn.Module):
  17. def __init__(self):
  18. super(LCA, self).__init__()
  19. def forward(self, x, pred):
  20. residual = x
  21. out = residual
  22. return out
  23. """ Global Context Module"""
  24. class GCM(nn.Module):
  25. def __init__(self, in_channels, out_channels):
  26. super(GCM, self).__init__()
  27. pool_size = [1, 3, 5]
  28. out_channel_list = [256, 128, 64, 64]
  29. upsampe_scale = [2, 4, 8, 16]
  30. GClist = []
  31. GCoutlist = []
  32. for ps in pool_size:
  33. GClist.append(nn.Sequential(
  34. nn.AdaptiveAvgPool2d(ps),
  35. nn.Conv2d(in_channels, out_channels, 1, 1),
  36. nn.ReLU(inplace=True)))
  37. GClist.append(nn.Sequential(
  38. nn.Conv2d(in_channels, out_channels, 1, 1),
  39. nn.ReLU(inplace=True),
  40. NonLocalBlock(out_channels)))
  41. self.GCmodule = nn.ModuleList(GClist)
  42. for i in range(4):
  43. GCoutlist.append(nn.Sequential(nn.Conv2d(out_channels * 4, out_channel_list[i], 3, 1, 1),
  44. nn.ReLU(inplace=True),
  45. nn.Upsample(scale_factor=upsampe_scale[i], mode='bilinear')))
  46. self.GCoutmodel = nn.ModuleList(GCoutlist)
  47. def forward(self, x):
  48. xsize = x.size()[2:]
  49. global_context = []
  50. for i in range(len(self.GCmodule) - 1):
  51. global_context.append(F.interpolate(self.GCmodule[i](x), xsize, mode='bilinear', align_corners=True))
  52. global_context.append(self.GCmodule[-1](x))
  53. global_context = torch.cat(global_context, dim=1)
  54. output = []
  55. for i in range(len(self.GCoutmodel)):
  56. output.append(self.GCoutmodel[i](global_context))
  57. return output
  58. """ Adaptive Selection Module"""
  59. class ASM(nn.Module):
  60. def __init__(self, in_channels, all_channels):
  61. super(ASM, self).__init__()
  62. self.non_local = NonLocalBlock(in_channels)
  63. def forward(self, lc, fuse, gc):
  64. fuse = self.non_local(fuse)
  65. fuse = torch.cat([lc, fuse, gc], dim=1)
  66. return fuse
  67. """
  68. Squeeze and Excitation Layer
  69. https://arxiv.org/abs/1709.01507
  70. """
  71. class SELayer(nn.Module):
  72. def __init__(self, channel, reduction=16):
  73. super(SELayer, self).__init__()
  74. self.avg_pool = nn.AdaptiveAvgPool2d(1)
  75. self.fc = nn.Sequential(
  76. nn.Linear(channel, channel // reduction, bias=False),
  77. nn.ReLU(inplace=True),
  78. nn.Linear(channel // reduction, channel, bias=False),
  79. nn.Sigmoid()
  80. )
  81. def forward(self, x):
  82. b, c, _, _ = x.size()
  83. y = self.avg_pool(x).view(b, c)
  84. y = self.fc(y).view(b, c, 1, 1)
  85. return x * y.expand_as(x)
  86. """
  87. Non Local Block
  88. https://arxiv.org/abs/1711.07971
  89. """
  90. class NonLocalBlock(nn.Module):
  91. def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
  92. super(NonLocalBlock, self).__init__()
  93. self.sub_sample = sub_sample
  94. self.in_channels = in_channels
  95. self.inter_channels = inter_channels
  96. if self.inter_channels is None:
  97. self.inter_channels = in_channels // 2
  98. if self.inter_channels == 0:
  99. self.inter_channels = 1
  100. self.g = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
  101. kernel_size=1, stride=1, padding=0)
  102. if bn_layer:
  103. self.W = nn.Sequential(
  104. nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
  105. kernel_size=1, stride=1, padding=0),
  106. nn.BatchNorm2d(self.in_channels)
  107. )
  108. nn.init.constant_(self.W[1].weight, 0)
  109. nn.init.constant_(self.W[1].bias, 0)
  110. else:
  111. self.W = nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
  112. kernel_size=1, stride=1, padding=0)
  113. nn.init.constant_(self.W.weight, 0)
  114. nn.init.constant_(self.W.bias, 0)
  115. self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
  116. kernel_size=1, stride=1, padding=0)
  117. self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
  118. kernel_size=1, stride=1, padding=0)
  119. if sub_sample:
  120. self.g = nn.Sequential(self.g, nn.MaxPool2d(kernel_size=(2, 2)))
  121. self.phi = nn.Sequential(self.phi, nn.MaxPool2d(kernel_size=(2, 2)))
  122. def forward(self, x):
  123. batch_size = x.size(0)
  124. g_x = self.g(x).view(batch_size, self.inter_channels, -1)
  125. g_x = g_x.permute(0, 2, 1)
  126. theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
  127. theta_x = theta_x.permute(0, 2, 1)
  128. phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
  129. f = torch.matmul(theta_x, phi_x)
  130. f_div_C = F.softmax(f, dim=-1)
  131. y = torch.matmul(f_div_C, g_x)
  132. y = y.permute(0, 2, 1).contiguous()
  133. y = y.view(batch_size, self.inter_channels, *x.size()[2:])
  134. W_y = self.W(y)
  135. z = W_y + x
  136. return z
  137. #AGCM Module
  138. class CrossNonLocalBlock(nn.Module):
  139. def __init__(self, in_channels_source,in_channels_target, inter_channels, sub_sample=False, bn_layer=True):
  140. super(CrossNonLocalBlock, self).__init__()
  141. self.sub_sample = sub_sample
  142. self.in_channels_source = in_channels_source
  143. self.in_channels_target = in_channels_target
  144. self.inter_channels = inter_channels
  145. """
  146. if self.inter_channels is None:
  147. self.inter_channels = in_channels // 2
  148. if self.inter_channels == 0:
  149. self.inter_channels = 1
  150. """
  151. self.g = nn.Conv2d(in_channels=self.in_channels_source, out_channels=self.inter_channels,
  152. kernel_size=1, stride=1, padding=0)
  153. self.theta = nn.Conv2d(in_channels=self.in_channels_source, out_channels=self.inter_channels,
  154. kernel_size=1, stride=1, padding=0)
  155. self.phi = nn.Conv2d(in_channels=self.in_channels_target, out_channels=self.inter_channels,
  156. kernel_size=1, stride=1, padding=0)
  157. if bn_layer:
  158. self.W = nn.Sequential(
  159. nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels_target,
  160. kernel_size=1, stride=1, padding=0),
  161. nn.BatchNorm2d(self.in_channels_target)
  162. )
  163. nn.init.constant_(self.W[1].weight, 0)
  164. nn.init.constant_(self.W[1].bias, 0)
  165. if sub_sample:
  166. self.g = nn.Sequential(self.g, nn.MaxPool2d(kernel_size=(2, 2)))
  167. self.phi = nn.Sequential(self.phi, nn.MaxPool2d(kernel_size=(2, 2)))
  168. def forward(self,x,l):
  169. batch_size = x.size(0)
  170. g_x = self.g(x).view(batch_size, self.inter_channels, -1)
  171. g_x = g_x.permute(0, 2, 1) #source
  172. theta_x1 = self.theta(x)
  173. theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
  174. theta_x = theta_x.permute(0, 2, 1) #source
  175. phi_x = self.phi(l).view(batch_size, self.inter_channels, -1) #target
  176. f = torch.matmul(theta_x, phi_x)
  177. f_div_C = F.softmax(f, dim=-1)
  178. f_div_C = f_div_C.permute(0,2,1)
  179. y = torch.matmul(f_div_C, g_x)
  180. y = y.permute(0, 2, 1).contiguous()
  181. y = y.view(batch_size, self.inter_channels, *l.size()[2:])
  182. W_y = self.W(y)
  183. z = W_y + l
  184. return z
  185. #SFEM module
  186. class NonLocalBlock_PatchWise(nn.Module):
  187. def __init__(self, in_channel, inter_channel, patch_factor):
  188. super(NonLocalBlock_PatchWise, self).__init__()
  189. "Embedding dimension must be 0 modulo number of heads."
  190. self.in_channel = in_channel
  191. self.patch_factor = patch_factor
  192. self.patch_width = int(8/self.patch_factor)
  193. self.patch_height = int(8/self.patch_factor)
  194. self.stride_width = int(8/self.patch_factor)
  195. self.stride_height = int(8/self.patch_factor)
  196. self.unfold = nn.Unfold(kernel_size=(self.patch_width, self.patch_height), stride=(self.stride_width, self.stride_height))
  197. self.adp = nn.AdaptiveAvgPool2d(8)
  198. self.bottleneck = nn.Conv2d(64,inter_channel,kernel_size=(1,1))
  199. self.non_block = NonLocalBlock(self.in_channel)
  200. self.adp_post = nn.AdaptiveAvgPool2d((8,8))
  201. def forward(self, x):
  202. batch_size = x.size(0)
  203. x_up = self.adp(x)
  204. x_up = self.unfold(x)
  205. batch_size,p_dim,p_size = x_up.size()
  206. x_up = x_up.view(batch_size,-1,self.in_channel,p_size)
  207. final_output = torch.tensor([]).cuda()
  208. index = torch.arange(0,p_size,1,dtype=torch.int64).cuda()
  209. for i in range(int(p_size)):
  210. divide = torch.index_select(x_up, 3, index[i])
  211. divide = divide.view(batch_size,-1,self.in_channel)
  212. patch_width = int(divide.size(1) ** 0.5)
  213. divide = divide.reshape(batch_size,self.in_channel,patch_width,patch_width) # tensor to operate on
  214. attn = self.non_block(divide)
  215. output = attn.view(batch_size,-1,self.in_channel,1)
  216. final_output = torch.cat((final_output,output),dim=3)
  217. final_output = final_output.view(batch_size, self.in_channel, 8,8)
  218. return final_output
  219. class GCM_up(nn.Module):
  220. def __init__(self, in_channels, out_channels):
  221. super(GCM_up, self).__init__()
  222. self.adp = nn.AdaptiveAvgPool2d((8,8))
  223. self.patch1 = NonLocalBlock_PatchWise(in_channels,out_channels,2)
  224. self.patch2 = NonLocalBlock_PatchWise(in_channels,out_channels,4)
  225. self.patch3 = NonLocalBlock(256,64)
  226. self.fuse = SELayer(3*256)
  227. self.conv = nn.Conv2d(3*256, out_channels, 1, 1)
  228. self.relu = nn.ReLU(inplace=True)
  229. def forward(self, x):
  230. b,c,h,w = x.size()
  231. x = self.adp(x)
  232. patch1 = self.patch1(x)
  233. patch2 = self.patch2(x)
  234. patch3 = self.patch3(x)
  235. global_cat = torch.cat((patch1, patch2, patch3), dim=1)
  236. fuse = self.relu(self.conv(self.fuse(global_cat)))
  237. adp_post = nn.AdaptiveAvgPool2d((h,w))
  238. fuse = adp_post(fuse)
  239. return fuse