| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122 |
- import torch
- import torch.nn as nn
- import numpy as np
- import torch.nn.functional as F
-
- def create_prompt_simple(masks, forground=2, background=2):
- kernel_size = 9
- kernel = nn.Conv2d(
- in_channels=1,
- bias=False,
- out_channels=1,
- kernel_size=kernel_size,
- stride=1,
- padding=kernel_size // 2,
- )
- # print(kernel.weight.shape)
- kernel.weight = nn.Parameter(
- torch.zeros(1, 1, kernel_size, kernel_size).to(masks.device),
- requires_grad=False,
- )
- kernel.weight[0, 0] = 1.0
- eroded_masks = kernel(masks).squeeze(1)//(kernel_size**2)
- masks = masks.squeeze(1)
- use_eroded = (eroded_masks.sum(dim=(1, 2), keepdim=True) >= forground).float()
- new_masks = (eroded_masks * use_eroded) + (masks * (1 - use_eroded))
- all_points = []
- all_labels = []
- for i in range(len(new_masks)):
- new_background = background
- points = []
- labels = []
- new_mask = new_masks[i]
- nonzeros = torch.nonzero(new_mask, as_tuple=False)
- n_nonzero = len(nonzeros)
- if n_nonzero >= forground:
- indices = np.random.choice(
- np.arange(n_nonzero), size=forground, replace=False
- ).tolist()
- # raise ValueError(nonzeros[:, [0, 1]][indices])
- points.append(nonzeros[:, [1,0]][indices])
- labels.append(torch.ones(forground))
- else:
- if n_nonzero > 0:
- points.append(nonzeros)
- labels.append(torch.ones(n_nonzero))
- new_background += forground - n_nonzero
- # print(points, new_background)
- zeros = torch.nonzero(1 - masks[i], as_tuple=False)
- n_zero = len(zeros)
- indices = np.random.choice(
- np.arange(n_zero), size=new_background, replace=False
- ).tolist()
- points.append(zeros[:, [1, 0]][indices])
- labels.append(torch.zeros(new_background))
- points = torch.cat(points, dim=0)
- labels = torch.cat(labels, dim=0)
- all_points.append(points)
- all_labels.append(labels)
- all_points = torch.stack(all_points, dim=0)
- all_labels = torch.stack(all_labels, dim=0)
- return all_points, all_labels
-
-
-
-
- #
- device = "cuda:0"
- def create_prompt_main(probabilities):
- probabilities = probabilities.sigmoid()
-
- # Thresholding function
- def threshold(tensor, thresh):
- return (tensor > thresh).float()
-
- # Morphological operations
- def morphological_op(tensor, operation, kernel_size):
- kernel = torch.ones(1, 1, kernel_size[0], kernel_size[1]).to(tensor.device)
- if kernel_size[0] % 2 == 0:
- padding = [(k - 1) // 2 for k in kernel_size]
- extra_pad = [0, 2, 0, 2]
- else:
- padding = [(k - 1) // 2 for k in kernel_size]
- extra_pad = [0, 0, 0, 0]
-
- if operation == 'erode':
- tensor = F.conv2d(F.pad(tensor, extra_pad), kernel, padding=padding).clamp(max=1)
- elif operation == 'dilate':
- tensor = F.max_pool2d(F.pad(tensor, extra_pad), kernel_size, stride=1, padding=padding).clamp(max=1)
-
- if kernel_size[0] % 2 == 0:
- tensor = tensor[:, :, :tensor.shape[2] - 1, :tensor.shape[3] - 1]
-
- return tensor.squeeze(1)
-
- # Foreground prompts
- th_O = threshold(probabilities, 0.5)
- M_f = morphological_op(morphological_op(th_O, 'erode', (10, 10)), 'dilate', (5, 5))
- foreground_indices = torch.nonzero(M_f.squeeze(0), as_tuple=False)
- n_for = 2 if len(foreground_indices) >= 2 else len(foreground_indices)
- n_back = 4 - n_for
- # Background prompts
- M_b1 = 1 - morphological_op(threshold(probabilities, 0.5), 'dilate', (10, 10))
- M_b2 = 1 - threshold(probabilities, 0.4)
- M_b2 = M_b2.squeeze(1)
-
- M_b = M_b1 * M_b2
- M_b = M_b.squeeze(0)
- background_indices = torch.nonzero(M_b, as_tuple=False)
-
- if n_for > 0:
- indices = torch.concat([foreground_indices[np.random.choice(np.arange(len(foreground_indices)), size=n_for)],
- background_indices[np.random.choice(np.arange(len(background_indices)), size=n_back)]
- ])
- values = torch.tensor([1] * n_for + [0] * n_back)
- else:
- indices = background_indices[np.random.choice(np.arange(len(background_indices)), size=4)]
- values = torch.tensor([0] * 4)
- # raise ValueError(indices, values)
- return indices.unsqueeze(0), values.unsqueeze(0)
-
-
|