You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

EUNet.py 5.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. import torch
  2. import torch.nn as nn
  3. import torchvision.models as models
  4. from models.modules import LCA,ASM,GCM_up,GCM,CrossNonLocalBlock
  5. class ConvBlock(nn.Module):
  6. def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
  7. super(ConvBlock, self).__init__()
  8. self.conv = nn.Conv2d(in_channels, out_channels,
  9. kernel_size=kernel_size,
  10. stride=stride,
  11. padding=padding)
  12. self.bn = nn.BatchNorm2d(out_channels)
  13. self.relu = nn.ReLU(inplace=True)
  14. def forward(self, x):
  15. x = self.conv(x)
  16. x = self.bn(x)
  17. x = self.relu(x)
  18. return x
  19. class DecoderBlock(nn.Module):
  20. def __init__(self, in_channels, out_channels,
  21. kernel_size=3, stride=1, padding=1):
  22. super(DecoderBlock, self).__init__()
  23. self.conv1 = ConvBlock(in_channels, in_channels // 4, kernel_size=kernel_size,
  24. stride=stride, padding=padding)
  25. self.conv2 = ConvBlock(in_channels // 4, out_channels, kernel_size=kernel_size,
  26. stride=stride, padding=padding)
  27. self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
  28. def forward(self, x):
  29. x = self.conv1(x)
  30. x = self.conv2(x)
  31. x = self.upsample(x)
  32. return x
  33. class SideoutBlock(nn.Module):
  34. def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
  35. super(SideoutBlock, self).__init__()
  36. self.conv1 = ConvBlock(in_channels, in_channels // 4, kernel_size=kernel_size,
  37. stride=stride, padding=padding)
  38. self.dropout = nn.Dropout2d(0.1)
  39. self.conv2 = nn.Conv2d(in_channels // 4, out_channels, 1)
  40. def forward(self, x):
  41. x = self.conv1(x)
  42. x = self.dropout(x)
  43. x = self.conv2(x)
  44. return x
  45. class EUNet(nn.Module):
  46. def __init__(self, num_classes):
  47. super(EUNet, self).__init__()
  48. resnet = models.resnet34(pretrained=True)
  49. # Encoder
  50. self.encoder1_conv = resnet.conv1
  51. self.encoder1_bn = resnet.bn1
  52. self.encoder1_relu = resnet.relu
  53. self.maxpool = resnet.maxpool
  54. self.encoder2 = resnet.layer1
  55. self.encoder3 = resnet.layer2
  56. self.encoder4 = resnet.layer3
  57. self.encoder5 = resnet.layer4
  58. # Decoder
  59. self.decoder5 = DecoderBlock(in_channels=512, out_channels=512)
  60. self.decoder4 = DecoderBlock(in_channels=1024, out_channels=256)
  61. self.decoder3 = DecoderBlock(in_channels=512, out_channels=128)
  62. self.decoder2 = DecoderBlock(in_channels=256, out_channels=64)
  63. self.decoder1 = DecoderBlock(in_channels=192, out_channels=64)
  64. self.outconv = nn.Sequential(ConvBlock(64, 32, kernel_size=3, stride=1, padding=1),
  65. nn.Dropout2d(0.1),
  66. nn.Conv2d(32, num_classes, 1))
  67. self.outenc = ConvBlock(512,256,kernel_size=1, stride=1,padding=0)
  68. # Sideout
  69. self.sideout2 = SideoutBlock(64, 1)
  70. self.sideout3 = SideoutBlock(128, 1)
  71. self.sideout4 = SideoutBlock(256, 1)
  72. self.sideout5 = SideoutBlock(512, 1)
  73. # global context module
  74. self.gcm_up = GCM_up(256,64)
  75. self.gcm_e5 = GCM_up(256, 256)#3
  76. self.gcm_e4 = GCM_up(256, 128)#2
  77. self.gcm_e3 = GCM_up(256, 64)#1
  78. self.gcm_e2 = GCM_up(256, 64)#0
  79. # adaptive selection module
  80. self.asm4 = ASM(512, 1024)
  81. self.asm3 = ASM(256, 512)
  82. self.asm2 = ASM(128, 256)
  83. self.asm1 = ASM(64, 192)
  84. self.up1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
  85. self.up2 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
  86. self.up3 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
  87. self.up4 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=True)
  88. self.lca_cross_1 = CrossNonLocalBlock(512,256,256)
  89. self.lca_cross_2 = CrossNonLocalBlock(1024,128,128)
  90. self.lca_cross_3 = CrossNonLocalBlock(512,64,64)
  91. self.lca_cross_4 = CrossNonLocalBlock(256,64,64)
  92. def forward(self, x):
  93. e1 = self.encoder1_conv(x)
  94. e1 = self.encoder1_bn(e1)
  95. e1 = self.encoder1_relu(e1)
  96. e1_pool = self.maxpool(e1)
  97. e2 = self.encoder2(e1_pool)
  98. e3 = self.encoder3(e2)
  99. e4 = self.encoder4(e3)
  100. e5 = self.encoder5(e4)
  101. e_ex = self.outenc(e5)
  102. global_contexts_up = self.gcm_up(e_ex)
  103. d5 = self.decoder5(e5)
  104. out5 = self.sideout5(d5)
  105. lc4 = self.lca_cross_1(d5,e4)
  106. gc4 = self.gcm_e5(e_ex)
  107. gc4 = self.up1(gc4)
  108. comb4 = self.asm4(lc4, d5, gc4)
  109. d4 = self.decoder4(comb4)
  110. out4 = self.sideout4(d4)
  111. lc3 = self.lca_cross_2(comb4,e3)
  112. gc3 = self.gcm_e4(e_ex)
  113. gc3 = self.up2(gc3)
  114. comb3 = self.asm3(lc3, d4, gc3)
  115. d3 = self.decoder3(comb3)
  116. out3 = self.sideout3(d3)
  117. lc2= self.lca_cross_3(comb3,e2)
  118. gc2 = self.gcm_e3(e_ex)
  119. gc2 = self.up3(gc2)
  120. comb2 = self.asm2(lc2, d3, gc2)
  121. d2 = self.decoder2(comb2)
  122. out2 = self.sideout2(d2)
  123. lc1 = self.lca_cross_4(comb2,e1)
  124. gc1 = self.gcm_e2(e_ex)
  125. gc1 = self.up4(gc1)
  126. comb1 = self.asm1(lc1, d2, gc1)
  127. d1 = self.decoder1(comb1)
  128. out1 = self.outconv(d1)
  129. return torch.sigmoid(out1), torch.sigmoid(out2), torch.sigmoid(out3), \
  130. torch.sigmoid(out4), torch.sigmoid(out5)