You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 4.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. import torch
  2. import torch.nn as nn
  3. import numpy as np
  4. import torch.nn.functional as F
  5. def create_prompt_simple(masks, forground=2, background=2):
  6. kernel_size = 9
  7. kernel = nn.Conv2d(
  8. in_channels=1,
  9. bias=False,
  10. out_channels=1,
  11. kernel_size=kernel_size,
  12. stride=1,
  13. padding=kernel_size // 2,
  14. )
  15. # print(kernel.weight.shape)
  16. kernel.weight = nn.Parameter(
  17. torch.zeros(1, 1, kernel_size, kernel_size).to(masks.device),
  18. requires_grad=False,
  19. )
  20. kernel.weight[0, 0] = 1.0
  21. eroded_masks = kernel(masks).squeeze(1)//(kernel_size**2)
  22. masks = masks.squeeze(1)
  23. use_eroded = (eroded_masks.sum(dim=(1, 2), keepdim=True) >= forground).float()
  24. new_masks = (eroded_masks * use_eroded) + (masks * (1 - use_eroded))
  25. all_points = []
  26. all_labels = []
  27. for i in range(len(new_masks)):
  28. new_background = background
  29. points = []
  30. labels = []
  31. new_mask = new_masks[i]
  32. nonzeros = torch.nonzero(new_mask, as_tuple=False)
  33. n_nonzero = len(nonzeros)
  34. if n_nonzero >= forground:
  35. indices = np.random.choice(
  36. np.arange(n_nonzero), size=forground, replace=False
  37. ).tolist()
  38. # raise ValueError(nonzeros[:, [0, 1]][indices])
  39. points.append(nonzeros[:, [1,0]][indices])
  40. labels.append(torch.ones(forground))
  41. else:
  42. if n_nonzero > 0:
  43. points.append(nonzeros)
  44. labels.append(torch.ones(n_nonzero))
  45. new_background += forground - n_nonzero
  46. # print(points, new_background)
  47. zeros = torch.nonzero(1 - masks[i], as_tuple=False)
  48. n_zero = len(zeros)
  49. indices = np.random.choice(
  50. np.arange(n_zero), size=new_background, replace=False
  51. ).tolist()
  52. points.append(zeros[:, [1, 0]][indices])
  53. labels.append(torch.zeros(new_background))
  54. points = torch.cat(points, dim=0)
  55. labels = torch.cat(labels, dim=0)
  56. all_points.append(points)
  57. all_labels.append(labels)
  58. all_points = torch.stack(all_points, dim=0)
  59. all_labels = torch.stack(all_labels, dim=0)
  60. return all_points, all_labels
  61. #
  62. device = "cuda:0"
  63. def create_prompt_main(probabilities):
  64. probabilities = probabilities.sigmoid()
  65. # Thresholding function
  66. def threshold(tensor, thresh):
  67. return (tensor > thresh).float()
  68. # Morphological operations
  69. def morphological_op(tensor, operation, kernel_size):
  70. kernel = torch.ones(1, 1, kernel_size[0], kernel_size[1]).to(tensor.device)
  71. if kernel_size[0] % 2 == 0:
  72. padding = [(k - 1) // 2 for k in kernel_size]
  73. extra_pad = [0, 2, 0, 2]
  74. else:
  75. padding = [(k - 1) // 2 for k in kernel_size]
  76. extra_pad = [0, 0, 0, 0]
  77. if operation == 'erode':
  78. tensor = F.conv2d(F.pad(tensor, extra_pad), kernel, padding=padding).clamp(max=1)
  79. elif operation == 'dilate':
  80. tensor = F.max_pool2d(F.pad(tensor, extra_pad), kernel_size, stride=1, padding=padding).clamp(max=1)
  81. if kernel_size[0] % 2 == 0:
  82. tensor = tensor[:, :, :tensor.shape[2] - 1, :tensor.shape[3] - 1]
  83. return tensor.squeeze(1)
  84. # Foreground prompts
  85. th_O = threshold(probabilities, 0.5)
  86. M_f = morphological_op(morphological_op(th_O, 'erode', (10, 10)), 'dilate', (5, 5))
  87. foreground_indices = torch.nonzero(M_f.squeeze(0), as_tuple=False)
  88. n_for = 2 if len(foreground_indices) >= 2 else len(foreground_indices)
  89. n_back = 4 - n_for
  90. # Background prompts
  91. M_b1 = 1 - morphological_op(threshold(probabilities, 0.5), 'dilate', (10, 10))
  92. M_b2 = 1 - threshold(probabilities, 0.4)
  93. M_b2 = M_b2.squeeze(1)
  94. M_b = M_b1 * M_b2
  95. M_b = M_b.squeeze(0)
  96. background_indices = torch.nonzero(M_b, as_tuple=False)
  97. if n_for > 0:
  98. indices = torch.concat([foreground_indices[np.random.choice(np.arange(len(foreground_indices)), size=n_for)],
  99. background_indices[np.random.choice(np.arange(len(background_indices)), size=n_back)]
  100. ])
  101. values = torch.tensor([1] * n_for + [0] * n_back)
  102. else:
  103. indices = background_indices[np.random.choice(np.arange(len(background_indices)), size=4)]
  104. values = torch.tensor([0] * 4)
  105. # raise ValueError(indices, values)
  106. return indices.unsqueeze(0), values.unsqueeze(0)