You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 6.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. import torch
  2. import torch.nn as nn
  3. import numpy as np
  4. import torch.nn.functional as F
  5. def distance_to_edge(point, image_shape):
  6. y, x = point
  7. height, width = image_shape
  8. distance_top = y
  9. distance_bottom = height - y
  10. distance_left = x
  11. distance_right = width - x
  12. return min(distance_top, distance_bottom, distance_left, distance_right)
  13. def sample_prompt(probabilities, forground=2, background=2):
  14. kernel_size = 9
  15. kernel = nn.Conv2d(
  16. in_channels=1,
  17. bias=False,
  18. out_channels=1,
  19. kernel_size=kernel_size,
  20. stride=1,
  21. padding=kernel_size // 2,
  22. )
  23. kernel.weight = nn.Parameter(
  24. torch.zeros(1, 1, kernel_size, kernel_size).to(probabilities.device),
  25. requires_grad=False,
  26. )
  27. kernel.weight[0, 0] = 1.0
  28. eroded_probs = kernel(probabilities).squeeze(1) / (kernel_size ** 2)
  29. probabilities = probabilities.squeeze(1)
  30. all_points = []
  31. all_labels = []
  32. for i in range(len(probabilities)):
  33. points = []
  34. labels = []
  35. prob_mask = probabilities[i]
  36. if torch.max(prob_mask) > 0.01:
  37. foreground_indices = torch.topk(prob_mask.view(-1), k=forground, dim=0).indices
  38. foreground_points = torch.nonzero(prob_mask > 0, as_tuple=False)
  39. n_foreground = len(foreground_points)
  40. if n_foreground >= forground:
  41. # Calculate distance to edge for each point
  42. distances = [distance_to_edge(point.cpu().numpy(), prob_mask.shape) for point in foreground_points]
  43. # Find the point with minimum distance to edge
  44. edge_point_idx = np.argmin(distances)
  45. edge_point = foreground_points[edge_point_idx]
  46. # Append the point closest to the edge and another random point
  47. points.append(edge_point[[1, 0]].unsqueeze(0))
  48. indices_foreground = np.random.choice(np.arange(n_foreground), size=forground-1, replace=False).tolist()
  49. selected_foreground = foreground_points[indices_foreground]
  50. points.append(selected_foreground[:, [1, 0]])
  51. labels.append(torch.ones(forground))
  52. else:
  53. if n_foreground > 0:
  54. points.append(foreground_points[:, [1, 0]])
  55. labels.append(torch.ones(n_foreground))
  56. # Select 2 background points, one from 0 to -15 and one less than -15
  57. background_indices_1 = torch.nonzero((prob_mask < 0) & (prob_mask > -15), as_tuple=False)
  58. background_indices_2 = torch.nonzero(prob_mask < -15, as_tuple=False)
  59. # Randomly sample from each set of background points
  60. indices_1 = np.random.choice(np.arange(len(background_indices_1)), size=1, replace=False).tolist()
  61. indices_2 = np.random.choice(np.arange(len(background_indices_2)), size=1, replace=False).tolist()
  62. points.append(background_indices_1[indices_1])
  63. points.append(background_indices_2[indices_2])
  64. labels.append(torch.zeros(2))
  65. else:
  66. # If no probability is greater than 0, return 4 background points
  67. # print(prob_mask.unique())
  68. background_indices_1 = torch.nonzero(prob_mask < 0, as_tuple=False)
  69. indices_1 = np.random.choice(np.arange(len(background_indices_1)), size=4, replace=False).tolist()
  70. points.append(background_indices_1[indices_1])
  71. labels.append(torch.zeros(4))
  72. points = torch.cat(points, dim=0)
  73. labels = torch.cat(labels, dim=0)
  74. all_points.append(points)
  75. all_labels.append(labels)
  76. all_points = torch.stack(all_points, dim=0)
  77. all_labels = torch.stack(all_labels, dim=0)
  78. # print(all_points, all_labels)
  79. return all_points, all_labels
  80. device = "cuda:0"
  81. def main_prompt(probabilities):
  82. probabilities = probabilities.sigmoid()
  83. # Thresholding function
  84. def threshold(tensor, thresh):
  85. return (tensor > thresh).float()
  86. # Morphological operations
  87. def morphological_op(tensor, operation, kernel_size):
  88. kernel = torch.ones(1, 1, kernel_size[0], kernel_size[1]).to(tensor.device)
  89. if kernel_size[0] % 2 == 0:
  90. padding = [(k - 1) // 2 for k in kernel_size]
  91. extra_pad = [0, 2, 0, 2]
  92. else:
  93. padding = [(k - 1) // 2 for k in kernel_size]
  94. extra_pad = [0, 0, 0, 0]
  95. if operation == 'erode':
  96. tensor = F.conv2d(F.pad(tensor, extra_pad), kernel, padding=padding).clamp(max=1)
  97. elif operation == 'dilate':
  98. tensor = F.max_pool2d(F.pad(tensor, extra_pad), kernel_size, stride=1, padding=padding).clamp(max=1)
  99. if kernel_size[0] % 2 == 0:
  100. tensor = tensor[:, :, :tensor.shape[2] - 1, :tensor.shape[3] - 1]
  101. return tensor.squeeze(1)
  102. # Foreground prompts
  103. th_O = threshold(probabilities, 0.5)
  104. M_f = morphological_op(morphological_op(th_O, 'erode', (10, 10)), 'dilate', (5, 5))
  105. foreground_indices = torch.nonzero(M_f.squeeze(0), as_tuple=False)
  106. n_for = 2 if len(foreground_indices) >= 2 else len(foreground_indices)
  107. n_back = 4 - n_for
  108. # Background prompts
  109. M_b1 = 1 - morphological_op(threshold(probabilities, 0.5), 'dilate', (10, 10))
  110. M_b2 = 1 - threshold(probabilities, 0.4)
  111. M_b2 = M_b2.squeeze(1)
  112. M_b = M_b1 * M_b2
  113. M_b = M_b.squeeze(0)
  114. background_indices = torch.nonzero(M_b, as_tuple=False)
  115. if n_for > 0:
  116. indices = torch.concat([foreground_indices[np.random.choice(np.arange(len(foreground_indices)), size=n_for)],
  117. background_indices[np.random.choice(np.arange(len(background_indices)), size=n_back)]
  118. ])
  119. values = torch.tensor([1] * n_for + [0] * n_back)
  120. else:
  121. indices = background_indices[np.random.choice(np.arange(len(background_indices)), size=4)]
  122. values = torch.tensor([0] * 4)
  123. # raise ValueError(indices, values)
  124. return indices.unsqueeze(0), values.unsqueeze(0)