You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

model_handler.py 6.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. import torch
  2. from torch import nn
  3. from torch.nn import functional as F
  4. from utils import create_prompt_main
  5. device = 'cuda:0'
  6. from segment_anything import SamPredictor, sam_model_registry
  7. class panc_sam(nn.Module):
  8. def forward(self, batched_input, device):
  9. box = torch.tensor([[200, 200, 750, 800]]).to(device)
  10. outputs = []
  11. outputs_prompt = []
  12. for image_record in batched_input:
  13. image_embeddings = image_record["image_embedd"].to(device)
  14. if "point_coords" in image_record:
  15. point_coords = image_record["point_coords"].to(device)
  16. point_labels = image_record["point_labels"].to(device)
  17. points = (point_coords.unsqueeze(0), point_labels.unsqueeze(0))
  18. else:
  19. raise ValueError("what the f?")
  20. # input_images = torch.stack([x["image"] for x in batched_input], dim=0)
  21. with torch.no_grad():
  22. sparse_embeddings, dense_embeddings = self.prompt_encoder(
  23. points=None,
  24. boxes=box,
  25. masks=None,
  26. )
  27. sparse_embeddings = sparse_embeddings
  28. dense_embeddings = dense_embeddings
  29. # raise ValueError(image_embeddings.shape)
  30. #####################################################
  31. low_res_masks, _ = self.mask_decoder(
  32. image_embeddings=image_embeddings,
  33. image_pe=self.prompt_encoder.get_dense_pe().detach(),
  34. sparse_prompt_embeddings=sparse_embeddings.detach(),
  35. dense_prompt_embeddings=dense_embeddings.detach(),
  36. multimask_output=False,
  37. )
  38. outputs.append(low_res_masks)
  39. # points, point_labels = create_prompt((low_res_masks > 0).float())
  40. # points, point_labels = create_prompt(low_res_masks)
  41. points, point_labels = create_prompt_main(low_res_masks)
  42. points = points * 4
  43. points = (points, point_labels)
  44. with torch.no_grad():
  45. sparse_embeddings, dense_embeddings = self.prompt_encoder2(
  46. points=points,
  47. boxes=None,
  48. masks=None,
  49. )
  50. low_res_masks, _ = self.mask_decoder2(
  51. image_embeddings=image_embeddings,
  52. image_pe=self.prompt_encoder2.get_dense_pe().detach(),
  53. sparse_prompt_embeddings=sparse_embeddings.detach(),
  54. dense_prompt_embeddings=dense_embeddings.detach(),
  55. multimask_output=False,
  56. )
  57. outputs_prompt.append(low_res_masks)
  58. low_res_masks_promtp = torch.cat(outputs_prompt, dim=1)
  59. low_res_masks = torch.cat(outputs, dim=1)
  60. return low_res_masks, low_res_masks_promtp
  61. def double_conv_3d(in_channels, out_channels):
  62. return nn.Sequential(
  63. nn.Conv3d(in_channels, out_channels, kernel_size=(1, 3, 3), padding=(0, 1, 1)),
  64. nn.ReLU(inplace=True),
  65. nn.Conv3d(out_channels, out_channels, kernel_size=(3, 1, 1), padding=(1, 0, 0)),
  66. nn.ReLU(inplace=True),
  67. )
  68. #Was not used
  69. class UNet3D(nn.Module):
  70. def __init__(self):
  71. super(UNet3D, self).__init__()
  72. self.dconv_down1 = double_conv_3d(1, 32)
  73. self.dconv_down2 = double_conv_3d(32, 64)
  74. self.dconv_down3 = double_conv_3d(64, 96)
  75. self.maxpool = nn.MaxPool3d((1, 2, 2))
  76. self.upsample = nn.Upsample(
  77. scale_factor=(1, 2, 2), mode="trilinear", align_corners=True
  78. )
  79. self.dconv_up2 = double_conv_3d(64 + 96, 64)
  80. self.dconv_up1 = double_conv_3d(64 + 32, 32)
  81. self.conv_last = nn.Conv3d(32, 1, kernel_size=1)
  82. def forward(self, x):
  83. x = x.unsqueeze(1)
  84. conv1 = self.dconv_down1(x)
  85. x = self.maxpool(conv1)
  86. conv2 = self.dconv_down2(x)
  87. x = self.maxpool(conv2)
  88. x = self.dconv_down3(x)
  89. x = self.upsample(x)
  90. x = torch.cat([x, conv2], dim=1)
  91. x = self.dconv_up2(x)
  92. x = self.upsample(x)
  93. x = torch.cat([x, conv1], dim=1)
  94. x = self.dconv_up1(x)
  95. out = self.conv_last(x)
  96. return out
  97. class Conv3DFilter(nn.Module):
  98. def __init__(
  99. self,
  100. in_channels=1,
  101. out_channels=1,
  102. kernel_size=[(3, 1, 1), (3, 1, 1), (3, 1, 1), (3, 1, 1)],
  103. padding_sizes=None,
  104. custom_bias=0,
  105. ):
  106. super(Conv3DFilter, self).__init__()
  107. self.custom_bias = custom_bias
  108. self.bias = 1e-8
  109. # Convolutional layer with padding to maintain input spatial dimensions
  110. self.convs = nn.ModuleList(
  111. [
  112. nn.Sequential(
  113. nn.Conv3d(
  114. in_channels,
  115. out_channels,
  116. kernel_size[0],
  117. padding=padding_sizes[0],
  118. ),
  119. nn.ReLU(),
  120. nn.Conv3d(
  121. out_channels,
  122. out_channels,
  123. kernel_size[0],
  124. padding=padding_sizes[0],
  125. ),
  126. nn.ReLU(),
  127. )
  128. ]
  129. )
  130. for kernel, padding in zip(kernel_size[1:-1], padding_sizes[1:-1]):
  131. self.convs.extend(
  132. [
  133. nn.Sequential(
  134. nn.Conv3d(
  135. out_channels, out_channels, kernel, padding=padding
  136. ),
  137. nn.ReLU(),
  138. nn.Conv3d(
  139. out_channels, out_channels, kernel, padding=padding
  140. ),
  141. nn.ReLU(),
  142. )
  143. ]
  144. )
  145. self.output_conv = nn.Conv3d(
  146. out_channels, 1, kernel_size[-1], padding=padding_sizes[-1]
  147. )
  148. # self.m = nn.LeakyReLU(0.1)
  149. def forward(self, input):
  150. x = input.unsqueeze(1)
  151. for module in self.convs:
  152. x = module(x) + x
  153. x = self.output_conv(x)
  154. x = torch.sigmoid(x).squeeze(1)
  155. return x