| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158 |
- import torch
- import torch.nn as nn
- import numpy as np
- import torch.nn.functional as F
-
-
- def distance_to_edge(point, image_shape):
- y, x = point
- height, width = image_shape
- distance_top = y
- distance_bottom = height - y
- distance_left = x
- distance_right = width - x
- return min(distance_top, distance_bottom, distance_left, distance_right)
-
- def sample_prompt(probabilities, forground=2, background=2):
- kernel_size = 9
- kernel = nn.Conv2d(
- in_channels=1,
- bias=False,
- out_channels=1,
- kernel_size=kernel_size,
- stride=1,
- padding=kernel_size // 2,
- )
- kernel.weight = nn.Parameter(
- torch.zeros(1, 1, kernel_size, kernel_size).to(probabilities.device),
- requires_grad=False,
- )
- kernel.weight[0, 0] = 1.0
- eroded_probs = kernel(probabilities).squeeze(1) / (kernel_size ** 2)
- probabilities = probabilities.squeeze(1)
-
- all_points = []
- all_labels = []
-
- for i in range(len(probabilities)):
- points = []
- labels = []
-
- prob_mask = probabilities[i]
-
- if torch.max(prob_mask) > 0.01:
- foreground_indices = torch.topk(prob_mask.view(-1), k=forground, dim=0).indices
- foreground_points = torch.nonzero(prob_mask > 0, as_tuple=False)
- n_foreground = len(foreground_points)
-
- if n_foreground >= forground:
- # Calculate distance to edge for each point
- distances = [distance_to_edge(point.cpu().numpy(), prob_mask.shape) for point in foreground_points]
-
- # Find the point with minimum distance to edge
- edge_point_idx = np.argmin(distances)
- edge_point = foreground_points[edge_point_idx]
-
- # Append the point closest to the edge and another random point
- points.append(edge_point[[1, 0]].unsqueeze(0))
- indices_foreground = np.random.choice(np.arange(n_foreground), size=forground-1, replace=False).tolist()
- selected_foreground = foreground_points[indices_foreground]
- points.append(selected_foreground[:, [1, 0]])
- labels.append(torch.ones(forground))
- else:
- if n_foreground > 0:
- points.append(foreground_points[:, [1, 0]])
- labels.append(torch.ones(n_foreground))
-
-
-
- # Select 2 background points, one from 0 to -15 and one less than -15
- background_indices_1 = torch.nonzero((prob_mask < 0) & (prob_mask > -15), as_tuple=False)
- background_indices_2 = torch.nonzero(prob_mask < -15, as_tuple=False)
-
- # Randomly sample from each set of background points
- indices_1 = np.random.choice(np.arange(len(background_indices_1)), size=1, replace=False).tolist()
- indices_2 = np.random.choice(np.arange(len(background_indices_2)), size=1, replace=False).tolist()
-
- points.append(background_indices_1[indices_1])
- points.append(background_indices_2[indices_2])
- labels.append(torch.zeros(2))
- else:
- # If no probability is greater than 0, return 4 background points
- # print(prob_mask.unique())
- background_indices_1 = torch.nonzero(prob_mask < 0, as_tuple=False)
-
- indices_1 = np.random.choice(np.arange(len(background_indices_1)), size=4, replace=False).tolist()
-
- points.append(background_indices_1[indices_1])
- labels.append(torch.zeros(4))
-
- points = torch.cat(points, dim=0)
- labels = torch.cat(labels, dim=0)
-
- all_points.append(points)
- all_labels.append(labels)
-
- all_points = torch.stack(all_points, dim=0)
- all_labels = torch.stack(all_labels, dim=0)
- # print(all_points, all_labels)
-
- return all_points, all_labels
-
-
-
- device = "cuda:0"
- def main_prompt(probabilities):
- probabilities = probabilities.sigmoid()
-
- # Thresholding function
- def threshold(tensor, thresh):
- return (tensor > thresh).float()
-
- # Morphological operations
- def morphological_op(tensor, operation, kernel_size):
- kernel = torch.ones(1, 1, kernel_size[0], kernel_size[1]).to(tensor.device)
- if kernel_size[0] % 2 == 0:
- padding = [(k - 1) // 2 for k in kernel_size]
- extra_pad = [0, 2, 0, 2]
- else:
- padding = [(k - 1) // 2 for k in kernel_size]
- extra_pad = [0, 0, 0, 0]
-
- if operation == 'erode':
- tensor = F.conv2d(F.pad(tensor, extra_pad), kernel, padding=padding).clamp(max=1)
- elif operation == 'dilate':
- tensor = F.max_pool2d(F.pad(tensor, extra_pad), kernel_size, stride=1, padding=padding).clamp(max=1)
-
- if kernel_size[0] % 2 == 0:
- tensor = tensor[:, :, :tensor.shape[2] - 1, :tensor.shape[3] - 1]
-
- return tensor.squeeze(1)
-
- # Foreground prompts
- th_O = threshold(probabilities, 0.5)
- M_f = morphological_op(morphological_op(th_O, 'erode', (10, 10)), 'dilate', (5, 5))
- foreground_indices = torch.nonzero(M_f.squeeze(0), as_tuple=False)
- n_for = 2 if len(foreground_indices) >= 2 else len(foreground_indices)
- n_back = 4 - n_for
- # Background prompts
- M_b1 = 1 - morphological_op(threshold(probabilities, 0.5), 'dilate', (10, 10))
- M_b2 = 1 - threshold(probabilities, 0.4)
- M_b2 = M_b2.squeeze(1)
-
- M_b = M_b1 * M_b2
- M_b = M_b.squeeze(0)
- background_indices = torch.nonzero(M_b, as_tuple=False)
-
- if n_for > 0:
- indices = torch.concat([foreground_indices[np.random.choice(np.arange(len(foreground_indices)), size=n_for)],
- background_indices[np.random.choice(np.arange(len(background_indices)), size=n_back)]
- ])
- values = torch.tensor([1] * n_for + [0] * n_back)
- else:
- indices = background_indices[np.random.choice(np.arange(len(background_indices)), size=4)]
- values = torch.tensor([0] * 4)
- # raise ValueError(indices, values)
- return indices.unsqueeze(0), values.unsqueeze(0)
-
|