You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 6.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. import torch
  2. import torch.nn as nn
  3. import numpy as np
  4. def create_prompt_simple(masks, forground=2, background=2):
  5. kernel_size = 9
  6. kernel = nn.Conv2d(
  7. in_channels=1,
  8. bias=False,
  9. out_channels=1,
  10. kernel_size=kernel_size,
  11. stride=1,
  12. padding=kernel_size // 2,
  13. )
  14. # print(kernel.weight.shape)
  15. kernel.weight = nn.Parameter(
  16. torch.zeros(1, 1, kernel_size, kernel_size).to(masks.device),
  17. requires_grad=False,
  18. )
  19. kernel.weight[0, 0] = 1.0
  20. eroded_masks = kernel(masks).squeeze(1)//(kernel_size**2)
  21. masks = masks.squeeze(1)
  22. use_eroded = (eroded_masks.sum(dim=(1, 2), keepdim=True) >= forground).float()
  23. new_masks = (eroded_masks * use_eroded) + (masks * (1 - use_eroded))
  24. all_points = []
  25. all_labels = []
  26. for i in range(len(new_masks)):
  27. new_background = background
  28. points = []
  29. labels = []
  30. new_mask = new_masks[i]
  31. nonzeros = torch.nonzero(new_mask, as_tuple=False)
  32. n_nonzero = len(nonzeros)
  33. if n_nonzero >= forground:
  34. indices = np.random.choice(
  35. np.arange(n_nonzero), size=forground, replace=False
  36. ).tolist()
  37. # raise ValueError(nonzeros[:, [0, 1]][indices])
  38. points.append(nonzeros[:, [1,0]][indices])
  39. labels.append(torch.ones(forground))
  40. else:
  41. if n_nonzero > 0:
  42. points.append(nonzeros)
  43. labels.append(torch.ones(n_nonzero))
  44. new_background += forground - n_nonzero
  45. # print(points, new_background)
  46. zeros = torch.nonzero(1 - masks[i], as_tuple=False)
  47. n_zero = len(zeros)
  48. indices = np.random.choice(
  49. np.arange(n_zero), size=new_background, replace=False
  50. ).tolist()
  51. points.append(zeros[:, [1, 0]][indices])
  52. labels.append(torch.zeros(new_background))
  53. points = torch.cat(points, dim=0)
  54. labels = torch.cat(labels, dim=0)
  55. all_points.append(points)
  56. all_labels.append(labels)
  57. all_points = torch.stack(all_points, dim=0)
  58. all_labels = torch.stack(all_labels, dim=0)
  59. return all_points, all_labels
  60. def distance_to_edge(point, image_shape):
  61. y, x = point
  62. height, width = image_shape
  63. distance_top = y
  64. distance_bottom = height - y
  65. distance_left = x
  66. distance_right = width - x
  67. return min(distance_top, distance_bottom, distance_left, distance_right)
  68. def create_prompt(probabilities, foreground=2, background=2):
  69. kernel_size = 9
  70. kernel = nn.Conv2d(
  71. in_channels=1,
  72. bias=False,
  73. out_channels=1,
  74. kernel_size=kernel_size,
  75. stride=1,
  76. padding=kernel_size // 2,
  77. )
  78. kernel.weight = nn.Parameter(
  79. torch.zeros(1, 1, kernel_size, kernel_size).to(probabilities.device),
  80. requires_grad=False,
  81. )
  82. kernel.weight[0, 0] = 1.0
  83. eroded_probs = kernel(probabilities).squeeze(1) / (kernel_size ** 2)
  84. probabilities = probabilities.squeeze(1)
  85. all_points = []
  86. all_labels = []
  87. for i in range(len(probabilities)):
  88. points = []
  89. labels = []
  90. prob_mask = probabilities[i]
  91. if torch.max(prob_mask) > 0.01:
  92. foreground_indices = torch.topk(prob_mask.view(-1), k=foreground, dim=0).indices
  93. foreground_points = torch.nonzero(prob_mask > 0, as_tuple=False)
  94. n_foreground = len(foreground_points)
  95. if n_foreground >= foreground:
  96. # Get the index of the point with the highest probability
  97. top_prob_idx = torch.topk(prob_mask.view(-1), k=1).indices[0]
  98. # Convert the flat index to 2D coordinates
  99. top_prob_point = np.unravel_index(top_prob_idx.item(), prob_mask.shape)
  100. top_prob_point = torch.tensor(top_prob_point, device=probabilities.device) # Move to the same device
  101. # Add the point with the highest probability to the points list
  102. points.append(torch.tensor([top_prob_point[1], top_prob_point[0]], device=probabilities.device).unsqueeze(0))
  103. labels.append(torch.ones(1, device=probabilities.device))
  104. # Exclude the top probability point when finding the point closest to the edge
  105. remaining_foreground_points = foreground_points[(foreground_points != top_prob_point.unsqueeze(0)).all(dim=1)]
  106. if remaining_foreground_points.numel() > 0:
  107. distances = [distance_to_edge(point.cpu().numpy(), prob_mask.shape) for point in remaining_foreground_points]
  108. edge_point_idx = np.argmin(distances)
  109. edge_point = remaining_foreground_points[edge_point_idx]
  110. # Add the edge point to the points list
  111. points.append(edge_point[[1, 0]].unsqueeze(0))
  112. labels.append(torch.ones(1, device=probabilities.device))
  113. # raise ValueError(points , labels)
  114. else:
  115. if n_foreground > 0:
  116. points.append(foreground_points[:, [1, 0]])
  117. labels.append(torch.ones(n_foreground))
  118. # Select 2 background points, one from 0 to -15 and one less than -15
  119. background_indices_1 = torch.nonzero((prob_mask < 0) & (prob_mask > -15), as_tuple=False)
  120. background_indices_2 = torch.nonzero(prob_mask < -15, as_tuple=False)
  121. # Randomly sample from each set of background points
  122. indices_1 = np.random.choice(np.arange(len(background_indices_1)), size=1, replace=False).tolist()
  123. indices_2 = np.random.choice(np.arange(len(background_indices_2)), size=1, replace=False).tolist()
  124. points.append(background_indices_1[indices_1])
  125. points.append(background_indices_2[indices_2])
  126. labels.append(torch.zeros(2))
  127. else:
  128. # If no probability is greater than 0, return 4 background points
  129. # print(prob_mask.unique())
  130. background_indices_1 = torch.nonzero(prob_mask < 0, as_tuple=False)
  131. indices_1 = np.random.choice(np.arange(len(background_indices_1)), size=4, replace=False).tolist()
  132. points.append(background_indices_1[indices_1])
  133. labels.append(torch.zeros(4))
  134. labels = [label.to(probabilities.device) for label in labels]
  135. points = torch.cat(points, dim=0)
  136. all_points.append(points)
  137. all_labels.append(torch.cat(labels, dim=0))
  138. all_points = torch.stack(all_points, dim=0)
  139. all_labels = torch.stack(all_labels, dim=0)
  140. # print(all_points, all_labels)
  141. return all_points, all_labels