Browse Source

Thesis

master
Amirhossein Bagheri 2 months ago
parent
commit
cf80b33684
72 changed files with 8539 additions and 0 deletions
  1. 170
    0
      Adaptive_Bias/dummyrun.py
  2. 180
    0
      Adaptive_Bias/resunet_focal.py
  3. 0
    0
      CCT/__init__.py
  4. 149
    0
      CCT/cct.py
  5. 258
    0
      CCT/decoders.py
  6. 32
    0
      CCT/ramps.py
  7. 412
    0
      CCT/utils.py
  8. 18
    0
      README.md
  9. 15
    0
      ReadMe.txt
  10. 50
    0
      advent/advent.py
  11. 246
    0
      advent/advent_ent_minimization_ssl.py
  12. 77
    0
      advent/deeplabv2.py
  13. 76
    0
      advent/deeplabv2_activations.py
  14. 82
    0
      advent/deeplabv2_dropout.py
  15. 99
    0
      advent/entropy_minimization.py
  16. 93
    0
      advent/entropy_minimization_1.py
  17. 161
    0
      advent/entropy_minimization_semi_supervised.py
  18. 108
    0
      advent/entropy_minimization_tumor_augmentation.py
  19. 52
    0
      advent/losses.py
  20. 47
    0
      advent/msc.py
  21. 47
    0
      advent/msc_confident.py
  22. 148
    0
      advent/resnet.py
  23. 79
    0
      base_line/ConvUnext/convUnext_focal.py
  24. 84
    0
      base_line/Convnext/convnext_focal.py
  25. 80
    0
      base_line/Deepvlab3/deepvlab3_focal.py
  26. 0
    0
      base_line/Resunet/Heavy/Models/__init__.py
  27. 180
    0
      base_line/Resunet/Heavy/Models/resunet.py
  28. 0
    0
      base_line/Resunet/Heavy/__init__.py
  29. 50
    0
      base_line/Resunet/Heavy/resunet_ddsm.py
  30. 79
    0
      base_line/Resunet/Heavy/resunet_focal.py
  31. 107
    0
      base_line/Resunet/Heavy/resunet_focal_ent.py
  32. 96
    0
      base_line/Resunet/Heavy/resunet_topk_focal.py
  33. 0
    0
      base_line/Resunet/__init__.py
  34. 80
    0
      base_line/Resunet/resunet_ce.py
  35. 80
    0
      base_line/Resunet/resunet_ce_mean.py
  36. 80
    0
      base_line/Resunet/resunet_dice.py
  37. 79
    0
      base_line/Resunet/resunet_focal.py
  38. 107
    0
      base_line/Resunet/resunet_focal_ent.py
  39. 94
    0
      base_line/Resunet/resunet_tilted.py
  40. 85
    0
      base_line/Resunet/resunet_topk.py
  41. 87
    0
      base_line/Resunet/resunet_topk_focal.py
  42. 85
    0
      base_line/Resunet/resunet_tvfocal.py
  43. 174
    0
      base_line/UNET/Model/unet.py
  44. 79
    0
      base_line/UNET/resunet_focal.py
  45. 156
    0
      base_line/res_inception.py
  46. 80
    0
      base_line/resunet.py
  47. 0
    0
      base_line/transunet.py
  48. 396
    0
      segmentation/Baseline/TransUNet/transunet_focal.py
  49. 401
    0
      segmentation/Baseline/TransUNet/transunet_topk.py
  50. 61
    0
      segmentation/Baseline/UNet/unet.py
  51. 169
    0
      segmentation/Baseline/UNet/unet_model.py
  52. 101
    0
      segmentation/CPS/cps_resunet.py
  53. 36
    0
      segmentation/CPS/utils.py
  54. 97
    0
      segmentation/ICT/ict.py
  55. 132
    0
      segmentation/ICT/ict_resunet.py
  56. 70
    0
      segmentation/ICT/loss.py
  57. 157
    0
      segmentation/ICT/unet.py
  58. 154
    0
      segmentation/URPC/loss.py
  59. 351
    0
      segmentation/URPC/resunet.py
  60. 80
    0
      segmentation/URPC/urpc.py
  61. 92
    0
      segmentation/URPC/urpc_resunet.py
  62. 153
    0
      segmentation/URPC/utils.py
  63. 111
    0
      segmentation/UniMatch/Models/deeplabv3plus.py
  64. 154
    0
      segmentation/UniMatch/Models/resnet.py
  65. 228
    0
      segmentation/UniMatch/Models/resunet.py
  66. 174
    0
      segmentation/UniMatch/Models/unet.py
  67. 147
    0
      segmentation/UniMatch/fixmatch.py
  68. 153
    0
      segmentation/UniMatch/fixmatch_resunet.py
  69. 40
    0
      segmentation/UniMatch/transform.py
  70. 160
    0
      segmentation/UniMatch/unimatch.py
  71. 187
    0
      segmentation/UniMatch/unimatch_resunet.py
  72. 194
    0
      uasmt/UASMT.py

+ 170
- 0
Adaptive_Bias/dummyrun.py View File

from copy import deepcopy
import math
from random import uniform
from typing import List
import time
from mlassistant.core import Model, ModelIO
import torch
from torch.nn import functional as F
from torch import nn
from ....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi
from ....utils.generalized_dice import dice_loss
from ....utils.losses.tverskyLoss import tversky_loss
from ..base_line.Resunet.Heavy.Models.resunet import Encoder
from ..base_line.Resunet.Heavy.Models.resunet import Decoder


class BIAS_MLP(Model):
def __init__(self) -> None:
super(BIAS_MLP, self).__init__()
self.MLP = nn.ModuleList([
nn.Conv2d(2,16,1),
nn.ReLU(),
nn.Conv2d(16,8,1),
nn.ReLU(),
nn.Conv2d(8,2,1),
])
def forward(self,x):
out = x
for module in self.MLP:
out = module(out)
return out



class RESUNET(Model):
def __init__(self,
gamma = 4 , smooth=1, alpha=0.7, beta=0.3, params = None) -> None:
super(RESUNET, self).__init__()
self.gamma = gamma
self.smooth = smooth
self.alpha = alpha
self.beta = beta
print(self.gamma)
print("gamma")
self.MLP = BIAS_MLP()
self.encoder = Encoder(params)
self.decoder = Decoder(params)
self.epochNUM = 0
self.forward_call = 0
self.mode = True

def train(self,mode : bool=True):
super().train(mode)
self.epochNUM += 1
self.mode = mode
if mode:
self.forward_call = 0

def forward(self,
mammo_x: torch.Tensor,
mammo_loss_and_gt: torch.Tensor) -> ModelIO:
"""first forwards supervised samples through student net, then all samples through both nets. """
output = dict(loss=0.0)
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)"
loss_type = mammo_loss_and_gt[:,0,:,:]
mask = mammo_loss_and_gt[:,1,:,:]
if self.mode:
if (not self.mode) or (int(self.forward_call) % 50 == 0) and(self.forward_call > 10):
print("changing")
print(list(self.encoder.parameters())[0][0])
network_output = self.decoder(self.encoder(mammo_x))
network_output_soft = network_output.softmax(dim=1) #torch.exp(network_output).to(mammo_x.device) # softmax output of model

bias_output = self.MLP(network_output) + network_output
bias_output = bias_output.softmax(dim=1)

output['pixel_probs'] = network_output_soft

main_shape = mask.shape
network_output_shape = network_output.shape

loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2])
mask = mask.unsqueeze(dim = 1)
mask_channels = torch.cat((1-mask, mask), 1)


dummy_mask_l = mask_channels.clone().to(mammo_x.device)
dummy_net_out_l = bias_output.clone().to(mammo_x.device)
dummy_mask_l[loss_channels == 0] = 0.0
dummy_shape_l = dummy_mask_l.shape
dummy_mask_l = torch.cat((dummy_mask_l,torch.ones((dummy_shape_l[0],1,dummy_shape_l[2],dummy_shape_l[2])).to(mammo_x.device)), 1)
dummy_net_out_l = torch.cat((dummy_net_out_l,torch.ones((dummy_shape_l[0],1,dummy_shape_l[2],dummy_shape_l[2])).to(mammo_x.device)), 1)
focal_loss_l = balanced_focal_cross_entropy_loss_semi(
dummy_net_out_l,
dummy_mask_l,
focal_gamma=self.gamma)
with torch.no_grad():
dummy_mask_u = (bias_output >= 0.5).clone().to(mammo_x.device) # TODO sanity check
dummy_net_out_u = bias_output.clone().to(mammo_x.device)
dummy_mask_u[loss_channels == 1.0] = 0.0
dummy_shape_u = dummy_mask_u.shape
dummy_mask_u = torch.cat((dummy_mask_u,torch.ones((dummy_shape_u[0],1,dummy_shape_u[2],dummy_shape_u[2])).to(mammo_x.device)), 1)
dummy_net_out_u = torch.cat((dummy_net_out_u,torch.ones((dummy_shape_u[0],1,dummy_shape_u[2],dummy_shape_u[2])).to(mammo_x.device)), 1)
focal_loss_u = balanced_focal_cross_entropy_loss_semi(
dummy_net_out_u,
dummy_mask_u,
focal_gamma=self.gamma)
# now two loss for eq 3 are here focal_loss_u & focal_loss_l

# choose balance set
output['loss'] = focal_loss_u + focal_loss_l
output['MYloss'] = output["loss"]
output['focalloss_l'] = focal_loss_l
output['focalloss_u'] = focal_loss_u
output['suploss'] = focal_loss_l
output['unsuploss'] = focal_loss_u
output['org_pixel_labels'] = mammo_loss_and_gt
else:
with torch.no_grad(): # TODO check sanity
#self.encoder.
print("torch no grad")
features = self.encoder(mammo_x)
print(list(self.encoder.parameters())[0][0])
network_output = self.decoder(features)
network_output_soft = network_output.softmax(dim=1)
output['pixel_probs'] = network_output_soft
main_shape = mask.shape
network_output_shape = network_output.shape

loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2])
mask = mask.unsqueeze(dim = 1)
mask_channels = torch.cat((1-mask, mask), 1)


dummy_mask_l = mask_channels.clone().to(mammo_x.device)
dummy_net_out_l = network_output_soft.clone().to(mammo_x.device)
balance_mask = loss_channels.clone().to(mammo_x.device)
#balance_mask[loss_channels == 0] = 2 # problem here sometimes all of it are supeprvised so we apply it in all of image.
main_shape = balance_mask.shape
number_1 = torch.sum(balance_mask == 1).to(mammo_x.device)
number_0 = torch.sum(balance_mask == 0).to(mammo_x.device)
random_matrix = torch.rand(main_shape).to(mammo_x.device)
#print(number_0, number_1)
if number_1 > number_0: #choosing balanced version
balance_mask[(balance_mask == 1) & (random_matrix >= number_0/number_1)] = 2
else:
balance_mask[(balance_mask == 0) & (random_matrix >= number_1/number_0)] = 2

dummy_mask_l[balance_mask == 2] = 0.0
dummy_shape_l = dummy_mask_l.shape
dummy_mask_l = torch.cat((dummy_mask_l,torch.ones((dummy_shape_l[0],1,dummy_shape_l[2],dummy_shape_l[2])).to(mammo_x.device)), 1)
dummy_net_out_l = torch.cat((dummy_net_out_l,torch.ones((dummy_shape_l[0],1,dummy_shape_l[2],dummy_shape_l[2])).to(mammo_x.device)), 1)
focal_loss_l = balanced_focal_cross_entropy_loss_semi(
dummy_net_out_l,
dummy_mask_l,
focal_gamma=self.gamma)
output['loss'] = focal_loss_l
output['MYloss'] = output["loss"]
output['focalloss_l'] = focal_loss_l
output['focalloss_u'] = focal_loss_l * 0
output['suploss'] = focal_loss_l
output['unsuploss'] = focal_loss_l * 0
output['org_pixel_labels'] = mammo_loss_and_gt



self.forward_call += 1
return output

+ 180
- 0
Adaptive_Bias/resunet_focal.py View File

from copy import deepcopy
import math
from random import uniform
from typing import List
import time
from mlassistant.core import Model, ModelIO
import torch
from torch.nn import functional as F
from torch import nn
from ....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi
from ....utils.generalized_dice import dice_loss
from ....utils.losses.tverskyLoss import tversky_loss
from ....models.semi_supervised.base_line.Resunet.Heavy.Models.resunet import Encoder
from ....models.semi_supervised.base_line.Resunet.Heavy.Models.resunet import Decoder


class BIAS_MLP(Model):
def __init__(self) -> None:
super(BIAS_MLP, self).__init__()
self.MLP = nn.ModuleList([
nn.Conv2d(2,16,1),
nn.ReLU(),
nn.Conv2d(16,8,1),
nn.ReLU(),
nn.Conv2d(8,2,1),
])
def forward(self,x):
out = x
for module in self.MLP:
out = module(out)
return out



class RESUNET(Model):
def __init__(self,
gamma = 4 , smooth=1, alpha=0.7, beta=0.3, params = None) -> None:
super(RESUNET, self).__init__()
self.gamma = gamma
self.smooth = smooth
self.alpha = alpha
self.beta = beta
print(self.gamma)
print("gamma")
self.MLP = BIAS_MLP()
self.encoder = Encoder(params)
self.decoder = Decoder(params)
self.epochNUM = 0
self.forward_call = 0
self.mode = True

def train(self,mode : bool=True):
super().train(mode)
self.epochNUM += 1
self.mode = mode
if mode:
self.forward_call = 0

def forward(self,
mammo_x: torch.Tensor,
mammo_loss_and_gt: torch.Tensor) -> ModelIO:
"""first forwards supervised samples through student net, then all samples through both nets. """
output = dict(loss=0.0)
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)"
loss_type = mammo_loss_and_gt[:,0,:,:]
mask = mammo_loss_and_gt[:,1,:,:]
if (self.mode) and int(self.forward_call) % 2 == 0:
# print("second")
# print(list(self.encoder.parameters())[0][0])
network_output = self.decoder(self.encoder(mammo_x))
network_output_soft = network_output.softmax(dim=1) #torch.exp(network_output).to(mammo_x.device) # softmax output of model

bias_output = self.MLP(network_output) + network_output
bias_output = bias_output.softmax(dim=1)

output['pixel_probs'] = network_output_soft

main_shape = mask.shape
network_output_shape = network_output.shape

loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2])
mask = mask.unsqueeze(dim = 1)
mask_channels = torch.cat((1-mask, mask), 1)


dummy_mask_l = mask_channels.clone().to(mammo_x.device)
dummy_net_out_l = bias_output.clone().to(mammo_x.device)
dummy_mask_l[loss_channels == 0] = 0.0
dummy_shape_l = dummy_mask_l.shape
dummy_mask_l = torch.cat((dummy_mask_l,torch.ones((dummy_shape_l[0],1,dummy_shape_l[2],dummy_shape_l[2])).to(mammo_x.device)), 1)
dummy_net_out_l = torch.cat((dummy_net_out_l,torch.ones((dummy_shape_l[0],1,dummy_shape_l[2],dummy_shape_l[2])).to(mammo_x.device)), 1)
focal_loss_l = balanced_focal_cross_entropy_loss_semi(
dummy_net_out_l,
dummy_mask_l,
focal_gamma=self.gamma)
with torch.no_grad():
dummy_mask_u = (bias_output >= 0.5).clone().to(mammo_x.device) # TODO sanity check
dummy_net_out_u = bias_output.clone().to(mammo_x.device)
dummy_mask_u[loss_channels == 1.0] = 0.0
dummy_shape_u = dummy_mask_u.shape
dummy_mask_u = torch.cat((dummy_mask_u,torch.ones((dummy_shape_u[0],1,dummy_shape_u[2],dummy_shape_u[2])).to(mammo_x.device)), 1)
dummy_net_out_u = torch.cat((dummy_net_out_u,torch.ones((dummy_shape_u[0],1,dummy_shape_u[2],dummy_shape_u[2])).to(mammo_x.device)), 1)
focal_loss_u = balanced_focal_cross_entropy_loss_semi(
dummy_net_out_u,
dummy_mask_u,
focal_gamma=self.gamma)
# now two loss for eq 3 are here focal_loss_u & focal_loss_l

# choose balance set
output['loss'] = focal_loss_u + focal_loss_l
output['MYloss'] = output["loss"]
output['focalloss_l'] = focal_loss_l
output['focalloss_u'] = focal_loss_u
output['suploss'] = focal_loss_l
output['unsuploss'] = focal_loss_u
output['org_pixel_labels'] = mammo_loss_and_gt
elif self.mode:
with torch.no_grad(): # TODO check sanity
#self.encoder.
# print("first")
features = self.encoder(mammo_x)
# print(list(self.encoder.parameters())[0][0])
network_output = self.decoder(features)
network_output_soft = network_output.softmax(dim=1)
output['pixel_probs'] = network_output_soft
main_shape = mask.shape
network_output_shape = network_output.shape

loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2])
mask = mask.unsqueeze(dim = 1)
mask_channels = torch.cat((1-mask, mask), 1)


dummy_mask_l = mask_channels.clone().to(mammo_x.device)
dummy_net_out_l = network_output_soft.clone().to(mammo_x.device)
balance_mask = loss_channels.clone().to(mammo_x.device)
#balance_mask[loss_channels == 0] = 2 # problem here sometimes all of it are supeprvised so we apply it in all of image.
main_shape = balance_mask.shape
number_1 = torch.sum(balance_mask == 1).to(mammo_x.device)
number_0 = torch.sum(balance_mask == 0).to(mammo_x.device)
random_matrix = torch.rand(main_shape).to(mammo_x.device)
print(number_0, number_1)
if number_1 > number_0: #choosing balanced version
balance_mask[(balance_mask == 1) & (random_matrix >= number_0/number_1)] = 2
else:
balance_mask[(balance_mask == 0) & (random_matrix >= number_1/number_0)] = 2

dummy_mask_l[balance_mask == 2] = 0.0
dummy_shape_l = dummy_mask_l.shape
dummy_mask_l = torch.cat((dummy_mask_l,torch.ones((dummy_shape_l[0],1,dummy_shape_l[2],dummy_shape_l[2])).to(mammo_x.device)), 1)
dummy_net_out_l = torch.cat((dummy_net_out_l,torch.ones((dummy_shape_l[0],1,dummy_shape_l[2],dummy_shape_l[2])).to(mammo_x.device)), 1)
focal_loss_l = balanced_focal_cross_entropy_loss_semi(
dummy_net_out_l,
dummy_mask_l,
focal_gamma=self.gamma)
output['loss'] = focal_loss_l
output['MYloss'] = output["loss"]
output['focalloss_l'] = focal_loss_l
output['focalloss_u'] = focal_loss_l * 0
output['suploss'] = focal_loss_l
output['unsuploss'] = focal_loss_l * 0
output['org_pixel_labels'] = mammo_loss_and_gt
else:
network_output = self.decoder(self.encoder(mammo_x))
network_output_soft = network_output.softmax(dim=1) #torch.exp(network_output).to(mammo_x.device) # softmax output of mode
output['pixel_probs'] = network_output_soft
output['loss'] = torch.tensor(0)
output['MYloss'] = torch.tensor(0)
output['focalloss_l'] = torch.tensor(0)
output['focalloss_u'] = torch.tensor(0)
output['suploss'] = torch.tensor(0)
output['unsuploss'] = torch.tensor(0)
output['org_pixel_labels'] = mammo_loss_and_gt


self.forward_call += 1
return output

+ 0
- 0
CCT/__init__.py View File


+ 149
- 0
CCT/cct.py View File

from copy import deepcopy
import math
from random import uniform
import mlassistant
from typing import List
import time
from mlassistant.core import Model, ModelIO
import torch
from torch.nn import functional as F
from torch import nn
from ....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi
from ....utils.generalized_dice import dice_loss
from ....utils.losses.tverskyLoss import tversky_loss
from ....models.semi_supervised.base_line.Resunet.Heavy.Models.resunet import Encoder
from ....models.semi_supervised.base_line.Resunet.Heavy.Models.resunet import Decoder
from .utils import *
from .decoders import *

class CCT(Model):
def __init__(self,
gamma = 4 , smooth=1, alpha=0.7, beta=0.3, params = None , Epochs = 200, iter = 200) -> None:
super(CCT, self).__init__()
self.gamma = gamma
self.smooth = smooth
self.alpha = alpha
self.beta = beta
print(self.gamma)
print("gamma")
self.encoder = Encoder(params)
self.main_decoder = Decoder(params)
self.conf = Config["model"]
self.conft = Config

rampup_ends = int(self.conft['ramp_up'] * Epochs)
cons_w_unsup = consistency_weight(final_w=self.conft['unsupervised_w'], iters_per_epoch=iter,
rampup_ends=rampup_ends)

self.unsup_loss_w = cons_w_unsup

self.unsuper_loss = softmax_mse_loss
self.sup_loss_w = self.conf['supervised_w']
self.softmax_temp = self.conf['softmax_temp']
self.sup_type = self.conf['sup_loss']

# Use weak labels
self.use_weak_lables = self.conft['use_weak_lables']
self.weakly_loss_w = self.conft['weakly_loss_w']
# pair wise loss (sup mat)
self.aux_constraint = self.conf['aux_constraint']
self.aux_constraint_w = self.conf['aux_constraint_w']
# confidence masking (sup mat)
self.confidence_th = self.conf['confidence_th']
self.confidence_masking = self.conf['confidence_masking']

upscale = 8
num_out_ch = 512
decoder_in_ch = num_out_ch
num_classes = 2

# The auxilary decoders
vat_decoder = [VATDecoder(upscale, decoder_in_ch, num_classes, xi=self.conf['xi'],
eps=self.conf['eps']) for _ in range(self.conf['vat'])]
drop_decoder = [DropOutDecoder(upscale, decoder_in_ch, num_classes,
drop_rate=self.conf['drop_rate'], spatial_dropout=self.conf['spatial'])
for _ in range(self.conf['drop'])]
cut_decoder = [CutOutDecoder(upscale, decoder_in_ch, num_classes, erase=self.conf['erase'])
for _ in range(self.conf['cutout'])]
context_m_decoder = [ContextMaskingDecoder(upscale, decoder_in_ch, num_classes)
for _ in range(self.conf['context_masking'])]
object_masking = [ObjectMaskingDecoder(upscale, decoder_in_ch, num_classes)
for _ in range(self.conf['object_masking'])]
feature_drop = [FeatureDropDecoder(upscale, decoder_in_ch, num_classes)
for _ in range(self.conf['feature_drop'])]
feature_noise = [FeatureNoiseDecoder(upscale, decoder_in_ch, num_classes,
uniform_range=self.conf['uniform_range'])
for _ in range(self.conf['feature_noise'])]

# self.aux_decoders = nn.ModuleList([*vat_decoder, *drop_decoder, *cut_decoder,
# *context_m_decoder, *object_masking, *feature_drop, *feature_noise])
self.aux_decoders = nn.ModuleList([*drop_decoder,*vat_decoder,*feature_noise,
*context_m_decoder, *object_masking, *feature_drop])
self.iter = 0
self.mode = True


def train(self,mode : bool=True):
super().train(mode)
self.mode = mode
self.iter = 0
print("epoch:",mlassistant.context.epoch.EpochContext.epoch)
def forward(self,
mammo_x: torch.Tensor,
mammo_loss_and_gt: torch.Tensor) -> ModelIO:
"""first forwards supervised samples through student net, then all samples through both nets. """
self.iter = 1
output = dict(loss=0.0)
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)"
loss_type = mammo_loss_and_gt[:,0,:,:]
mask = mammo_loss_and_gt[:,1,:,:]
curr_iter= self.iter
epoch= mlassistant.context.epoch.EpochContext.epoch
network_output = self.main_decoder(self.encoder(mammo_x))
network_output_soft = network_output.softmax(dim=1)
output['pixel_probs'] = network_output_soft
output_ul = F.interpolate(network_output_soft, size=(256,256), mode='bilinear')


# Get auxiliary predictions
outputs_ul = [aux_decoder(self.encoder(mammo_x)[-1], output_ul.detach()) for aux_decoder in self.aux_decoders]
targets = F.softmax(output_ul.detach(), dim=1)

# Compute unsupervised loss
loss_unsup = sum([self.unsuper_loss(inputs=u, targets=targets, \
conf_mask=self.confidence_masking, threshold=self.confidence_th, use_softmax=False)
for u in outputs_ul])
main_shape = mask.shape
network_output_shape = network_output.shape
loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2])
mask = mask.unsqueeze(dim = 1)
mask_channels = torch.cat((1-mask, mask), 1)
loss_unsup = (loss_unsup / len(outputs_ul))

dummy_mask = mask_channels.clone().to(mammo_x.device)
dummy_net_out = network_output_soft.clone().to(mammo_x.device)
dummy_mask[loss_channels == 0] = 0.0
dummy_shape = dummy_mask.shape
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_net_out = torch.cat((dummy_net_out,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
focal_loss = balanced_focal_cross_entropy_loss_semi(
dummy_net_out,
dummy_mask,
focal_gamma=self.gamma)

loss_sup = focal_loss
weight_u = self.unsup_loss_w(epoch=epoch, curr_iter=curr_iter)


output['loss'] = loss_sup * self.sup_loss_w + loss_unsup * weight_u
output['MYloss'] = output["loss"]
output['focalloss'] = focal_loss
output['suploss'] = loss_sup
output['unsuploss'] = loss_unsup
output['sup_loss_wloss'] = torch.tensor(self.sup_loss_w)
output['lossweight_u'] = torch.tensor(weight_u)
output['org_pixel_labels'] = mammo_loss_and_gt

self.iter += 1
return output

+ 258
- 0
CCT/decoders.py View File

import math , time
import torch
import torch.nn.functional as F
from torch import nn
from itertools import chain
import contextlib
import random
import numpy as np
import cv2
from torch.distributions.uniform import Uniform


def icnr(x, scale=2, init=nn.init.kaiming_normal_):
"""
Checkerboard artifact free sub-pixel convolution
https://arxiv.org/abs/1707.02937
"""
ni,nf,h,w = x.shape
ni2 = int(ni/(scale**2))
k = init(torch.zeros([ni2,nf,h,w])).transpose(0, 1)
k = k.contiguous().view(ni2, nf, -1)
k = k.repeat(1, 1, scale**2)
k = k.contiguous().view([nf,ni,h,w]).transpose(0, 1)
x.data.copy_(k)


class PixelShuffle(nn.Module):
"""
Real-Time Single Image and Video Super-Resolution
https://arxiv.org/abs/1609.05158
"""
def __init__(self, n_channels, scale):
super(PixelShuffle, self).__init__()
self.conv = nn.Conv2d(n_channels, n_channels*(scale**2), kernel_size=1)
icnr(self.conv.weight)
self.shuf = nn.PixelShuffle(scale)
self.relu = nn.ReLU(inplace=True)

def forward(self,x):
x = self.shuf(self.relu(self.conv(x)))
return x


def upsample(in_channels, out_channels, upscale, kernel_size=3):
# A series of x 2 upsamling until we get to the upscale we want
layers = []
conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
nn.init.kaiming_normal_(conv1x1.weight.data, nonlinearity='relu')
layers.append(conv1x1)
for i in range(int(math.log(upscale, 2))):
layers.append(PixelShuffle(out_channels, scale=2))
return nn.Sequential(*layers)


class MainDecoder(nn.Module):
def __init__(self, upscale, conv_in_ch, num_classes):
super(MainDecoder, self).__init__()
self.upsample = upsample(conv_in_ch, num_classes, upscale=upscale)

def forward(self, x):
x = self.upsample(x)
return x


class DropOutDecoder(nn.Module):
def __init__(self, upscale, conv_in_ch, num_classes, drop_rate=0.3, spatial_dropout=True):
super(DropOutDecoder, self).__init__()
self.dropout = nn.Dropout2d(p=drop_rate) if spatial_dropout else nn.Dropout(drop_rate)
self.upsample = upsample(conv_in_ch, num_classes, upscale=upscale)

def forward(self, x, _):
x = self.upsample(self.dropout(x))
return x


class FeatureDropDecoder(nn.Module):
def __init__(self, upscale, conv_in_ch, num_classes):
super(FeatureDropDecoder, self).__init__()
self.upsample = upsample(conv_in_ch, num_classes, upscale=upscale)

def feature_dropout(self, x):
attention = torch.mean(x, dim=1, keepdim=True)
max_val, _ = torch.max(attention.view(x.size(0), -1), dim=1, keepdim=True)
threshold = max_val * np.random.uniform(0.7, 0.9)
threshold = threshold.view(x.size(0), 1, 1, 1).expand_as(attention)
drop_mask = (attention < threshold).float()
return x.mul(drop_mask)

def forward(self, x, _):
x = self.feature_dropout(x)
x = self.upsample(x)
return x


class FeatureNoiseDecoder(nn.Module):
def __init__(self, upscale, conv_in_ch, num_classes, uniform_range=0.3):
super(FeatureNoiseDecoder, self).__init__()
self.upsample = upsample(conv_in_ch, num_classes, upscale=upscale)
self.uni_dist = Uniform(-uniform_range, uniform_range)

def feature_based_noise(self, x):
noise_vector = self.uni_dist.sample(x.shape[1:]).to(x.device).unsqueeze(0)
x_noise = x.mul(noise_vector) + x
return x_noise

def forward(self, x, _):
x = self.feature_based_noise(x)
x = self.upsample(x)
return x



def _l2_normalize(d):
# Normalizing per batch axis
d_reshaped = d.view(d.shape[0], -1, *(1 for _ in range(d.dim() - 2)))
d /= torch.norm(d_reshaped, dim=1, keepdim=True) + 1e-8
return d


def get_r_adv(x, decoder, it=1, xi=1e-1, eps=10.0):
"""
Virtual Adversarial Training
https://arxiv.org/abs/1704.03976
"""
x_detached = x.detach()
with torch.no_grad():
pred = F.softmax(decoder(x_detached), dim=1)

d = torch.rand(x.shape).sub(0.5).to(x.device)
d = _l2_normalize(d)

for _ in range(it):
d.requires_grad_()
pred_hat = decoder(x_detached + xi * d)
logp_hat = F.log_softmax(pred_hat, dim=1)
adv_distance = F.kl_div(logp_hat, pred, reduction='batchmean')
adv_distance.backward()
d = _l2_normalize(d.grad)
decoder.zero_grad()

r_adv = d * eps
return r_adv


class VATDecoder(nn.Module):
def __init__(self, upscale, conv_in_ch, num_classes, xi=1e-1, eps=10.0, iterations=1):
super(VATDecoder, self).__init__()
self.xi = xi
self.eps = eps
self.it = iterations
self.upsample = upsample(conv_in_ch, num_classes, upscale=upscale)

def forward(self, x, _):
r_adv = get_r_adv(x, self.upsample, self.it, self.xi, self.eps)
x = self.upsample(x + r_adv)
return x



def guided_cutout(output, upscale, resize, erase=0.4, use_dropout=False):
if len(output.shape) == 3:
masks = (output > 0).float()
else:
masks = (output.argmax(1) > 0).float()

if use_dropout:
p_drop = random.randint(3, 6)/10
maskdroped = (F.dropout(masks, p_drop) > 0).float()
maskdroped = maskdroped + (1 - masks)
maskdroped.unsqueeze_(0)
maskdroped = F.interpolate(maskdroped, size=resize, mode='nearest')

masks_np = []
for mask in masks:
mask_np = np.uint8(mask.cpu().numpy())
mask_ones = np.ones_like(mask_np)
try: # Version 3.x
_, contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
except: # Version 4.x
contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

polys = [c.reshape(c.shape[0], c.shape[-1]) for c in contours if c.shape[0] > 50]
for poly in polys:
min_w, max_w = poly[:, 0].min(), poly[:, 0].max()
min_h, max_h = poly[:, 1].min(), poly[:, 1].max()
bb_w, bb_h = max_w-min_w, max_h-min_h
rnd_start_w = random.randint(0, int(bb_w*(1-erase)))
rnd_start_h = random.randint(0, int(bb_h*(1-erase)))
h_start, h_end = min_h+rnd_start_h, min_h+rnd_start_h+int(bb_h*erase)
w_start, w_end = min_w+rnd_start_w, min_w+rnd_start_w+int(bb_w*erase)
mask_ones[h_start:h_end, w_start:w_end] = 0
masks_np.append(mask_ones)
masks_np = np.stack(masks_np)

maskcut = torch.from_numpy(masks_np).float().unsqueeze_(1)
maskcut = F.interpolate(maskcut, size=resize, mode='nearest')

if use_dropout:
return maskcut.to(output.device), maskdroped.to(output.device)
return maskcut.to(output.device)


class CutOutDecoder(nn.Module):
def __init__(self, upscale, conv_in_ch, num_classes, drop_rate=0.3, spatial_dropout=True, erase=0.4):
super(CutOutDecoder, self).__init__()
self.erase = erase
self.upscale = upscale
self.upsample = upsample(conv_in_ch, num_classes, upscale=upscale)

def forward(self, x, pred=None):
maskcut = guided_cutout(pred, upscale=self.upscale, erase=self.erase, resize=(x.size(2), x.size(3)))
x = x * maskcut
x = self.upsample(x)
return x


def guided_masking(x, output, upscale, resize, return_msk_context=True):
if len(output.shape) == 3:
masks_context = (output > 0).float().unsqueeze(1)
else:
masks_context = (output.argmax(1) > 0).float().unsqueeze(1)
masks_context = F.interpolate(masks_context, size=resize, mode='nearest')

x_masked_context = masks_context * x
if return_msk_context:
return x_masked_context

masks_objects = (1 - masks_context)
x_masked_objects = masks_objects * x
return x_masked_objects


class ContextMaskingDecoder(nn.Module):
def __init__(self, upscale, conv_in_ch, num_classes):
super(ContextMaskingDecoder, self).__init__()
self.upscale = upscale
self.upsample = upsample(conv_in_ch, num_classes, upscale=upscale)

def forward(self, x, pred=None):
x_masked_context = guided_masking(x, pred, resize=(x.size(2), x.size(3)),
upscale=self.upscale, return_msk_context=True)
x_masked_context = self.upsample(x_masked_context)
return x_masked_context


class ObjectMaskingDecoder(nn.Module):
def __init__(self, upscale, conv_in_ch, num_classes):
super(ObjectMaskingDecoder, self).__init__()
self.upscale = upscale
self.upsample = upsample(conv_in_ch, num_classes, upscale=upscale)

def forward(self, x, pred=None):
x_masked_obj = guided_masking(x, pred, resize=(x.size(2), x.size(3)),
upscale=self.upscale, return_msk_context=False)
x_masked_obj = self.upsample(x_masked_obj)

return x_masked_obj

+ 32
- 0
CCT/ramps.py View File

import numpy as np

def sigmoid_rampup(current, rampup_length):
if rampup_length == 0:
return 1.0
current = np.clip(current, 0.0, rampup_length)
phase = 1.0 - current / rampup_length
return float(np.exp(-5.0 * phase * phase))

def linear_rampup(current, rampup_length):
assert current >= 0 and rampup_length >= 0
if current >= rampup_length:
return 1.0
return current / rampup_length

def cosine_rampup(current, rampup_length):
if rampup_length == 0:
return 1.0
current = np.clip(current, 0.0, rampup_length)
return 1 - float(.5 * (np.cos(np.pi * current / rampup_length) + 1))

def log_rampup(current, rampup_length):
if rampup_length == 0:
return 1.0
current = np.clip(current, 0.0, rampup_length)
return float(1- np.exp(-5.0 * current / rampup_length))
def exp_rampup(current, rampup_length):
if rampup_length == 0:
return 1.0
current = np.clip(current, 0.0, rampup_length)
return float(np.exp(5.0 * (current / rampup_length - 1)))

+ 412
- 0
CCT/utils.py View File

Config = {
"name": "CCT",
"experim_name": "CCT",
"n_gpu": 1,
"n_labeled_examples": 1464,
"diff_lrs": True,
"ramp_up": 0.1,
"unsupervised_w": 30,
"ignore_index": 255,
"lr_scheduler": "Poly",
"use_weak_lables":False,
"weakly_loss_w": 0.4,
"pretrained": True,

"model":{
"supervised": False,
"semi": True,
"supervised_w": 1,

"sup_loss": "CE",
"un_loss": "MSE",

"softmax_temp": 1,
"aux_constraint": False,
"aux_constraint_w": 1,
"confidence_masking": False,
"confidence_th": 0.5,

"drop": 6,
"drop_rate": 0.5,
"spatial": True,
"cutout": 6,
"erase": 0.4,
"vat": 2,
"xi": 1e-6,
"eps": 2.0,

"context_masking": 2,
"object_masking": 2,
"feature_drop": 6,

"feature_noise": 6,
"uniform_range": 0.3
},


"optimizer": {
"type": "SGD",
"args":{
"lr": 1e-2,
"weight_decay": 1e-4,
"momentum": 0.9
}
},


# "train_supervised": {
# "data_dir": "VOCtrainval_11-May-2012",
# "batch_size": 10,
# "crop_size": 320,
# "shuffle": True,
# "base_size": 400,
# "scale": True,
# "augment": True,
# "flip": True,
# "rotate": False,
# "blur": False,
# "split": "train_supervised",
# "num_workers": 8
# },

# "train_unsupervised": {
# "data_dir": "VOCtrainval_11-May-2012",
# "weak_labels_output": "pseudo_labels/result/pseudo_labels",
# "batch_size": 10,
# "crop_size": 320,
# "shuffle": True,
# "base_size": 400,
# "scale": True,
# "augment": True,
# "flip": True,
# "rotate": False,
# "blur": False,
# "split": "train_unsupervised",
# "num_workers": 8
# },

# "val_loader": {
# "data_dir": "VOCtrainval_11-May-2012",
# "batch_size": 1,
# "val": True,
# "split": "val",
# "shuffle": False,
# "num_workers": 4
# },

# "trainer": {
# "epochs": 80,
# "save_dir": "saved/",
# "save_period": 5,
# "monitor": "max Mean_IoU",
# "early_stop": 10,
# "tensorboardX": True,
# "log_dir": "saved/",
# "log_per_iter": 20,

# "val": True,
# "val_per_epochs": 5
# }
}
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn


def sigmoid_rampup(current, rampup_length):
if rampup_length == 0:
return 1.0
current = np.clip(current, 0.0, rampup_length)
phase = 1.0 - current / rampup_length
return float(np.exp(-5.0 * phase * phase))

class consistency_weight(object):
"""
ramp_types = ['sigmoid_rampup', 'linear_rampup', 'cosine_rampup', 'log_rampup', 'exp_rampup']
"""
def __init__(self, final_w, iters_per_epoch, rampup_starts=0, rampup_ends=7, ramp_type='sigmoid_rampup'):
self.final_w = final_w
self.iters_per_epoch = iters_per_epoch
self.rampup_starts = rampup_starts * iters_per_epoch
self.rampup_ends = rampup_ends * iters_per_epoch
self.rampup_length = (self.rampup_ends - self.rampup_starts)
self.rampup_func = sigmoid_rampup
self.current_rampup = 0

def __call__(self, epoch, curr_iter):
cur_total_iter = self.iters_per_epoch * epoch + curr_iter
if cur_total_iter < self.rampup_starts:
return 0
self.current_rampup = self.rampup_func(cur_total_iter - self.rampup_starts, self.rampup_length)
return self.final_w * self.current_rampup


def CE_loss(input_logits, target_targets, ignore_index, temperature=1):
return F.cross_entropy(input_logits/temperature, target_targets, ignore_index=ignore_index)

# for FocalLoss
def softmax_helper(x):
# copy from: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/utilities/nd_softmax.py
rpt = [1 for _ in range(len(x.size()))]
rpt[1] = x.size(1)
x_max = x.max(1, keepdim=True)[0].repeat(*rpt)
e_x = torch.exp(x - x_max)
return e_x / e_x.sum(1, keepdim=True).repeat(*rpt)

def get_alpha(supervised_loader):
# get number of classes
num_labels = 0
for image_batch, label_batch in supervised_loader:
label_batch.data[label_batch.data==255] = 0 # pixels of ignore class added to background
l_unique = torch.unique(label_batch.data)
list_unique = [element.item() for element in l_unique.flatten()]
num_labels = max(max(list_unique),num_labels)
num_classes = num_labels + 1
# count class occurrences
alpha = [0 for i in range(num_classes)]
for image_batch, label_batch in supervised_loader:
label_batch.data[label_batch.data==255] = 0 # pixels of ignore class added to background
l_unique = torch.unique(label_batch.data)
list_unique = [element.item() for element in l_unique.flatten()]
l_unique_count = torch.stack([(label_batch.data==x_u).sum() for x_u in l_unique]) # tensor([65920, 36480])
list_count = [count.item() for count in l_unique_count.flatten()]
for index in list_unique:
alpha[index] += list_count[list_unique.index(index)]
return alpha

# for FocalLoss
def softmax_helper(x):
# copy from: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/utilities/nd_softmax.py
rpt = [1 for _ in range(len(x.size()))]
rpt[1] = x.size(1)
x_max = x.max(1, keepdim=True)[0].repeat(*rpt)
e_x = torch.exp(x - x_max)
return e_x / e_x.sum(1, keepdim=True).repeat(*rpt)


class FocalLoss(nn.Module):
"""
copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py
This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
Focal_Loss= -1*alpha*(1-pt)*log(pt)
:param num_class:
:param alpha: (tensor) 3D or 4D the scalar factor for this criterion
:param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
focus on hard misclassified example
:param smooth: (float,double) smooth value when cross entropy
:param balance_index: (int) balance class index, should be specific when alpha is float
:param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
"""

def __init__(self, apply_nonlin=None, ignore_index = None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True):
super(FocalLoss, self).__init__()
self.apply_nonlin = apply_nonlin
self.alpha = alpha
self.gamma = gamma
self.balance_index = balance_index
self.smooth = smooth
self.size_average = size_average

if self.smooth is not None:
if self.smooth < 0 or self.smooth > 1.0:
raise ValueError('smooth value should be in [0,1]')

def forward(self, logit, target):
if self.apply_nonlin is not None:
logit = self.apply_nonlin(logit)
num_class = logit.shape[1]

if logit.dim() > 2:
# N,C,d1,d2 -> N,C,m (m=d1*d2*...)
logit = logit.view(logit.size(0), logit.size(1), -1)
logit = logit.permute(0, 2, 1).contiguous()
logit = logit.view(-1, logit.size(-1))
target = torch.squeeze(target, 1)
target = target.view(-1, 1)
valid_mask = None
if self.ignore_index is not None:
valid_mask = target != self.ignore_index
target = target * valid_mask
alpha = self.alpha

if alpha is None:
alpha = torch.ones(num_class, 1)
elif isinstance(alpha, (list, np.ndarray)):
assert len(alpha) == num_class
alpha = torch.FloatTensor(alpha).view(num_class, 1)
alpha = alpha / alpha.sum()
alpha = 1/alpha # inverse of class frequency
elif isinstance(alpha, float):
alpha = torch.ones(num_class, 1)
alpha = alpha * (1 - self.alpha)
alpha[self.balance_index] = self.alpha

else:
raise TypeError('Not support alpha type')
if alpha.device != logit.device:
alpha = alpha.to(logit.device)

idx = target.cpu().long()

one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_()
# to resolve error in idx in scatter_
idx[idx==225]=0
one_hot_key = one_hot_key.scatter_(1, idx, 1)
if one_hot_key.device != logit.device:
one_hot_key = one_hot_key.to(logit.device)

if self.smooth:
one_hot_key = torch.clamp(
one_hot_key, self.smooth/(num_class-1), 1.0 - self.smooth)
pt = (one_hot_key * logit).sum(1) + self.smooth
logpt = pt.log()

gamma = self.gamma

alpha = alpha[idx]
alpha = torch.squeeze(alpha)
loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
if valid_mask is not None:
loss = loss * valid_mask.squeeze()
if self.size_average:
loss = loss.mean()
else:
loss = loss.sum()
return loss


class abCE_loss(nn.Module):
"""
Annealed-Bootstrapped cross-entropy loss
"""
def __init__(self, iters_per_epoch, epochs, num_classes, weight=None,
reduction='mean', thresh=0.7, min_kept=1, ramp_type='log_rampup'):
super(abCE_loss, self).__init__()
self.weight = torch.FloatTensor(weight) if weight is not None else weight
self.reduction = reduction
self.thresh = thresh
self.min_kept = min_kept
self.ramp_type = ramp_type
if ramp_type is not None:
self.rampup_func = getattr(ramps, ramp_type)
self.iters_per_epoch = iters_per_epoch
self.num_classes = num_classes
self.start = 1/num_classes
self.end = 0.9
self.total_num_iters = (epochs - (0.6 * epochs)) * iters_per_epoch

def threshold(self, curr_iter, epoch):
cur_total_iter = self.iters_per_epoch * epoch + curr_iter
current_rampup = self.rampup_func(cur_total_iter, self.total_num_iters)
return current_rampup * (self.end - self.start) + self.start

def forward(self, predict, target, ignore_index, curr_iter, epoch):
batch_kept = self.min_kept * target.size(0)
prob_out = F.softmax(predict, dim=1)
tmp_target = target.clone()
tmp_target[tmp_target == ignore_index] = 0
prob = prob_out.gather(1, tmp_target.unsqueeze(1))
mask = target.contiguous().view(-1, ) != ignore_index
sort_prob, sort_indices = prob.contiguous().view(-1, )[mask].contiguous().sort()

if self.ramp_type is not None:
thresh = self.threshold(curr_iter=curr_iter, epoch=epoch)
else:
thresh = self.thresh

min_threshold = sort_prob[min(batch_kept, sort_prob.numel() - 1)] if sort_prob.numel() > 0 else 0.0
threshold = max(min_threshold, thresh)
loss_matrix = F.cross_entropy(predict, target,
weight=self.weight.to(predict.device) if self.weight is not None else None,
ignore_index=ignore_index, reduction='none')
loss_matirx = loss_matrix.contiguous().view(-1, )
sort_loss_matirx = loss_matirx[mask][sort_indices]
select_loss_matrix = sort_loss_matirx[sort_prob < threshold]
if self.reduction == 'sum' or select_loss_matrix.numel() == 0:
return select_loss_matrix.sum()
elif self.reduction == 'mean':
return select_loss_matrix.mean()
else:
raise NotImplementedError('Reduction Error!')



def softmax_mse_loss(inputs, targets, conf_mask=False, threshold=None, use_softmax=False):
assert inputs.requires_grad == True and targets.requires_grad == False
assert inputs.size() == targets.size() # (batch_size * num_classes * H * W)
inputs = F.softmax(inputs, dim=1)
if use_softmax:
targets = F.softmax(targets, dim=1)

if conf_mask:
loss_mat = F.mse_loss(inputs, targets, reduction='none')
mask = (targets.max(1)[0] > threshold)
loss_mat = loss_mat[mask.unsqueeze(1).expand_as(loss_mat)]
if loss_mat.shape.numel() == 0: loss_mat = torch.tensor([0.]).to(inputs.device)
return loss_mat.mean()
else:
return F.mse_loss(inputs, targets, reduction='mean') # take the mean over the batch_size


def softmax_kl_loss(inputs, targets, conf_mask=False, threshold=None, use_softmax=False):
assert inputs.requires_grad == True and targets.requires_grad == False
assert inputs.size() == targets.size()
input_log_softmax = F.log_softmax(inputs, dim=1)
if use_softmax:
targets = F.softmax(targets, dim=1)
if conf_mask:
loss_mat = F.kl_div(input_log_softmax, targets, reduction='none')
mask = (targets.max(1)[0] > threshold)
loss_mat = loss_mat[mask.unsqueeze(1).expand_as(loss_mat)]
if loss_mat.shape.numel() == 0: loss_mat = torch.tensor([0.]).to(inputs.device)
return loss_mat.sum() / mask.shape.numel()
else:
return F.kl_div(input_log_softmax, targets, reduction='mean')


def softmax_js_loss(inputs, targets, **_):
assert inputs.requires_grad == True and targets.requires_grad == False
assert inputs.size() == targets.size()
epsilon = 1e-5

M = (F.softmax(inputs, dim=1) + targets) * 0.5
kl1 = F.kl_div(F.log_softmax(inputs, dim=1), M, reduction='mean')
kl2 = F.kl_div(torch.log(targets+epsilon), M, reduction='mean')
return (kl1 + kl2) * 0.5



def pair_wise_loss(unsup_outputs, size_average=True, nbr_of_pairs=8):
"""
Pair-wise loss in the sup. mat.
"""
if isinstance(unsup_outputs, list):
unsup_outputs = torch.stack(unsup_outputs)

# Only for a subset of the aux outputs to reduce computation and memory
unsup_outputs = unsup_outputs[torch.randperm(unsup_outputs.size(0))]
unsup_outputs = unsup_outputs[:nbr_of_pairs]

temp = torch.zeros_like(unsup_outputs) # For grad purposes
for i, u in enumerate(unsup_outputs):
temp[i] = F.softmax(u, dim=1)
mean_prediction = temp.mean(0).unsqueeze(0) # Mean over the auxiliary outputs
pw_loss = ((temp - mean_prediction)**2).mean(0) # Variance
pw_loss = pw_loss.sum(1) # Sum over classes
if size_average:
return pw_loss.mean()
return pw_loss.sum()

+ 18
- 0
README.md View File

# Enhancing_Injury_Segmentation_in_Breast_Mammograms_Through_Semi-Supervised_Learning # Enhancing_Injury_Segmentation_in_Breast_Mammograms_Through_Semi-Supervised_Learning



In order to run the Base lines tables in the Thesis you just need to go to base_line folder. Every one of the methods has a folder.
ResUnet has Implementation with different loss functions in it. You can use these models as your models for retrieving previous results for your special use. You can also use other SSL methods same way as before.


For you to train your model you just would need to use forward pass of the model. You will see the below function prototype in all the methods. Now we will explain this function.

```python
def forward(self,
mammo_x: torch.Tensor,
mammo_loss_and_gt: torch.Tensor)
```

mammo_x is actually the [Batch, size, size] which is x-ray image that we use as input data.
mammo_loss_and_gt is [Batch, 2, size, size]. mammo_loss_and_gt[:, 1, :, :] is the mask for our x-ray, referred as GT in Thesis. mammo_loss_and_gt[:, 0, :, :] is actually the DCM mask mentioned in the thesis.

By providing the inputs for forward function you can use all the models code easily in your projects.
Good Luck!

+ 15
- 0
ReadMe.txt View File

In order to run the Base lines tables in the Thesis you just need to go to base_line folder. Every one of the methods has a folder.
ResUnet has Implementation with different loss functions in it. You can use these models as your models for retrieving previous results for your special use. You can also use other SSL methods same way as before.

For you to train your model you just would need to use forward pass of the model. You will see the below function prototype in all the methods. Now we will explain this function.


def forward(self,
mammo_x: torch.Tensor,
mammo_loss_and_gt: torch.Tensor)

mammo_x is actually the [Batch, size, size] which is x-ray image that we use as input data.
mammo_loss_and_gt is [Batch, 2, size, size]. mammo_loss_and_gt[:, 1, :, :] is the mask for our x-ray, referred as GT in Thesis. mammo_loss_and_gt[:, 0, :, :] is actually the DCM mask mentioned in the thesis.

By providing the inputs for forward function you can use all the models code easily in your projects.
Good Luck!

+ 50
- 0
advent/advent.py View File

from turtle import forward
from mlassistant.core import Model
from torch import nn
from mlassistant.gan.models import GAN, BaseGenerator, BaseDiscriminator

from .entropy_minimization import EntropyMinimization


class ConvBlock(nn.Module):

def __init__(self, num_classes, ndf) -> None:
super().__init__()
self.conv2d = nn.Conv2d(num_classes, ndf, kernel_size=4, stride=2, padding=1),
self.activation = nn.LeakyReLU(negative_slope=0.2, inplace=True)

def forward(self, x):
x = self.conv2d(x)
return self.activation(x)

class Discriminator(BaseDiscriminator): # the discriminator

def __init__(self, num_classes, ndf=64) -> None:
super().__init__()
self.net = nn.Sequential(
ConvBlock(num_classes, ndf),
ConvBlock(ndf, 2 * ndf),
ConvBlock(2 * ndf, 4 * ndf),
ConvBlock(4 * ndf, 8 * ndf),
ConvBlock(8 * ndf , 1),
)

def forward(self, x):
return x


class Generator(BaseGenerator): # segnet

def __init__(self, conf):
super().__init__(conf)
self.net = EntropyMinimization()

def forward(self, x):
segmentation_net_output = self.net(x)
return segmentation_net_output


class AdventGan(GAN): # the gan

def __init__(self, d_name: str):
super().__init__(Generator(), {d_name: Discriminator()})

+ 246
- 0
advent/advent_ent_minimization_ssl.py View File

import torch
import torch as tf
from torch.nn import functional as F
import torch.nn as nn
from .deeplabv2 import DeepLabV2
from .losses import *
from mlassistant.core import Model
import sys
from ...full_segmentation.utils import sum_roi_4ch_to_1ch
from torchvision.utils import draw_segmentation_masks
from ....utils.mask_processes.resize_binarize_roi import resize_binarize_roi_torch
from ....enums import *
from .msc import MSC
from ....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi
from ....utils.dice import dice_loss
images_path = './../../images/'
import time
class EntropyMinimization(Model):

def train(self,mode : bool=True):
super().train(mode)
self.epochNUM += 1
print(self.epochNUM)

def __init__(self, num_classes=2, _lambda=1.0, input_size=64 , gamma=2.0,ent_factor = 0.001, **kwargs):
super().__init__()
self.base_network = DeepLabV2S_ResNet101_MSC(n_classes=num_classes)
self.num_classes = num_classes
self._lambda = _lambda
self.batch_count = 0
self.input_size = input_size
self.gamma = gamma
# self.activation = nn.ReLU()
self.softmax = nn.Softmax(dim=1)
self.epochNUM = 0
self.ent_factor = ent_factor
self.time_c = time.time()
def interpolate(self,shapes,arr):
segmentation_label = arr
output = dict()

if len(segmentation_label.shape) == 2:
segmentation_label = segmentation_label.unsqueeze(0)
segmentation_labels = []
segmentation_dict = dict()
for i, shape in enumerate(shapes):
start_dim = int(len(shape) - 3)
key = '_'.join([str(i) for i in shape])
if key not in segmentation_dict:
segmentation = F.interpolate(
segmentation_label.unsqueeze(1),
size=(shape[-2], shape[-1]),
mode="bilinear",
align_corners=False).squeeze()
if len(segmentation.shape) == 2:
segmentation = segmentation.unsqueeze(0)
output['org_pixel_labels_' + str(i)] = segmentation.clone()
segmentation = torch.flatten(segmentation, start_dim=start_dim)
segmentation_dict[key] = segmentation
else: # already interpolated!
segmentation = segmentation_dict[key]
segmentation_labels.append(segmentation)
final_segmentation_label = torch.cat(segmentation_labels, dim=start_dim)
return final_segmentation_label, output


def forward(self,
mammo_x: torch.Tensor,
mammo_loss_and_gt: torch.Tensor):
"""


"""

#print(mammo_roi_x.shape)
#mammo_loss_and_gt = torch.cat((mammo_roi_x[:,0,:,:].unsqueeze(1), mammo_roi_x[:,0,:,:].unsqueeze(1)), 1)
#print(mammo_loss_and_gt.shape)
#mammo_loss_and_gt =
#start = time.time()
#print("between two forwards = ", start - self.time_c, flush= True)
output = dict()
#print()
#print(mammo_x.shape,flush=True)
if mammo_loss_and_gt is not None:
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)"

loss_type_lable = mammo_loss_and_gt[:,0,:,:] # 1 means supervised and zero means unsupervised
mask = mammo_loss_and_gt[:,1,:,:] # 1 means cancer and zero means background
#deel_time = time.time()
logits , network_output, shapes = self.base_network(mammo_x)
#print("base model took= ",time.time() - deel_time, flush=True)
output['pixel_probs'] = self.softmax(logits)
network_output_soft = self.softmax(network_output)

loss_type, o1 = self.interpolate(shapes, loss_type_lable)
mask, o2 = self.interpolate(shapes, mask)

main_shape = loss_type.shape
network_output_shape = network_output_soft.shape
dup_loss_type = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1])
mask = mask.unsqueeze(dim = 1)
dup_mask = torch.cat((1-mask, mask), 1)

output = {**output, **o1, **o2}
supervised_loss = torch.tensor(0.0).to(mammo_x.device)

yhat = (network_output_soft * dup_loss_type).unsqueeze(-1)
y = (dup_mask * dup_loss_type).unsqueeze(-1)


focal_mask = dup_mask.clone().to(mammo_x.device)
focal_net_out = torch.exp(torch.log(network_output_soft).to(mammo_x.device)).to(mammo_x.device)
focal_mask[dup_loss_type == 0] = 0.0
#focal_net_out[dup_loss_type == 0] = 0.0
focal_shape = focal_mask.shape
focal_mask = torch.cat((focal_mask,torch.ones((focal_shape[0],1,focal_shape[2])).to(mammo_x.device)), 1)
focal_net_out = torch.cat((focal_net_out,torch.ones((focal_shape[0],1,focal_shape[2])).to(mammo_x.device)), 1)
supervised_loss += balanced_focal_cross_entropy_loss_semi(
focal_net_out.unsqueeze(-1),
focal_mask.unsqueeze(-1),
focal_gamma=self.gamma)


entropy_loss_val = entropy_loss_normalized(network_output_soft)* (1-loss_type)
entropy_loss_val = torch.mean(entropy_loss_val.flatten()[((1-loss_type) == 1).flatten()])

net_out_sup = yhat.flatten()[(dup_loss_type == 1).flatten()]
mask_sup = mask.flatten()[(loss_type == 1).flatten()]
dice_loss_val = dice_loss((net_out_sup.reshape(2,-1).unsqueeze(dim = 0).unsqueeze(dim = -1)),mask_sup.unsqueeze(dim = 0).unsqueeze(dim = -1).to(torch.int64))
output['suploss'] = supervised_loss
output['focalloss'] = supervised_loss
output['dice_loss'] = dice_loss_val
output["loss"] = supervised_loss + self.ent_factor * entropy_loss_val

output['org_pixel_labels'] = mammo_loss_and_gt
output['ent_loss'] = entropy_loss_val
output['ce_loss'] = F.binary_cross_entropy(yhat,y)
output['ce_0_loss'] = F.binary_cross_entropy(yhat * (1-y) + y,y) # y = 1 -> y / y = 0 -> yhat
output['ce_1_loss'] = F.binary_cross_entropy(yhat * y,y) # y = 0 -> y / y = 1 -> yhat
#print(output)

else:
output['loss'] = torch.tensor(0.0)
output['org_pixel_labels'] = mammo_loss_and_gt
output['ent_loss'] = torch.tensor(0.0)
output['ce_loss'] = torch.tensor(0.0)
output['ce_0_loss'] = torch.tensor(0.0)
output['ce_1_loss'] = torch.tensor(0.0)
output['dice_loss'] = torch.tensor(0.0)
output['suploss'] = torch.tensor(0.0)
output['focalloss'] = torch.tensor(0.0)

#print("whole forward took =", time.time() - start, flush = True)
#self.time_c = time.time()
return output


def DeepLabV2S_ResNet101_MSC(n_classes):
return MSC(
base=DeepLabV2(n_classes=n_classes, n_blocks=[3, 8, 20, 3], atrous_rates=[3, 6, 9, 12]),
scales=[0.5, 0.75])


def balanced_focal_cross_entropy_localization_loss(
probs:torch.Tensor, gt: torch.Tensor, focal_gamma: float = 1) -> torch.Tensor:
"""
Calculates class balanced cross entropy loss weighted with the style of Focal loss.
mean_over_classes(1/mean(focal_weights) * (focal_weights * cross_entropy_loss))
focal_weight: (1 - p_of_the_related_class)^gamma

Args:
probs (torch.Tensor): Probabilities assigned by the model of shape B C ...
gt (torch.Tensor): The ground truth containing the number of the class, whether of shape B ... or B C ..., make sure the ... matches in both probs and gt
focal_gamma (float): The power factor used in the weight of focal loss. Default is one which means the typical balanced cross entropy.

Returns:
torch.Tensor: The calculated
"""
localization_mask = (gt == 0 ).to(torch.int64)
probs = probs * localization_mask # just compute loss on outside regions
gt = torch.zeros(gt.shape).to(torch.int64).to(probs.device) # make int zero
# if channel is one in the pixel probs, convert it to the binary mode!
if probs.shape[1] == 1:
probs = torch.cat([1 - probs, probs], dim=1)

assert gt.max() <= probs.shape[1] - 1, f'Expected {probs.shape[1]} classes according to pixel probs\' shape but found {gt.max()} in GT mask'
if len(gt.shape) == len(probs.shape):
assert (gt.shape[1] == 1) or (gt.shape[1] == probs.shape[1]), f'Expected the channel dim to have either one channel or the same as probs while it is of shape {gt.shape} and probs is of shape {probs.shape}'
if gt.shape[1] == 1:
gt = gt[:, 0, ...]
else:
gt = torch.argmax(gt, dim=1) # transferring one-hot to count data

assert len(gt.shape) == (len(probs.shape) - 1), f'Expected GT labels to be of shape B ... and pixel probs of shape B C ..., but received {gt.shape} and {probs.shape} instead.'

# if probabilities and ground truth have different shapes, resize the ground truth
if probs.shape[-2:] != gt.shape[-2:]:

# convert gt to one-hot it before interpolation to prevent mistakes
gt = F.one_hot(gt.long(), probs.shape[1]) # B H W C
gt = torch.permute(gt, [0, -1, 1, 2])

# binarize
gt = resize_binarize_roi_torch(gt.float(), probs.shape[-2:])

# convert one-hot to numbers with argmax
gt = torch.argmax(gt, dim=1) # Eliminating the channel
gt = gt.long() # B H W

# flattening and bringing class channel o index 0
probs = probs.transpose(0, 1).flatten(1) # C N
gt = gt.flatten(0) # N

c = probs.shape[0]

gt_related_prob = probs[gt, torch.arange(gt.shape[0])]

losses = []

for ci in range(c):
c_probs = gt_related_prob[gt == ci]

if torch.numel(c_probs) > 0:
if focal_gamma == 1:
losses.append(F.binary_cross_entropy(
c_probs,
torch.ones_like(c_probs)))
else:
w = torch.pow(torch.abs(1 - c_probs.detach()), focal_gamma)
losses.append(
(1.0 / (torch.mean(w) + 1e-4)) *
F.binary_cross_entropy(
c_probs,
torch.ones_like(c_probs),
weight=w))

return torch.mean(torch.stack(losses))


+ 77
- 0
advent/deeplabv2.py View File

#!/usr/bin/env python
# coding: utf-8
#
# Author: Kazuto Nakashima
# URL: http://kazuto1011.github.io
# Created: 2017-11-19

from __future__ import absolute_import, print_function

import torch
import torch.nn as nn
import torch.nn.functional as F

from .resnet import ConvBlock, PoolBlock, ResnetBlock,ResnetLayer , ResNet


class AtrousSpatialPyramidPooling(nn.Module):
"""
(ASPP)
"""

def __init__(self, in_channels, out_channels, rates):
super(AtrousSpatialPyramidPooling, self).__init__()
self.conv_layers = nn.ModuleList()
for i, rate in enumerate(rates):
conv_layer = nn.Conv2d(in_channels, out_channels, 3, 1, padding=rate, dilation=rate, bias=True)
torch.nn.init.normal_(conv_layer.weight , 0 , 0.01)
torch.nn.init.normal_(conv_layer.bias , 0 , 0.01)
self.conv_layers.append(conv_layer)

def forward(self, x):
out = torch.cat([layer(x) for layer in self.conv_layers],dim=1)
return out

class DeepLabV2(nn.Module):
def __init__(self, n_classes, n_blocks,atrous_rates):
super(DeepLabV2, self).__init__()
ch = [32 * 2 ** p for p in range(6)]
self.layer1= PoolBlock(ch[0])
self.layer2= ResnetLayer(n_blocks[0], ch[0], ch[2], 1, 1)
self.layer3= ResnetLayer(n_blocks[1], ch[2], ch[3], 2, 1)
self.layer4= ResnetLayer(n_blocks[2], ch[3], ch[4], 2, 1)
self.atrous = AtrousSpatialPyramidPooling(ch[4], n_classes, atrous_rates)
self.up_sample = nn.Upsample(scale_factor=2)
self.conv2d_transpose1 = nn.ConvTranspose2d(in_channels=8,out_channels=4,kernel_size=3,stride=2)
self.conv2d_transpose2 = nn.ConvTranspose2d(in_channels=4,out_channels=4,kernel_size=3,stride=2)
self.conv2d = nn.Conv2d(in_channels=4,out_channels=2,kernel_size=3,stride=1)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.atrous(x)
x = self.up_sample(x)
x = self.conv2d_transpose1(x)
x = self.conv2d_transpose2(x)
x = self.conv2d(x)
return x
def freeze_bn(self):
for m in self.modules():
if isinstance(m, ConvBlock.BATCH_NORM):
m.eval()

# if __name__ == "__main__":
# model = DeepLabV2(
# n_classes=21, n_blocks=[3, 4, 23, 3], atrous_rates=[6, 12, 18, 24]
# )
# model.eval()
# image = torch.randn(1, 3, 513, 513)

# print(model)
# print("input:", image.shape)
# print("output:", model(image).shape)

+ 76
- 0
advent/deeplabv2_activations.py View File

#!/usr/bin/env python
# coding: utf-8
#
# Author: Kazuto Nakashima
# URL: http://kazuto1011.github.io
# Created: 2017-11-19

from __future__ import absolute_import, print_function

import torch
import torch.nn as nn
import torch.nn.functional as F

from .resnet import ConvBlock, PoolBlock, ResnetBlock,ResnetLayer , ResNet


class AtrousSpatialPyramidPooling(nn.Module):
"""
(ASPP)
"""

def __init__(self, in_channels, out_channels, rates):
super(AtrousSpatialPyramidPooling, self).__init__()
self.conv_layers = nn.ModuleList()
for i, rate in enumerate(rates):
conv_layer = nn.Conv2d(in_channels, out_channels, 3, 1, padding=rate, dilation=rate, bias=True)
torch.nn.init.normal_(conv_layer.weight , 0 , 0.01)
torch.nn.init.normal_(conv_layer.bias , 0 , 0.01)
self.conv_layers.append(conv_layer)

def forward(self, x):
out = torch.cat([layer(x) for layer in self.conv_layers],dim=1)
return out

class DeepLabV2(nn.Module):
def __init__(self, n_classes, n_blocks,atrous_rates):
super(DeepLabV2, self).__init__()
ch = [32 * 2 ** p for p in range(6)]
self.layer1= PoolBlock(ch[0])
self.layer2= ResnetLayer(n_blocks[0], ch[0], ch[2], 1, 1)
self.layer3= ResnetLayer(n_blocks[1], ch[2], ch[3], 2, 1)
self.layer4= ResnetLayer(n_blocks[2], ch[3], ch[4], 2, 1)
self.atrous = AtrousSpatialPyramidPooling(ch[4], n_classes, atrous_rates)
self.up_sample = nn.Upsample(scale_factor=2)
self.conv2d_transpose = nn.ConvTranspose2d(in_channels=8,out_channels=4,kernel_size=3,stride=2)
self.conv2d_transpose = nn.ConvTranspose2d(in_channels=4,out_channels=4,kernel_size=3,stride=2)
self.conv2d = nn.Conv2d(in_channels=4,out_channels=2,kernel_size=3,stride=1)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.atrous(x)
x = self.up_sample(x)
x = self.conv2d_transpose(x)
x = self.conv2d(x)
return x
def freeze_bn(self):
for m in self.modules():
if isinstance(m, ConvBlock.BATCH_NORM):
m.eval()

# if __name__ == "__main__":
# model = DeepLabV2(
# n_classes=21, n_blocks=[3, 4, 23, 3], atrous_rates=[6, 12, 18, 24]
# )
# model.eval()
# image = torch.randn(1, 3, 513, 513)

# print(model)
# print("input:", image.shape)
# print("output:", model(image).shape)

+ 82
- 0
advent/deeplabv2_dropout.py View File

#!/usr/bin/env python
# coding: utf-8
#
# Author: Kazuto Nakashima
# URL: http://kazuto1011.github.io
# Created: 2017-11-19

from __future__ import absolute_import, print_function

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Dropout
from .resnet import ConvBlock, PoolBlock, ResnetBlock,ResnetLayer , ResNet


class AtrousSpatialPyramidPooling(nn.Module):
"""
(ASPP)
"""

def __init__(self, in_channels, out_channels, rates):
super(AtrousSpatialPyramidPooling, self).__init__()
self.conv_layers = nn.ModuleList()
for i, rate in enumerate(rates):
conv_layer = nn.Conv2d(in_channels, out_channels, 3, 1, padding=rate, dilation=rate, bias=True)
torch.nn.init.normal_(conv_layer.weight , 0 , 0.01)
torch.nn.init.normal_(conv_layer.bias , 0 , 0.01)
self.conv_layers.append(conv_layer)

def forward(self, x):
out = torch.cat([layer(x) for layer in self.conv_layers],dim=1)
return out

class DeepLabV2(nn.Module):
def __init__(self, n_classes, n_blocks,atrous_rates):
super(DeepLabV2, self).__init__()
ch = [32 * 2 ** p for p in range(6)]
self.dropout = Dropout(0.1)
self.layer1= PoolBlock(ch[0])
self.layer2= ResnetLayer(n_blocks[0], ch[0], ch[2], 1, 1)
self.layer3= ResnetLayer(n_blocks[1], ch[2], ch[3], 2, 1)
self.layer4= ResnetLayer(n_blocks[2], ch[3], ch[4], 2, 1)
self.atrous = AtrousSpatialPyramidPooling(ch[4], n_classes, atrous_rates)
self.up_sample = nn.Upsample(scale_factor=2)
self.conv2d_transpose1 = nn.ConvTranspose2d(in_channels=8,out_channels=4,kernel_size=3,stride=2)
self.conv2d_transpose2 = nn.ConvTranspose2d(in_channels=4,out_channels=4,kernel_size=3,stride=2)
self.conv2d = nn.Conv2d(in_channels=4,out_channels=2,kernel_size=3,stride=1)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.dropout(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.dropout(x)
x = self.atrous(x)
x = self.up_sample(x)
x = self.dropout(x)
x = self.conv2d_transpose1(x)
x = self.conv2d_transpose2(x)
x = self.dropout(x)
x = self.conv2d(x)
return x
def freeze_bn(self):
for m in self.modules():
if isinstance(m, ConvBlock.BATCH_NORM):
m.eval()

# if __name__ == "__main__":
# model = DeepLabV2(
# n_classes=21, n_blocks=[3, 4, 23, 3], atrous_rates=[6, 12, 18, 24]
# )
# model.eval()
# image = torch.randn(1, 3, 513, 513)

# print(model)
# print("input:", image.shape)
# print("output:", model(image).shape)

+ 99
- 0
advent/entropy_minimization.py View File

import torch
import torch.nn as nn
from .deeplabv2 import DeepLabV2
from .losses import *
from mlassistant.core import Model
from ...full_segmentation.utils import sum_roi_4ch_to_1ch
from torchvision.utils import draw_segmentation_masks
from ....enums import *
from typing import List
from .msc import MSC
from ....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss

images_path = './../../images/'


class EntropyMinimization(Model):

def __init__(self,
num_classes=2,
_lambda=1.0,
input_size=64,
gamma=1.0,
multi_head: torch.nn.Module = MSC,
main_model: torch.nn.Module = DeepLabV2,
atrous_rates: List[int] = [3, 6, 9, 12],
**kwargs):
super().__init__()
self.multi_head = multi_head
self.main_model = main_model
self.atrous_rates = atrous_rates
self.base_network = self.DeepLabV2S_ResNet101_MSC(n_classes=num_classes)
self.num_classes = num_classes
self._lambda = _lambda
self.batch_count = 0
self.input_size = input_size
self.gamma = gamma
# self.activation = nn.ReLU()
self.softmax = nn.Softmax(dim=1)

def forward(self,
mammo_x: torch.Tensor,
mammo_roi_x_exist: torch.Tensor,
mammo_roi_x: torch.Tensor = None):
output = dict()
logits, network_output, shapes = self.base_network(mammo_x)
output['pixel_probs'] = self.softmax(logits)
if mammo_roi_x is not None:
mammo_roi_x_exist = mammo_roi_x_exist.bool()
segmentation_label = sum_roi_4ch_to_1ch(mammo_roi_x,
ROIAggregation.mass_vs_others).squeeze()
print(segmentation_label.shape)
if len(segmentation_label.shape) == 2:
segmentation_label = segmentation_label.unsqueeze(0)
segmentation_labels = []
segmentation_dict = dict()
for i, shape in enumerate(shapes):
start_dim = int(len(shape) - 3)
key = '_'.join([str(i) for i in shape])
if key not in segmentation_dict:
segmentation = F.interpolate(
segmentation_label.unsqueeze(1),
size=(shape[-2], shape[-1]),
mode="bilinear",
align_corners=False).squeeze()
if len(segmentation.shape) == 2:
segmentation = segmentation.unsqueeze(0)
output['org_pixel_labels_' + str(i)] = segmentation.clone()
segmentation = torch.flatten(segmentation, start_dim=start_dim)
segmentation_dict[key] = segmentation
else: # already interpolated!
segmentation = segmentation_dict[key]
segmentation_labels.append(segmentation)
final_segmentation_label = torch.cat(segmentation_labels, dim=start_dim)
# ent_loss = torch.tensor(0.0).to(mammo_x.device)
segmentation_loss = torch.tensor(0.0).to(mammo_x.device)
network_output = self.softmax(network_output)
# unsupervised_outputs = network_output[~mammo_roi_x_exist]
# if len(unsupervised_outputs) > 0:
# ent_loss = ent_loss + torch.mean(entropy_loss(unsupervised_outputs))
supervised_outputs, rois = network_output[mammo_roi_x_exist], final_segmentation_label[
mammo_roi_x_exist]
if len(supervised_outputs) > 0:
segmentation_loss = segmentation_loss + balanced_focal_cross_entropy_loss(
supervised_outputs.unsqueeze(-1),
rois.to(torch.int64).unsqueeze(-1),
focal_gamma=self.gamma)
output['org_pixel_labels'] = output['org_pixel_labels_0']
loss = segmentation_loss # + self._lambda * torch.mean(ent_loss)
output['loss'] = loss
else:
output['loss'] = torch.tensor(0.0)
output['org_pixel_labels'] = output['pixel_probs']
return output

def DeepLabV2S_ResNet101_MSC(self, n_classes):
return self.multi_head(
base=self.main_model(
n_classes=n_classes, n_blocks=[3, 4, 23, 3], atrous_rates=self.atrous_rates),
scales=[0.5, 0.75])

+ 93
- 0
advent/entropy_minimization_1.py View File

from turtle import forward
import torch
import torch.nn as nn
from .deeplabv2 import DeepLabV2
from .losses import *
from mlassistant.core import Model
from ...full_segmentation.utils import sum_roi_4ch_to_1ch
from torchvision.utils import draw_segmentation_masks
from ....enums import *
from .msc import MSC
from ....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss
images_path = './../../images/'
class EntropyMinimization(Model):

def __init__(self, num_classes=2, _lambda=1.0, input_size=64, **kwargs):
super().__init__()
self.base_network = DeepLabV2S_ResNet101_MSC(n_classes=num_classes)
self.num_classes = num_classes
self._lambda = _lambda
self.batch_count = 0
self.input_size = input_size
# self.activation = nn.ReLU()
self.softmax = nn.Softmax(dim=1)

def forward(self,
mammo_x: torch.Tensor,
mammo_roi_x_exist: torch.Tensor,
mammo_roi_x: torch.Tensor = None):
output = dict()
mammo_roi_x_exist = mammo_roi_x_exist.bool()
mammo_x = mammo_x.repeat(1, 3, 1, 1)
segmentation_label = sum_roi_4ch_to_1ch(mammo_roi_x, ROIAggregation.mass_vs_others).squeeze()
logits , network_output, shapes = self.base_network(mammo_x)
# print(shapes)
segmentation_labels = []
segmentation_dict = dict()
for i,shape in enumerate(shapes):
key = '_'.join([str(i) for i in shape])
if key not in segmentation_dict:
segmentation = F.interpolate(
segmentation_label.unsqueeze(1),
size=(shape[-2], shape[-1]),
mode="bilinear",
align_corners=False).squeeze()
output['org_pixel_labels_' + str(i)] = segmentation.clone()
segmentation = torch.flatten(segmentation,start_dim=1)
segmentation_dict[key] = segmentation
else: # already interpolated!
segmentation = segmentation_dict[key]
segmentation_labels.append(segmentation)
final_segmentation_label = torch.cat(segmentation_labels, dim=1)
ent_loss = torch.tensor(0.0).to(mammo_x.device)
segmentation_loss = torch.tensor(0.0).to(mammo_x.device)
network_output = self.softmax(network_output)
unsupervised_outputs = network_output[~mammo_roi_x_exist]
if len(unsupervised_outputs) > 0:
ent_loss = ent_loss + torch.mean(entropy_loss(unsupervised_outputs))
output['pixel_probs'] = self.softmax(logits)
supervised_outputs, rois = network_output[mammo_roi_x_exist], final_segmentation_label[mammo_roi_x_exist]
if len(supervised_outputs) > 0:
segmentation_loss = segmentation_loss + balanced_focal_cross_entropy_loss(
supervised_outputs.unsqueeze(-1), rois.to(torch.int64).unsqueeze(-1), focal_gamma=1)
# for i in range(len(mammo_x)):
# if mammo_roi_x_exist[i]:
# for threshold in range(20):
# binary_threshold = (threshold + 1) * 0.05
# image = (mammo_x[i] * 255).to(torch.uint8).cpu()
# if len(image.shape)< 3:
# image = image.unsqueeze(0)
# # image
# new_label = (segmentation_label[i]>0.5).bool().unsqueeze(0).cpu()
# # new label
# new_prediction = (output['pixel_probs'][i][1].clone().detach()>binary_threshold).bool().unsqueeze(0).cpu()
# # new_pred
# masks = torch.cat([new_label,new_prediction]).cpu()
# colors = ["green","red"]
# alpha = 0.7
# final_image = draw_segmentation_masks(image, masks, alpha, colors).transpose(0,1).transpose(1,2)
# Image.fromarray(final_image.cpu().numpy()).save(
# os.path.join(images_path,
# str(self.batch_count) + '_' + str(i) + '_' + str(binary_threshold) + '_image_mask.png'))
output['org_pixel_labels'] = output['org_pixel_labels_2']
loss = segmentation_loss + self._lambda * torch.mean(ent_loss)
output['loss'] = loss
return output


def DeepLabV2S_ResNet101_MSC(n_classes):
return MSC(
base=DeepLabV2(n_classes=n_classes, n_blocks=[3, 8, 5, 3], atrous_rates=[3, 6, 9, 12]),
scales=[0.5, 8.0])

+ 161
- 0
advent/entropy_minimization_semi_supervised.py View File

import torch
import torch.nn as nn
from .deeplabv2 import DeepLabV2
from .losses import *
from mlassistant.core import Model
from ...full_segmentation.utils import sum_roi_4ch_to_1ch
from torchvision.utils import draw_segmentation_masks
from ....utils.mask_processes.resize_binarize_roi import resize_binarize_roi_torch
from ....enums import *
from .msc import MSC
from ....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss
images_path = './../../images/'
class EntropyMinimization(Model):

def __init__(self, num_classes=2, _lambda=1.0, input_size=64 , gamma=1.0, **kwargs):
super().__init__()
self.base_network = DeepLabV2S_ResNet101_MSC(n_classes=num_classes)
self.num_classes = num_classes
self._lambda = _lambda
self.batch_count = 0
self.input_size = input_size
self.gamma = gamma
# self.activation = nn.ReLU()
self.softmax = nn.Softmax(dim=1)

def forward(self,
mammo_x: torch.Tensor,
mammo_roi_x_exist: torch.Tensor,
mammo_roi_x: torch.Tensor = None):
output = dict()
logits , network_output, shapes = self.base_network(mammo_x)
output['pixel_probs'] = self.softmax(logits)
if mammo_roi_x is not None:
mammo_roi_x_exist = mammo_roi_x_exist.bool()
segmentation_label = sum_roi_4ch_to_1ch(mammo_roi_x, ROIAggregation.mass_vs_others).squeeze()
if len(segmentation_label.shape) == 2:
segmentation_label = segmentation_label.unsqueeze(0)
segmentation_labels = []
segmentation_dict = dict()
for i,shape in enumerate(shapes):
start_dim = int(len(shape) - 3)
key = '_'.join([str(i) for i in shape])
if key not in segmentation_dict:
segmentation = F.interpolate(
segmentation_label.unsqueeze(1),
size=(shape[-2], shape[-1]),
mode="bilinear",
align_corners=False).squeeze()
if len(segmentation.shape) == 2:
segmentation = segmentation.unsqueeze(0)
output['org_pixel_labels_' + str(i)] = segmentation.clone()
segmentation = torch.flatten(segmentation,start_dim=start_dim)
segmentation_dict[key] = segmentation
else: # already interpolated!
segmentation = segmentation_dict[key]
segmentation_labels.append(segmentation)
final_segmentation_label = torch.cat(segmentation_labels, dim=start_dim)
# ent_loss = torch.tensor(0.0).to(mammo_x.device)
segmentation_loss = torch.tensor(0.0).to(mammo_x.device)
network_output = self.softmax(network_output)
unsupervised_outputs = network_output[~mammo_roi_x_exist]
if len(unsupervised_outputs) > 0:
ent_loss = ent_loss + torch.mean(entropy_loss(unsupervised_outputs))
# apply localization losses
supervised_outputs, localizations = network_output[mammo_roi_x_exist], final_segmentation_label[mammo_roi_x_exist]
if len(supervised_outputs) > 0:
segmentation_loss = segmentation_loss + balanced_focal_cross_entropy_loss(
supervised_outputs.unsqueeze(-1), localizations.to(torch.int64).unsqueeze(-1), focal_gamma=self.gamma)
output['org_pixel_labels'] = output['org_pixel_labels_0']
loss = segmentation_loss # + self._lambda * torch.mean(ent_loss)
output['loss'] = loss
else:
output['loss'] = torch.tensor(0.0)
output['org_pixel_labels'] = output['pixel_probs']
return output


def DeepLabV2S_ResNet101_MSC(n_classes):
return MSC(
base=DeepLabV2(n_classes=n_classes, n_blocks=[3, 8, 20, 3], atrous_rates=[3, 6, 9, 12]),
scales=[0.5, 0.75])


def balanced_focal_cross_entropy_localization_loss(
probs:torch.Tensor, gt: torch.Tensor, focal_gamma: float = 1) -> torch.Tensor:
"""
Calculates class balanced cross entropy loss weighted with the style of Focal loss.
mean_over_classes(1/mean(focal_weights) * (focal_weights * cross_entropy_loss))
focal_weight: (1 - p_of_the_related_class)^gamma

Args:
probs (torch.Tensor): Probabilities assigned by the model of shape B C ...
gt (torch.Tensor): The ground truth containing the number of the class, whether of shape B ... or B C ..., make sure the ... matches in both probs and gt
focal_gamma (float): The power factor used in the weight of focal loss. Default is one which means the typical balanced cross entropy.

Returns:
torch.Tensor: The calculated
"""
localization_mask = (gt == 0 ).to(torch.int64)
probs = probs * localization_mask # just compute loss on outside regions
gt = torch.zeros(gt.shape).to(torch.int64).to(probs.device) # make int zero
# if channel is one in the pixel probs, convert it to the binary mode!
if probs.shape[1] == 1:
probs = torch.cat([1 - probs, probs], dim=1)

assert gt.max() <= probs.shape[1] - 1, f'Expected {probs.shape[1]} classes according to pixel probs\' shape but found {gt.max()} in GT mask'
if len(gt.shape) == len(probs.shape):
assert (gt.shape[1] == 1) or (gt.shape[1] == probs.shape[1]), f'Expected the channel dim to have either one channel or the same as probs while it is of shape {gt.shape} and probs is of shape {probs.shape}'
if gt.shape[1] == 1:
gt = gt[:, 0, ...]
else:
gt = torch.argmax(gt, dim=1) # transferring one-hot to count data

assert len(gt.shape) == (len(probs.shape) - 1), f'Expected GT labels to be of shape B ... and pixel probs of shape B C ..., but received {gt.shape} and {probs.shape} instead.'

# if probabilities and ground truth have different shapes, resize the ground truth
if probs.shape[-2:] != gt.shape[-2:]:

# convert gt to one-hot it before interpolation to prevent mistakes
gt = F.one_hot(gt.long(), probs.shape[1]) # B H W C
gt = torch.permute(gt, [0, -1, 1, 2])

# binarize
gt = resize_binarize_roi_torch(gt.float(), probs.shape[-2:])

# convert one-hot to numbers with argmax
gt = torch.argmax(gt, dim=1) # Eliminating the channel
gt = gt.long() # B H W

# flattening and bringing class channel o index 0
probs = probs.transpose(0, 1).flatten(1) # C N
gt = gt.flatten(0) # N

c = probs.shape[0]

gt_related_prob = probs[gt, torch.arange(gt.shape[0])]

losses = []

for ci in range(c):
c_probs = gt_related_prob[gt == ci]

if torch.numel(c_probs) > 0:
if focal_gamma == 1:
losses.append(F.binary_cross_entropy(
c_probs,
torch.ones_like(c_probs)))
else:
w = torch.pow(torch.abs(1 - c_probs.detach()), focal_gamma)
losses.append(
(1.0 / (torch.mean(w) + 1e-4)) *
F.binary_cross_entropy(
c_probs,
torch.ones_like(c_probs),
weight=w))

return torch.mean(torch.stack(losses))


+ 108
- 0
advent/entropy_minimization_tumor_augmentation.py View File

import torch
import torch.nn as nn
from .deeplabv2 import DeepLabV2
from .losses import *
from mlassistant.core import Model
from ...full_segmentation.utils import sum_roi_4ch_to_1ch
from torchvision.utils import draw_segmentation_masks
from ....enums import *
from typing import List
from .msc import MSC
from ....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss
from ....augmentations.random_tumor_augmentor import RandomTumorAugmentor
from ....augmentations.random_tumor_coloring import RandomTumorColoring

images_path = './../../images/'


class EntropyMinimization(Model):

def __init__(self,
num_classes=2,
_lambda=1.0,
input_size=64,
gamma=1.0,
multi_head: torch.nn.Module = MSC,
main_model: torch.nn.Module = DeepLabV2,
atrous_rates: List[int] = [3, 6, 9, 12],
**kwargs):
super().__init__()
self.multi_head = multi_head
self.main_model = main_model
self.augmentor = RandomTumorAugmentor(probability=0.5)
self.tumor_coloring = RandomTumorColoring(probability=0.5)
self.atrous_rates = atrous_rates
self.base_network = self.DeepLabV2S_ResNet101_MSC(n_classes=num_classes)
self.num_classes = num_classes
self._lambda = _lambda
self.batch_count = 0
self.input_size = input_size
self.gamma = gamma
# self.activation = nn.ReLU()
self.softmax = nn.Softmax(dim=1)

def forward(self,
mammo_x: torch.Tensor,
mammo_roi_x_exist: torch.Tensor,
mammo_y_breast_type:torch.Tensor,
mammo_roi_x: torch.Tensor):
output = dict()
if mammo_roi_x is not None:
segmentation_label = sum_roi_4ch_to_1ch(mammo_roi_x,
ROIAggregation.tumor_vs_norm).squeeze()
segmentation_label[segmentation_label>0] = 1
if len(segmentation_label.shape) == 2:
segmentation_label = segmentation_label.unsqueeze(0)
mammo_x ,segmentation_label = self.tumor_coloring.forward(mammo_x,segmentation_label)
for i in range(len(mammo_x)):
mammo_x[i] ,segmentation_label[i] = self.augmentor.forward(mammo_x[i], segmentation_label[i], int(mammo_y_breast_type[i]))
logits, network_output, shapes = self.base_network(mammo_x)
output['pixel_probs'] = self.softmax(logits)
if mammo_roi_x is not None:
mammo_roi_x_exist = mammo_roi_x_exist.bool()
segmentation_labels = []
segmentation_dict = dict()
for i, shape in enumerate(shapes):
start_dim = int(len(shape) - 3)
key = '_'.join([str(i) for i in shape])
if key not in segmentation_dict:
segmentation = F.interpolate(
segmentation_label.unsqueeze(1),
size=(shape[-2], shape[-1]),
mode="bilinear",
align_corners=False).squeeze()
if len(segmentation.shape) == 2:
segmentation = segmentation.unsqueeze(0)
output['org_pixel_labels_' + str(i)] = segmentation.clone()
segmentation = torch.flatten(segmentation, start_dim=start_dim)
segmentation_dict[key] = segmentation
else: # already interpolated!
segmentation = segmentation_dict[key]
segmentation_labels.append(segmentation)
final_segmentation_label = torch.cat(segmentation_labels, dim=start_dim)
# ent_loss = torch.tensor(0.0).to(mammo_x.device)
segmentation_loss = torch.tensor(0.0).to(mammo_x.device)
network_output = self.softmax(network_output)
# unsupervised_outputs = network_output[~mammo_roi_x_exist]
# if len(unsupervised_outputs) > 0:
# ent_loss = ent_loss + torch.mean(entropy_loss(unsupervised_outputs))
supervised_outputs, rois = network_output[mammo_roi_x_exist], final_segmentation_label[
mammo_roi_x_exist]
if len(supervised_outputs) > 0:
segmentation_loss = segmentation_loss + balanced_focal_cross_entropy_loss(
supervised_outputs.unsqueeze(-1),
rois.to(torch.int64).unsqueeze(-1),
focal_gamma=self.gamma)
output['org_pixel_labels'] = output['org_pixel_labels_0']
loss = segmentation_loss # + self._lambda * torch.mean(ent_loss)
output['loss'] = loss
else:
output['loss'] = torch.tensor(0.0)
output['org_pixel_labels'] = output['pixel_probs']
return output

def DeepLabV2S_ResNet101_MSC(self, n_classes):
return self.multi_head(
base=self.main_model(
n_classes=n_classes, n_blocks=[3, 4, 23, 3], atrous_rates=self.atrous_rates),
scales=[0.5, 0.75])

+ 52
- 0
advent/losses.py View File

# this model is based on this article: ADVENT: Adversarial Entropy Minimization for Domain
# Adaptation in Semantic Segmentation
# https://arxiv.org/abs/1811.12833
import numpy as np
import torch
import torch.nn.functional as F
from torch.autograd import Variable


def entropy_loss(v):
"""
Entropy loss for probabilistic prediction vectors
input: batch_size x channels x hw
output: batch_size x 1 x hw
"""
assert v.dim() == 3
n, c, hw = v.size()
denominator = n * hw * np.log2(c)
return -torch.sum(torch.mul(v, torch.log2(v+1e-30)),dim=1) / denominator

def entropy_loss_normalized(v):
"""
Entropy loss for probabilistic prediction vectors
input: batch_size x channels x hw
output: batch_size x 1 x hw
"""
assert v.dim() == 3
n, c, hw = v.size()
return -torch.sum(torch.mul(v, torch.log2(v+1e-30)),dim=1) / np.log2(c)


def cross_entropy_2d(predict, target):
"""
Args:
predict:(n, c, h, w)
target:(n, h, w)
"""
assert not target.requires_grad
assert predict.dim() == 4
assert target.dim() == 3
assert predict.size(0) == target.size(0), f"{predict.size(0)} vs {target.size(0)}"
assert predict.size(2) == target.size(1), f"{predict.size(2)} vs {target.size(1)}"
assert predict.size(3) == target.size(2), f"{predict.size(3)} vs {target.size(3)}"
n, c, h, w = predict.size()
target_mask = (target >= 0) * (target != 255)
target = target[target_mask]
if not target.data.dim():
return Variable(torch.zeros(1))
predict = predict.transpose(1, 2).transpose(2, 3).contiguous()
predict = predict[target_mask.view(n, h, w, 1).repeat(1, 1, 1, c)].view(-1, c)
loss = F.cross_entropy(predict, target, size_average=True)
return loss

+ 47
- 0
advent/msc.py View File

#!/usr/bin/env python
# coding: utf-8
#
# Author: Kazuto Nakashima
# URL: http://kazuto1011.github.io
# Created: 2018-03-26

import torch
import torch.nn as nn
import torch.nn.functional as F


class MSC(nn.Module):
"""
Multi-scale inputs
"""

def __init__(self, base, scales=None):
super(MSC, self).__init__()
self.base = base
if scales:
self.scales = scales
else:
self.scales = [0.5, 0.75]

def forward(self, x):
# Original
logits = self.base(x)
_, _, H, W = logits.shape
interp = lambda l: F.interpolate(
l, size=(H, W), mode="bilinear", align_corners=False
)

# Scaled
logits_pyramid = []
for scale in self.scales:
h = F.interpolate(x, scale_factor=scale, mode="bilinear", align_corners=False)
logits_pyramid.append(self.base(h))

# Pixel-wise max
logits_all = [logits] + [interp(l) for l in logits_pyramid]
logits_max = torch.max(torch.stack(logits_all), dim=0)[0]

shapes = [logits.shape] + [out.shape for out in logits_pyramid] + [logits_max.shape]
outputs = [logits] + logits_pyramid + [logits_max]
out = torch.cat([torch.flatten(output,start_dim=2) for output in outputs],dim=2)
return logits , out , shapes

+ 47
- 0
advent/msc_confident.py View File

#!/usr/bin/env python
# coding: utf-8
#
# Author: Kazuto Nakashima
# URL: http://kazuto1011.github.io
# Created: 2018-03-26

import torch
import torch.nn as nn
import torch.nn.functional as F


class MSC(nn.Module):
"""
Multi-scale inputs
"""

def __init__(self, base, scales=None):
super(MSC, self).__init__()
self.base = base
if scales:
self.scales = scales
else:
self.scales = [0.5, 0.75]

def forward(self, x):
# Original
logits = self.base(x)
_, _, H, W = logits.shape
interp = lambda l: F.interpolate(
l, size=(H, W), mode="bilinear", align_corners=False
)

# Scaled
logits_pyramid = []
for scale in self.scales:
h = F.interpolate(x, scale_factor=scale, mode="bilinear", align_corners=False)
logits_pyramid.append(self.base(h))

# Pixel-wise max
logits_all = [logits] + [interp(l) for l in logits_pyramid]
logits_max = torch.max(torch.stack(logits_all), dim=0)[0]
logits_confident = torch.min(torch.stack(logits_all), dim=0)[0]
shapes = [logits.shape] + [out.shape for out in logits_pyramid] + [logits_max.shape]
outputs = [logits] + logits_pyramid + [logits_confident]
out = torch.cat([torch.flatten(output,start_dim=2) for output in outputs],dim=2)
return logits , out , shapes

+ 148
- 0
advent/resnet.py View File

#!/usr/bin/env python
# coding: utf-8
#
# Author: Kazuto Nakashima
# URL: http://kazuto1011.github.io
# Created: 2017-11-19

from __future__ import absolute_import, print_function

from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
from mlassistant.core import Model

try:
from encoding.nn import SyncBatchNorm

_BATCH_NORM = SyncBatchNorm
except:
_BATCH_NORM = nn.BatchNorm2d

_BOTTLENECK_EXPANSION = 4


class ConvBlock(Model):
"""
Cascade of 2D convolution, batch norm, and ReLU.
"""

BATCH_NORM = _BATCH_NORM

def __init__(self, in_ch, out_ch, kernel_size, stride, padding, dilation, relu=True):
super(ConvBlock, self).__init__()
self.conv = nn.Conv2d(
in_ch, out_ch, kernel_size, stride, padding, dilation, bias=False)
self.bn = _BATCH_NORM(out_ch, eps=1e-5, momentum=1 - 0.999)
self.relu_activation = nn.ReLU()
self.relu = relu

def forward(self, x):
x = self.conv(x)
x = self.bn(x)
if self.relu:
x = self.relu_activation(x)
return x


class ResnetBlock(nn.Module):
"""
Bottleneck block of MSRA ResNet.
"""

def __init__(self, in_ch, out_ch, stride, dilation, downsample):
super(ResnetBlock, self).__init__()
mid_ch = out_ch // _BOTTLENECK_EXPANSION
self.reduce = ConvBlock(in_ch, mid_ch, 1, stride, 0, 1, True)
self.conv3x3 = ConvBlock(mid_ch, mid_ch, 3, 1, dilation, dilation, True)
self.increase = ConvBlock(mid_ch, out_ch, 1, 1, 0, 1, False)
self.shortcut = (
ConvBlock(in_ch, out_ch, 1, stride, 0, 1, False) if downsample else nn.Identity())

def forward(self, x):
h = self.reduce(x)
h = self.conv3x3(h)
h = self.increase(h)
h = h + self.shortcut(x)
return F.relu(h)


class ResnetLayer(nn.Module):
"""
Residual layer with multi grids
"""

def __init__(self, n_layers, in_ch, out_ch, stride, dilation, multi_grids=None):
super(ResnetLayer, self).__init__()
self.layers = nn.ModuleList()
if multi_grids is None:
multi_grids = [1 for _ in range(n_layers)]
else:
assert n_layers == len(multi_grids)

for i in range(n_layers):
self.layers.append(
ResnetBlock(
in_ch=(in_ch if i == 0 else out_ch),
out_ch=out_ch,
stride=(stride if i == 0 else 1),
dilation=dilation * multi_grids[i],
downsample=(True if i == 0 else False)))

def forward(self, x):
for layer in self.layers:
x = layer(x)
return x


class PoolBlock(nn.Module):
"""
The 1st conv layer.
Note that the max pooling is different from both MSRA and FAIR ResNet.
"""

def __init__(self, out_ch):
super(PoolBlock, self).__init__()
self.conv1 = ConvBlock(1, out_ch, 7, 2, 3, 1)
self.pool = nn.MaxPool2d(3, 2, 1, ceil_mode=True)

def forward(self, x):
x = self.conv1(x)
x = self.pool(x)
return x


class ResNet(nn.Module):

def __init__(self, n_classes, n_blocks):
super(ResNet, self).__init__()
ch = [64 * 2**p for p in range(6)]
self.layer1 = PoolBlock(ch[0])
self.layer2 = ResnetLayer(n_blocks[0], ch[0], ch[2], 1, 1)
self.layer3 = ResnetLayer(n_blocks[1], ch[2], ch[3], 2, 1)
self.layer4 = ResnetLayer(n_blocks[2], ch[3], ch[4], 2, 1)
self.pool5 = nn.AdaptiveAvgPool2d(1)
self.flatten = nn.Flatten()
self.fc = nn.Linear(ch[5], n_classes)

def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.pool5(x)
x = self.flatten(x)
x = self.fc(x)
return x


if __name__ == "__main__":
model = ResNet(n_classes=1000, n_blocks=[3, 4, 23, 3])
model.eval()
image = torch.randn(1, 3, 224, 224)

print(model)
print("input:", image.shape)
print("output:", model(image).shape)

+ 79
- 0
base_line/ConvUnext/convUnext_focal.py View File

from copy import deepcopy
import math
from random import uniform
from typing import List
import time
from mlassistant.core import Model, ModelIO
import torch
from torch.nn import functional as F
from torch import nn
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi
from .....utils.generalized_dice import dice_loss
from .....utils.losses.tverskyLoss import tversky_loss

from .....utils.dynamic_num_generator.num_generator import init_num_generator1, init_num_generator2
from .....utils.dynamic_num_generator.wrapper import dynamic_num_generator





class ConvUnextv1(Model):
def __init__(self,
base_model: Model,
gamma = 4 , smooth=1, alpha=0.7, beta=0.3) -> None:
super(ConvUnextv1, self).__init__()

self.base = base_model
self.gamma = gamma
self.smooth = smooth
self.alpha = alpha
self.beta = beta
print(self.gamma)

def forward(self,
mammo_x: torch.Tensor,
mammo_loss_and_gt: torch.Tensor) -> ModelIO:
"""first forwards supervised samples through student net, then all samples through both nets. """
output = dict(loss=0.0)
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)"
loss_type = mammo_loss_and_gt[:,0,:,:]
mask = mammo_loss_and_gt[:,1,:,:]
network_output = self.base(mammo_x)
network_output = nn.LogSoftmax(dim=1)(network_output).to(mammo_x.device) # logsoftmax # B
network_output_soft = torch.exp(network_output).to(mammo_x.device) # softmax output of model
output['pixel_probs'] = network_output_soft

main_shape = mask.shape
network_output_shape = network_output.shape



loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2])
mask = mask.unsqueeze(dim = 1)
mask_channels = torch.cat((1-mask, mask), 1)

dummy_mask = mask_channels.clone().to(mammo_x.device)
dummy_net_out = network_output_soft.clone().to(mammo_x.device)
dummy_mask[loss_channels == 0] = 0.0
dummy_shape = dummy_mask.shape
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_net_out = torch.cat((dummy_net_out,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_yhat = network_output_soft * loss_channels
dummy_y = mask_channels * loss_channels
focal_loss = balanced_focal_cross_entropy_loss_semi(
dummy_net_out,
dummy_mask,
focal_gamma=self.gamma)
dice = dice_loss(dummy_net_out,dummy_mask)
tversky = tversky_loss(inputs = dummy_yhat[:,1,...], targets = dummy_y[:,1,...], smooth=self.smooth, alpha = self.alpha, beta= self.beta)
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y)
output['ce_loss'] = ce_loss
output['loss'] = focal_loss
output['diceloss'] = dice
output['focalloss'] = focal_loss
output['suploss'] = output['loss']
output['tverskyloss'] = tversky
output['org_pixel_labels'] = mammo_loss_and_gt
return output

+ 84
- 0
base_line/Convnext/convnext_focal.py View File

from copy import deepcopy
import math
from random import uniform
from typing import List
import time
from mlassistant.core import Model, ModelIO
import torch
from torch.nn import functional as F
from torch import nn
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi
from .....utils.generalized_dice import dice_loss
from .....utils.losses.tverskyLoss import tversky_loss

from .....utils.dynamic_num_generator.num_generator import init_num_generator1, init_num_generator2
from .....utils.dynamic_num_generator.wrapper import dynamic_num_generator





class Convnext(Model):
def __init__(self,
base_model: Model,
gamma = 4 , smooth=1, alpha=0.7, beta=0.3) -> None:
super(Convnext, self).__init__()

self.base = base_model
self.gamma = gamma
self.smooth = smooth
self.alpha = alpha
self.beta = beta
print(self.gamma)

def forward(self,
mammo_x: torch.Tensor,
mammo_loss_and_gt: torch.Tensor) -> ModelIO:
"""first forwards supervised samples through student net, then all samples through both nets. """
output = dict(loss=0.0)
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)"
loss_type = mammo_loss_and_gt[:,0,:,:]
mask = mammo_loss_and_gt[:,1,:,:]
print("before model")
network_output = self.base._forward_features(mammo_x) # logsoftmax # B
print(network_output[0].shape)
print(network_output[1].shape)
print(network_output[2].shape)
print(network_output[3].shape)
print(network_output.shape)
network_output_soft = torch.exp(network_output).to(mammo_x.device) # softmax output of model
output['pixel_probs'] = network_output_soft

main_shape = mask.shape
network_output_shape = network_output.shape



loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2])
mask = mask.unsqueeze(dim = 1)
mask_channels = torch.cat((1-mask, mask), 1)

dummy_mask = mask_channels.clone().to(mammo_x.device)
dummy_net_out = network_output_soft.clone().to(mammo_x.device)
dummy_mask[loss_channels == 0] = 0.0
dummy_shape = dummy_mask.shape
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_net_out = torch.cat((dummy_net_out,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_yhat = network_output_soft * loss_channels
dummy_y = mask_channels * loss_channels
focal_loss = balanced_focal_cross_entropy_loss_semi(
dummy_net_out,
dummy_mask,
focal_gamma=self.gamma)
dice = dice_loss(dummy_net_out,dummy_mask)
tversky = tversky_loss(inputs = dummy_yhat[:,1,...], targets = dummy_y[:,1,...], smooth=self.smooth, alpha = self.alpha, beta= self.beta)
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y)
output['ce_loss'] = ce_loss
output['loss'] = focal_loss
output['diceloss'] = dice
output['focalloss'] = focal_loss
output['suploss'] = output['loss']
output['tverskyloss'] = tversky
output['org_pixel_labels'] = mammo_loss_and_gt
return output

+ 80
- 0
base_line/Deepvlab3/deepvlab3_focal.py View File

from copy import deepcopy
import math
from random import uniform
from typing import List
import time
from mlassistant.core import Model, ModelIO
import torch
from torch.nn import functional as F
from torch import nn
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi
from .....utils.generalized_dice import dice_loss
from .....utils.losses.tverskyLoss import tversky_loss

from .....utils.dynamic_num_generator.num_generator import init_num_generator1, init_num_generator2
from .....utils.dynamic_num_generator.wrapper import dynamic_num_generator





class DEEPVLAB3(Model):
def __init__(self,
base_model: Model,
gamma = 4 , smooth=1, alpha=0.7, beta=0.3) -> None:
super(DEEPVLAB3, self).__init__()

self.base = base_model
self.gamma = gamma
self.smooth = smooth
self.alpha = alpha
self.beta = beta
print(self.gamma)
print("gamma")

def forward(self,
mammo_x: torch.Tensor,
mammo_loss_and_gt: torch.Tensor) -> ModelIO:
"""first forwards supervised samples through student net, then all samples through both nets. """
output = dict(loss=0.0)
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)"
loss_type = mammo_loss_and_gt[:,0,:,:]
mask = mammo_loss_and_gt[:,1,:,:]
network_output = self.base(mammo_x.expand(mammo_x.shape[0],3,mammo_x.shape[2],mammo_x.shape[3]).to(mammo_x.device))["out"] # logsoftmax # B
network_output = nn.LogSoftmax(dim=1)(network_output).to(mammo_x.device) # logsoftmax # B
network_output_soft = torch.exp(network_output).to(mammo_x.device) # softmax output of model
output['pixel_probs'] = network_output_soft

main_shape = mask.shape
network_output_shape = network_output.shape



loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2])
mask = mask.unsqueeze(dim = 1)
mask_channels = torch.cat((1-mask, mask), 1)

dummy_mask = mask_channels.clone().to(mammo_x.device)
dummy_net_out = network_output_soft.clone().to(mammo_x.device)
dummy_mask[loss_channels == 0] = 0.0
dummy_shape = dummy_mask.shape
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_net_out = torch.cat((dummy_net_out,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_yhat = network_output_soft * loss_channels
dummy_y = mask_channels * loss_channels
focal_loss = balanced_focal_cross_entropy_loss_semi(
dummy_net_out,
dummy_mask,
focal_gamma=self.gamma)
dice = dice_loss(dummy_net_out,dummy_mask)
tversky = tversky_loss(inputs = dummy_yhat[:,1,...], targets = dummy_y[:,1,...], smooth=self.smooth, alpha = self.alpha, beta= self.beta)
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y)
output['ce_loss'] = ce_loss
output['loss'] = focal_loss
output['diceloss'] = dice
output['focalloss'] = focal_loss
output['suploss'] = output['loss']
output['tverskyloss'] = tversky
output['org_pixel_labels'] = mammo_loss_and_gt
return output

+ 0
- 0
base_line/Resunet/Heavy/Models/__init__.py View File


+ 180
- 0
base_line/Resunet/Heavy/Models/resunet.py View File

import numpy as np
import torch
import torch.nn as nn
from torch.distributions.uniform import Uniform


def kaiming_normal_init_weight(model):
for m in model.modules():
if isinstance(m, nn.Conv3d):
torch.nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm3d):
m.weight.data.fill_(1)
m.bias.data.zero_()
return model


def sparse_init_weight(model):
for m in model.modules():
if isinstance(m, nn.Conv3d):
torch.nn.init.sparse_(m.weight, sparsity=0.1)
elif isinstance(m, nn.BatchNorm3d):
m.weight.data.fill_(1)
m.bias.data.zero_()
return model


class ConvBlock(nn.Module):
"""two convolution layers with batch norm and leaky relu"""

def __init__(self, in_channels, out_channels, dropout_p):
super(ConvBlock, self).__init__()
self.conv_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(),
nn.Dropout(dropout_p),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU()
)

def forward(self, x):
return self.conv_conv(x)


class DownBlock(nn.Module):
"""Downsampling followed by ConvBlock"""

def __init__(self, in_channels, out_channels, dropout_p):
super(DownBlock, self).__init__()
self._conv = ConvBlock(in_channels, out_channels, dropout_p)
self.maxpool = nn.MaxPool2d(2)

def forward(self, x, resx , res = True):
if res:
return self.maxpool(self._conv(x) + resx)
return self.maxpool(self._conv(x))


class UpBlock(nn.Module):
"""Upssampling followed by ConvBlock"""

def __init__(self, in_channels1, in_channels2, out_channels, dropout_p,
bilinear=True):
super(UpBlock, self).__init__()
self.bilinear = bilinear
if bilinear:
self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size=1)
self.up = nn.Upsample(
scale_factor=2, mode='bilinear', align_corners=True)
else:
self.up = nn.ConvTranspose2d(
in_channels1, in_channels2, kernel_size=2, stride=2)
self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p)
self.res = nn.Conv2d(in_channels2 * 2, out_channels,1)

def forward(self, x1, x2):
if self.bilinear:
x1 = self.conv1x1(x1)
x1 = self.up(x1)
x = torch.cat([x2, x1], dim=1)

return self.conv(x) + self.res(x) # for resudial blocks


class Encoder(nn.Module):
def __init__(self, params):
super(Encoder, self).__init__()
self.params = params
self.in_chns = self.params['in_chns']
self.ft_chns = self.params['feature_chns']
self.n_class = self.params['class_num']
self.bilinear = self.params['bilinear']
self.dropout = self.params['dropout']
assert (len(self.ft_chns) == 5)
self.in_conv = ConvBlock(
self.in_chns, self.ft_chns[0], self.dropout[0])
self.down1 = DownBlock(
self.ft_chns[0], self.ft_chns[1], self.dropout[1])
self.down2 = DownBlock(
self.ft_chns[1], self.ft_chns[2], self.dropout[2])
self.down3 = DownBlock(
self.ft_chns[2], self.ft_chns[3], self.dropout[3])
self.down4 = DownBlock(
self.ft_chns[3], self.ft_chns[4], self.dropout[4])
self.down_res = nn.ModuleList([nn.Conv2d(self.in_chns,self.ft_chns[0],1),
nn.Conv2d(self.ft_chns[0],self.ft_chns[1],1),
nn.Conv2d(self.ft_chns[1],self.ft_chns[2],1),
nn.Conv2d(self.ft_chns[2],self.ft_chns[3],1)])

def forward(self, x):
x0 = self.in_conv(x) + self.down_res[0](x)
x1 = self.down1(x0,self.down_res[1](x0))
x2 = self.down2(x1,self.down_res[2](x1))
x3 = self.down3(x2,self.down_res[3](x2))
x4 = self.down4(x3,None,False)
return [x0, x1, x2, x3, x4]


class Decoder(nn.Module):
def __init__(self, params):
super(Decoder, self).__init__()
self.params = params
self.in_chns = self.params['in_chns']
self.ft_chns = self.params['feature_chns']
self.n_class = self.params['class_num']
self.bilinear = self.params['bilinear']
assert (len(self.ft_chns) == 5)

self.up1 = UpBlock(
self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p=0.0)
self.up2 = UpBlock(
self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p=0.0)
self.up3 = UpBlock(
self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p=0.0)
self.up4 = UpBlock(
self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p=0.0)

self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class,
kernel_size=3, padding=1)

def forward(self, feature):
x0 = feature[0]
x1 = feature[1]
x2 = feature[2]
x3 = feature[3]
x4 = feature[4]

xup0 = self.up1(x4, x3)
xup1 = self.up2(xup0, x2)
xup2 = self.up3(xup1, x1)
xup3 = self.up4(xup2, x0)

output = self.out_conv(xup3)
return output


class RESUNet(nn.Module):
def __init__(self, in_chns, class_num):
super(RESUNet, self).__init__()

params = {'in_chns': in_chns,
'feature_chns': [64, 128, 128, 256, 512],
'dropout': [0.05, 0.1, 0.2, 0.3, 0.5],
'class_num': class_num,
'bilinear': False,
'acti_func': 'relu'}

self.encoder = Encoder(params)
self.decoder = Decoder(params)

def forward(self, x, need_fp=False):
feature = self.encoder(x)
if need_fp:
outs = self.decoder([torch.cat((feat, nn.Dropout2d(0.5)(feat))) for feat in feature])
return outs.chunk(2)
output = self.decoder(feature)
return output

+ 0
- 0
base_line/Resunet/Heavy/__init__.py View File


+ 50
- 0
base_line/Resunet/Heavy/resunet_ddsm.py View File

from copy import deepcopy
import math
from random import uniform
from typing import List
import time
from mlassistant.core import Model, ModelIO
import torch
from torch.nn import functional as F
from torch import nn
from ......utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss
from ......utils.generalized_dice import dice_loss
from ......utils.losses.tverskyLoss import tversky_loss

from ......utils.dynamic_num_generator.num_generator import init_num_generator1, init_num_generator2
from ......utils.dynamic_num_generator.wrapper import dynamic_num_generator


class RESUNET(Model):
def __init__(self,
base_model: Model,
gamma = 4 , smooth=1, alpha=0.7, beta=0.3) -> None:
super(RESUNET, self).__init__()

self.base = base_model
self.gamma = gamma
self.smooth = smooth
self.alpha = alpha
self.beta = beta
print(self.gamma)
print("gamma")

def forward(self,
mammo_x: torch.Tensor,
mammo_roi_x: torch.Tensor) -> ModelIO:

output = dict(loss=0.0)
network_output = self.base.forward(mammo_x).softmax(dim=1)
network_output_soft = network_output
output['pixel_probs'] = network_output_soft

focal_loss = balanced_focal_cross_entropy_loss(
network_output_soft,
mammo_roi_x,
focal_gamma=self.gamma)

output['loss'] = focal_loss
output['org_pixel_labels'] = mammo_roi_x

return output

+ 79
- 0
base_line/Resunet/Heavy/resunet_focal.py View File

from copy import deepcopy
import math
from random import uniform
from typing import List
import time
from mlassistant.core import Model, ModelIO
import torch
from torch.nn import functional as F
from torch import nn
from ......utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi
from ......utils.generalized_dice import dice_loss
from ......utils.losses.tverskyLoss import tversky_loss

from ......utils.dynamic_num_generator.num_generator import init_num_generator1, init_num_generator2
from ......utils.dynamic_num_generator.wrapper import dynamic_num_generator





class RESUNET(Model):
def __init__(self,
base_model: Model,
gamma = 4 , smooth=1, alpha=0.7, beta=0.3) -> None:
super(RESUNET, self).__init__()

self.base = base_model
self.gamma = gamma
self.smooth = smooth
self.alpha = alpha
self.beta = beta
print(self.gamma)
print("gamma")

def forward(self,
mammo_x: torch.Tensor,
mammo_loss_and_gt: torch.Tensor) -> ModelIO:
"""first forwards supervised samples through student net, then all samples through both nets. """
output = dict(loss=0.0)
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)"
loss_type = mammo_loss_and_gt[:,0,:,:]
mask = mammo_loss_and_gt[:,1,:,:]
network_output = self.base.forward(mammo_x).softmax(dim=1) # logsoftmax # B
network_output_soft = network_output #torch.exp(network_output).to(mammo_x.device) # softmax output of model
output['pixel_probs'] = network_output_soft

main_shape = mask.shape
network_output_shape = network_output.shape



loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2])
mask = mask.unsqueeze(dim = 1)
mask_channels = torch.cat((1-mask, mask), 1)

dummy_mask = mask_channels.clone().to(mammo_x.device)
dummy_net_out = network_output_soft.clone().to(mammo_x.device)
dummy_mask[loss_channels == 0] = 0.0
dummy_shape = dummy_mask.shape
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_net_out = torch.cat((dummy_net_out,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_yhat = network_output_soft * loss_channels
dummy_y = mask_channels * loss_channels
focal_loss = balanced_focal_cross_entropy_loss_semi(
dummy_net_out,
dummy_mask,
focal_gamma=self.gamma)
dice = dice_loss(dummy_net_out,dummy_mask)
tversky = tversky_loss(inputs = dummy_yhat[:,1,...], targets = dummy_y[:,1,...], smooth=self.smooth, alpha = self.alpha, beta= self.beta)
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y)
output['ce_loss'] = ce_loss
output['loss'] = focal_loss
output['diceloss'] = dice
output['focalloss'] = focal_loss
output['suploss'] = output['loss']
output['tverskyloss'] = tversky
output['org_pixel_labels'] = mammo_loss_and_gt
return output

+ 107
- 0
base_line/Resunet/Heavy/resunet_focal_ent.py View File

from copy import deepcopy
import math
from random import uniform
from typing import List
import time
from mlassistant.core import Model, ModelIO
import torch
from torch.nn import functional as F
from torch import nn
from ......utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi
from ......utils.generalized_dice import dice_loss
from ......utils.losses.ent_losses import entropy_loss_normalized
from ......utils.losses.tverskyLoss import tversky_loss

from ......utils.dynamic_num_generator.num_generator import init_num_generator1, init_num_generator2
from ......utils.dynamic_num_generator.wrapper import dynamic_num_generator





class RESUNET(Model):
def __init__(self,
base_model: Model,
gamma = 2 , smooth=1, alpha=0.7, beta=0.3 , ent_factor = lambda x : 0.1) -> None:
super(RESUNET, self).__init__()

self.base = base_model
self.gamma = gamma
self.smooth = smooth
self.alpha = alpha
self.beta = beta
self.ent_factor = ent_factor
print(self.gamma)
self.epochNUM = 1
print(self.epochNUM)

def train(self,mode : bool=True):
super().train(mode)
self.epochNUM += 1

def forward(self,
mammo_x: torch.Tensor,
mammo_loss_and_gt: torch.Tensor) -> ModelIO:
"""first forwards supervised samples through student net, then all samples through both nets. """
output = dict(loss=0.0)
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)"
loss_type = mammo_loss_and_gt[:,0,:,:]
mask = mammo_loss_and_gt[:,1,:,:]
network_output = self.base.forward(mammo_x).softmax(dim=1) # logsoftmax # B
network_output_soft = network_output #torch.exp(network_output).to(mammo_x.device) # softmax output of model
output['pixel_probs'] = network_output_soft

main_shape = mask.shape
network_output_shape = network_output.shape



loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2])
mask = mask.unsqueeze(dim = 1)
mask_channels = torch.cat((1-mask, mask), 1)

dummy_mask = mask_channels.clone().to(mammo_x.device)
dummy_net_out = network_output_soft.clone().to(mammo_x.device)
dummy_mask[loss_channels == 0] = 0.0
dummy_shape = dummy_mask.shape
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_net_out = torch.cat((dummy_net_out,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
# dummy_yhat = network_output_soft * loss_channels
# dummy_y = mask_channels * loss_channels
focal_loss = balanced_focal_cross_entropy_loss_semi(
dummy_net_out,
dummy_mask,
focal_gamma=self.gamma)
#dice = dice_loss(dummy_net_out,dummy_mask)
#tversky = tversky_loss(inputs = dummy_yhat[:,1,...], targets = dummy_y[:,1,...], smooth=self.smooth, alpha = self.alpha, beta= self.beta)
#ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y)
entropy_loss_val = entropy_loss_normalized(network_output_soft)
entropy_suploss_val = torch.mean(entropy_loss_val * loss_type)
entropy_unsuploss_val = torch.mean(entropy_loss_val * (1-loss_type))
sup_ratio = torch.sum(loss_type == 1)
all_cel = torch.prod(torch.tensor(loss_type.shape)).to(mammo_x.device)
entropy_sup = entropy_suploss_val * (all_cel/sup_ratio)
entropy_unsup = entropy_unsuploss_val * (all_cel/(all_cel - sup_ratio + 1e-30))
entropy_loss_val = torch.mean(entropy_loss_val)
entloss = torch.mean(torch.tensor([entropy_sup, entropy_unsup])).to(mammo_x.device)

coef = self.ent_factor(int(self.epochNUM/2))
#output['ce_loss'] = ce_loss
output['loss'] = focal_loss + coef * entropy_loss_val
output['totalloss'] = output['loss']
#output['diceloss'] = dice
output['focalloss'] = focal_loss
output['suploss'] = focal_loss
output['unsuploss'] = entropy_loss_val
output['coef_loss'] = torch.tensor(coef).to(mammo_x.device)
#output['tverskyloss'] = tversky
output['org_pixel_labels'] = mammo_loss_and_gt
output["entallloss"] = entropy_loss_val
output["supentloss"] = entropy_sup
output["unsupentloss"] = entropy_unsup
output["entloss"] = entloss
return output

+ 96
- 0
base_line/Resunet/Heavy/resunet_topk_focal.py View File

from copy import deepcopy
import math
from random import uniform
from typing import List
import time
from mlassistant.core import Model, ModelIO
import torch
from torch.nn import functional as F
from torch import nn
from ......utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi
from ......utils.generalized_dice import dice_loss
from ......utils.losses.tverskyLoss import tversky_loss
from ......utils.losses.region_topk_focal_ce import calculate_topk_focal_ce_loss_per_region as topk

from ......utils.dynamic_num_generator.num_generator import init_num_generator1, init_num_generator2
from ......utils.dynamic_num_generator.wrapper import dynamic_num_generator






class RESUNET(Model):
def __init__(self,
base_model: Model,
gamma = 4 , smooth=1, alpha=0.7, beta=0.3, gskernel = None) -> None:
super(RESUNET, self).__init__()

self.base = base_model
self.gamma = gamma
self.smooth = smooth
self.alpha = alpha
self.beta = beta
self.gskernel = gskernel
self.epochNUM = 1
print(self.gamma)
print("gamma")

def train(self,mode : bool=True):
super().train(mode)
self.epochNUM += 1

def forward(self,
mammo_x: torch.Tensor,
mammo_loss_and_gt: torch.Tensor) -> ModelIO:
"""first forwards supervised samples through student net, then all samples through both nets. """
output = dict(loss=0.0)
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)"
loss_type = mammo_loss_and_gt[:,0,:,:]
mask = mammo_loss_and_gt[:,1,:,:]
network_output = self.base.forward(mammo_x).softmax(dim=1) # logsoftmax # B
network_output_soft = network_output #torch.exp(network_output).to(mammo_x.device) # softmax output of model
output['pixel_probs'] = network_output_soft


main_shape = mask.shape
network_output_shape = network_output.shape



loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2])
mask = mask.unsqueeze(dim = 1)
mask_channels = torch.cat((1-mask, mask), 1)

dummy_mask = mask_channels.clone().to(mammo_x.device)
dummy_net_out = network_output_soft.clone().to(mammo_x.device)
dummy_mask[loss_channels == 0] = 0.0
dummy_shape = dummy_mask.shape
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_net_out = torch.cat((dummy_net_out,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_yhat = network_output_soft * loss_channels
dummy_y = mask_channels * loss_channels
focal_loss = balanced_focal_cross_entropy_loss_semi(
dummy_net_out,
dummy_mask,
focal_gamma=self.gamma)
dice = dice_loss(dummy_yhat,dummy_y)
tversky = tversky_loss(inputs = dummy_yhat[:,1,...], targets = dummy_y[:,1,...], smooth=self.smooth, alpha = self.alpha, beta= self.beta)
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y)
topk_loss = topk(model_preds=dummy_yhat,roi_mask = dummy_y[:,0,...], fill_ratio=0.5, focal_gamma=self.gamma)
coef = self.gskernel(int(self.epochNUM/2))


output['ce_loss'] = ce_loss
output['topk_loss'] = topk_loss
output['gscoef_loss'] = torch.tensor(coef).to(mammo_x.device)
output['loss'] = coef * topk_loss + (1-coef) * focal_loss
output['diceloss'] = dice
output['focalloss'] = focal_loss
output['suploss'] = output['loss']
output['tverskyloss'] = tversky
output['org_pixel_labels'] = mammo_loss_and_gt
return output

+ 0
- 0
base_line/Resunet/__init__.py View File


+ 80
- 0
base_line/Resunet/resunet_ce.py View File

from copy import deepcopy
import math
from random import uniform
from typing import List
import time
from mlassistant.core import Model, ModelIO
import torch
from torch.nn import functional as F
from torch import nn
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi
from .....utils.generalized_dice import dice_loss
from .....utils.losses.tverskyLoss import tversky_loss

from .....utils.dynamic_num_generator.num_generator import init_num_generator1, init_num_generator2
from .....utils.dynamic_num_generator.wrapper import dynamic_num_generator





class RESUNET(Model):
def __init__(self,
base_model: Model,
gamma = 2 , smooth=1, alpha=0.7, beta=0.3) -> None:
super(RESUNET, self).__init__()

self.base = base_model
self.gamma = gamma
self.smooth = smooth
self.alpha = alpha
self.beta = beta

def forward(self,
mammo_x: torch.Tensor,
mammo_loss_and_gt: torch.Tensor) -> ModelIO:
"""first forwards supervised samples through student net, then all samples through both nets. """
output = dict(loss=0.0)
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)"
loss_type = mammo_loss_and_gt[:,0,:,:]
mask = mammo_loss_and_gt[:,1,:,:]
network_output = self.base._forward(mammo_x) # logsoftmax # B
network_output_soft = torch.exp(network_output).to(mammo_x.device) # softmax output of model
output['pixel_probs'] = network_output_soft

main_shape = mask.shape
network_output_shape = network_output.shape



loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2])
mask = mask.unsqueeze(dim = 1)
mask_channels = torch.cat((1-mask, mask), 1)

dummy_mask = mask_channels.clone().to(mammo_x.device)
dummy_net_out = network_output_soft.clone().to(mammo_x.device)
dummy_mask[loss_channels == 0] = 0.0
dummy_shape = dummy_mask.shape
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_net_out = torch.cat((dummy_net_out,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_yhat = network_output_soft * loss_channels
dummy_y = mask_channels * loss_channels
focal_loss = balanced_focal_cross_entropy_loss_semi(
dummy_net_out,
dummy_mask,
focal_gamma=self.gamma)
dice = dice_loss(dummy_net_out,dummy_mask)
tversky = tversky_loss(inputs = dummy_yhat[:,1,...], targets = dummy_y[:,1,...], smooth=self.smooth, alpha = self.alpha, beta= self.beta)
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y)
output['ce_1_loss'] = F.binary_cross_entropy(dummy_yhat * dummy_y,dummy_y)
output['ce_0_loss'] = F.binary_cross_entropy(dummy_yhat * (1 - dummy_y) + dummy_y, dummy_y)
output['ce_loss'] = ce_loss
output['loss'] = ce_loss
output['diceloss'] = dice
output['focalloss'] = focal_loss
output['suploss'] = output['loss']
output['tverskyloss'] = tversky
output['org_pixel_labels'] = mammo_loss_and_gt
return output

+ 80
- 0
base_line/Resunet/resunet_ce_mean.py View File

from copy import deepcopy
import math
from random import uniform
from typing import List
import time
from mlassistant.core import Model, ModelIO
import torch
from torch.nn import functional as F
from torch import nn
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi
from .....utils.generalized_dice import dice_loss
from .....utils.losses.tverskyLoss import tversky_loss

from .....utils.dynamic_num_generator.num_generator import init_num_generator1, init_num_generator2
from .....utils.dynamic_num_generator.wrapper import dynamic_num_generator





class RESUNET(Model):
def __init__(self,
base_model: Model,
gamma = 2 , smooth=1, alpha=0.7, beta=0.3) -> None:
super(RESUNET, self).__init__()

self.base = base_model
self.gamma = gamma
self.smooth = smooth
self.alpha = alpha
self.beta = beta

def forward(self,
mammo_x: torch.Tensor,
mammo_loss_and_gt: torch.Tensor) -> ModelIO:
"""first forwards supervised samples through student net, then all samples through both nets. """
output = dict(loss=0.0)
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)"
loss_type = mammo_loss_and_gt[:,0,:,:]
mask = mammo_loss_and_gt[:,1,:,:]
network_output = self.base._forward(mammo_x) # logsoftmax # B
network_output_soft = torch.exp(network_output).to(mammo_x.device) # softmax output of model
output['pixel_probs'] = network_output_soft

main_shape = mask.shape
network_output_shape = network_output.shape



loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2])
mask = mask.unsqueeze(dim = 1)
mask_channels = torch.cat((1-mask, mask), 1)

dummy_mask = mask_channels.clone().to(mammo_x.device)
dummy_net_out = network_output_soft.clone().to(mammo_x.device)
dummy_mask[loss_channels == 0] = 0.0
dummy_shape = dummy_mask.shape
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_net_out = torch.cat((dummy_net_out,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_yhat = network_output_soft * loss_channels
dummy_y = mask_channels * loss_channels
focal_loss = balanced_focal_cross_entropy_loss_semi(
dummy_net_out,
dummy_mask,
focal_gamma=self.gamma)
dice = dice_loss(dummy_net_out,dummy_mask)
tversky = tversky_loss(inputs = dummy_yhat[:,1,...], targets = dummy_y[:,1,...], smooth=self.smooth, alpha = self.alpha, beta= self.beta)
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y)
output['ce_1_loss'] = F.binary_cross_entropy(dummy_yhat * dummy_y,dummy_y)
output['ce_0_loss'] = F.binary_cross_entropy(dummy_yhat * (1 - dummy_y) + dummy_y, dummy_y)
output['ce_loss'] = ce_loss
output['loss'] = output['ce_1_loss'] + output['ce_0_loss']
output['diceloss'] = dice
output['focalloss'] = focal_loss
output['suploss'] = output['loss']
output['tverskyloss'] = tversky
output['org_pixel_labels'] = mammo_loss_and_gt
return output

+ 80
- 0
base_line/Resunet/resunet_dice.py View File

from copy import deepcopy
import math
from random import uniform
from typing import List
import time
from mlassistant.core import Model, ModelIO
import torch
from torch.nn import functional as F
from torch import nn
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi
from .....utils.generalized_dice import dice_loss
from .....utils.losses.tverskyLoss import tversky_loss

from .....utils.dynamic_num_generator.num_generator import init_num_generator1, init_num_generator2
from .....utils.dynamic_num_generator.wrapper import dynamic_num_generator





class RESUNET(Model):
def __init__(self,
base_model: Model,
gamma = 2 , smooth=1, alpha=0.7, beta=0.3) -> None:
super(RESUNET, self).__init__()

self.base = base_model
self.gamma = gamma
self.smooth = smooth
self.alpha = alpha
self.beta = beta

def forward(self,
mammo_x: torch.Tensor,
mammo_loss_and_gt: torch.Tensor) -> ModelIO:
"""first forwards supervised samples through student net, then all samples through both nets. """
output = dict(loss=0.0)
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)"
loss_type = mammo_loss_and_gt[:,0,:,:]
mask = mammo_loss_and_gt[:,1,:,:]
network_output = self.base._forward(mammo_x) # logsoftmax # B
network_output_soft = torch.exp(network_output).to(mammo_x.device) # softmax output of model
output['pixel_probs'] = network_output_soft

main_shape = mask.shape
network_output_shape = network_output.shape



loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2])
mask = mask.unsqueeze(dim = 1)
mask_channels = torch.cat((1-mask, mask), 1)

dummy_mask = mask_channels.clone().to(mammo_x.device)
dummy_net_out = network_output_soft.clone().to(mammo_x.device)
dummy_mask[loss_channels == 0] = 0.0
dummy_shape = dummy_mask.shape
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_net_out = torch.cat((dummy_net_out,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_yhat = network_output_soft * loss_channels
dummy_y = mask_channels * loss_channels
focal_loss = balanced_focal_cross_entropy_loss_semi(
dummy_net_out,
dummy_mask,
focal_gamma=self.gamma)
dice = dice_loss(dummy_net_out,dummy_mask)
tversky = tversky_loss(inputs = dummy_yhat[:,1,...], targets = dummy_y[:,1,...], smooth=self.smooth, alpha = self.alpha, beta= self.beta)
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y)
output['ce_1_loss'] = F.binary_cross_entropy(dummy_yhat * dummy_y,dummy_y)
output['ce_0_loss'] = F.binary_cross_entropy(dummy_yhat * (1 - dummy_y) + dummy_y, dummy_y)
output['ce_loss'] = ce_loss
output['loss'] = dice
output['diceloss'] = dice
output['focalloss'] = focal_loss
output['suploss'] = dice
output['tverskyloss'] = tversky
output['org_pixel_labels'] = mammo_loss_and_gt
return output

+ 79
- 0
base_line/Resunet/resunet_focal.py View File

from copy import deepcopy
import math
from random import uniform
from typing import List
import time
from mlassistant.core import Model, ModelIO
import torch
from torch.nn import functional as F
from torch import nn
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi
from .....utils.generalized_dice import dice_loss
from .....utils.losses.tverskyLoss import tversky_loss

from .....utils.dynamic_num_generator.num_generator import init_num_generator1, init_num_generator2
from .....utils.dynamic_num_generator.wrapper import dynamic_num_generator





class RESUNET(Model):
def __init__(self,
base_model: Model,
gamma = 4 , smooth=1, alpha=0.7, beta=0.3) -> None:
super(RESUNET, self).__init__()

self.base = base_model
self.gamma = gamma
self.smooth = smooth
self.alpha = alpha
self.beta = beta
print(self.gamma)
print("gamma")

def forward(self,
mammo_x: torch.Tensor,
mammo_loss_and_gt: torch.Tensor) -> ModelIO:
"""first forwards supervised samples through student net, then all samples through both nets. """
output = dict(loss=0.0)
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)"
loss_type = mammo_loss_and_gt[:,0,:,:]
mask = mammo_loss_and_gt[:,1,:,:]
network_output = self.base._forward(mammo_x) # logsoftmax # B
network_output_soft = torch.exp(network_output).to(mammo_x.device) # softmax output of model
output['pixel_probs'] = network_output_soft

main_shape = mask.shape
network_output_shape = network_output.shape



loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2])
mask = mask.unsqueeze(dim = 1)
mask_channels = torch.cat((1-mask, mask), 1)

dummy_mask = mask_channels.clone().to(mammo_x.device)
dummy_net_out = network_output_soft.clone().to(mammo_x.device)
dummy_mask[loss_channels == 0] = 0.0
dummy_shape = dummy_mask.shape
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_net_out = torch.cat((dummy_net_out,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_yhat = network_output_soft * loss_channels
dummy_y = mask_channels * loss_channels
focal_loss = balanced_focal_cross_entropy_loss_semi(
dummy_net_out,
dummy_mask,
focal_gamma=self.gamma)
dice = dice_loss(dummy_net_out,dummy_mask)
tversky = tversky_loss(inputs = dummy_yhat[:,1,...], targets = dummy_y[:,1,...], smooth=self.smooth, alpha = self.alpha, beta= self.beta)
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y)
output['ce_loss'] = ce_loss
output['loss'] = focal_loss
output['diceloss'] = dice
output['focalloss'] = focal_loss
output['suploss'] = output['loss']
output['tverskyloss'] = tversky
output['org_pixel_labels'] = mammo_loss_and_gt
return output

+ 107
- 0
base_line/Resunet/resunet_focal_ent.py View File

from copy import deepcopy
import math
from random import uniform
from typing import List
import time
from mlassistant.core import Model, ModelIO
import torch
from torch.nn import functional as F
from torch import nn
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi
from .....utils.generalized_dice import dice_loss
from .....utils.losses.ent_losses import entropy_loss_normalized
from .....utils.losses.tverskyLoss import tversky_loss

from .....utils.dynamic_num_generator.num_generator import init_num_generator1, init_num_generator2
from .....utils.dynamic_num_generator.wrapper import dynamic_num_generator





class RESUNET(Model):
def __init__(self,
base_model: Model,
gamma = 2 , smooth=1, alpha=0.7, beta=0.3 , ent_factor = lambda x : 0.1) -> None:
super(RESUNET, self).__init__()

self.base = base_model
self.gamma = gamma
self.smooth = smooth
self.alpha = alpha
self.beta = beta
self.ent_factor = ent_factor
print(self.gamma)
self.epochNUM = 200
print(self.epochNUM)

def train(self,mode : bool=True):
super().train(mode)
self.epochNUM += 1

def forward(self,
mammo_x: torch.Tensor,
mammo_loss_and_gt: torch.Tensor) -> ModelIO:
"""first forwards supervised samples through student net, then all samples through both nets. """
output = dict(loss=0.0)
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)"
loss_type = mammo_loss_and_gt[:,0,:,:]
mask = mammo_loss_and_gt[:,1,:,:]
network_output = self.base._forward(mammo_x) # logsoftmax # B
network_output_soft = torch.exp(network_output).to(mammo_x.device) # softmax output of model
output['pixel_probs'] = network_output_soft

main_shape = mask.shape
network_output_shape = network_output.shape



loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2])
mask = mask.unsqueeze(dim = 1)
mask_channels = torch.cat((1-mask, mask), 1)

dummy_mask = mask_channels.clone().to(mammo_x.device)
dummy_net_out = network_output_soft.clone().to(mammo_x.device)
dummy_mask[loss_channels == 0] = 0.0
dummy_shape = dummy_mask.shape
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_net_out = torch.cat((dummy_net_out,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_yhat = network_output_soft * loss_channels
dummy_y = mask_channels * loss_channels
focal_loss = balanced_focal_cross_entropy_loss_semi(
dummy_net_out,
dummy_mask,
focal_gamma=self.gamma)
dice = dice_loss(dummy_net_out,dummy_mask)
tversky = tversky_loss(inputs = dummy_yhat[:,1,...], targets = dummy_y[:,1,...], smooth=self.smooth, alpha = self.alpha, beta= self.beta)
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y)
entropy_loss_val = entropy_loss_normalized(network_output_soft)
entropy_suploss_val = torch.mean(entropy_loss_val * loss_type)
entropy_unsuploss_val = torch.mean(entropy_loss_val * (1-loss_type))
sup_ratio = torch.sum(loss_type == 1)
all_cel = torch.prod(torch.tensor(loss_type.shape)).to(mammo_x.device)
entropy_sup = entropy_suploss_val * (all_cel/sup_ratio)
entropy_unsup = entropy_unsuploss_val * (all_cel/(all_cel - sup_ratio + 1e-30))
entropy_loss_val = torch.mean(entropy_loss_val)
entloss = torch.mean(torch.tensor([entropy_sup, entropy_unsup])).to(mammo_x.device)

coef = self.ent_factor(int(self.epochNUM/2))
output['ce_loss'] = ce_loss
output['loss'] = focal_loss + coef * entropy_loss_val
output['totalloss'] = output['loss']
output['diceloss'] = dice
output['focalloss'] = focal_loss
output['suploss'] = focal_loss
output['unsuploss'] = entropy_loss_val
output['coef_loss'] = torch.tensor(coef).to(mammo_x.device)
output['tverskyloss'] = tversky
output['org_pixel_labels'] = mammo_loss_and_gt
output["entallloss"] = entropy_loss_val
output["supentloss"] = entropy_sup
output["unsupentloss"] = entropy_unsup
output["entloss"] = entloss
return output

+ 94
- 0
base_line/Resunet/resunet_tilted.py View File

from copy import deepcopy
import math
from random import uniform
from typing import List
import time
from mlassistant.core import Model, ModelIO
import torch
from torch.nn import functional as F
from torch import nn
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi_pixel_wise
from .....utils.generalized_dice import dice_loss
from .....utils.losses.tverskyLoss import tversky_loss

from .....utils.dynamic_num_generator.num_generator import init_num_generator1, init_num_generator2
from .....utils.dynamic_num_generator.wrapper import dynamic_num_generator





class RESUNET(Model):
def __init__(self,
base_model: Model,
t = -2 , smooth=1, alpha=0.7, beta=0.3 , gamma = 2) -> None:
super(RESUNET, self).__init__()

self.base = base_model
self.t = t
self.smooth = smooth
self.alpha = alpha
self.beta = beta
print(self.t , gamma)
self.gamma = gamma
print("gamma")

def forward(self,
mammo_x: torch.Tensor,
mammo_loss_and_gt: torch.Tensor) -> ModelIO:
"""first forwards supervised samples through student net, then all samples through both nets. """
output = dict(loss=0.0)
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)"
loss_type = mammo_loss_and_gt[:,0,:,:]
mask = mammo_loss_and_gt[:,1,:,:]
network_output = self.base._forward(mammo_x) # logsoftmax # B
network_output_soft = torch.exp(network_output).to(mammo_x.device) # softmax output of model
output['pixel_probs'] = network_output_soft

main_shape = mask.shape
network_output_shape = network_output.shape



loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2])
mask = mask.unsqueeze(dim = 1)
mask_channels = torch.cat((1-mask, mask), 1)

dummy_mask = mask_channels.clone().to(mammo_x.device)
dummy_net_out = network_output_soft.clone().to(mammo_x.device)
dummy_mask[loss_channels == 0] = 0.0
dummy_shape = dummy_mask.shape
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_net_out = torch.cat((dummy_net_out,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_yhat = network_output_soft * loss_channels
dummy_y = mask_channels * loss_channels
focal_loss = balanced_focal_cross_entropy_loss_semi(
dummy_net_out,
dummy_mask,
focal_gamma=self.gamma).to(mammo_x.device)
dice = dice_loss(dummy_net_out,dummy_mask)
tversky = tversky_loss(inputs = dummy_yhat[:,1,...], targets = dummy_y[:,1,...], smooth=self.smooth, alpha = self.alpha, beta= self.beta)
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y)
ce_per,losses = balanced_focal_cross_entropy_loss_semi_pixel_wise(
dummy_net_out,
dummy_mask,
focal_gamma=self.gamma)
ce_per = ce_per.to(mammo_x.device)
tilted = 1/self.t * torch.log(torch.mean(torch.exp(self.t * ce_per))).to(mammo_x.device)
til0 = 1/self.t * torch.log(torch.mean(torch.exp(self.t * losses[0].to(mammo_x.device)))).to(mammo_x.device) if len(losses) > 0 else torch.tensor(0)
til1 = 1/self.t * torch.log(torch.mean(torch.exp(self.t * losses[1].to(mammo_x.device)))).to(mammo_x.device) if len(losses) > 1 else torch.tensor(0)
til1,til0 = til1.to(mammo_x.device), til0.to(mammo_x.device)

output['ce_loss'] = ce_loss
output["tilted_loss"] = tilted
output['loss'] = torch.mean(torch.stack([til0,til1])).to(mammo_x.device) + focal_loss
output["til0_loss"] = til0
output["til1_loss"] = til1
output['diceloss'] = dice
output['focalloss'] = focal_loss
output['suploss'] = output['loss']
output['tverskyloss'] = tversky
output['org_pixel_labels'] = mammo_loss_and_gt
return output

+ 85
- 0
base_line/Resunet/resunet_topk.py View File

from copy import deepcopy
import math
from random import uniform
from typing import List
import time
from mlassistant.core import Model, ModelIO
import torch
from torch.nn import functional as F
from torch import nn
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi
from .....utils.generalized_dice import dice_loss
from .....utils.losses.tverskyLoss import tversky_loss
from .....utils.losses.region_topk_focal_ce import calculate_topk_focal_ce_loss_per_region as topk

from .....utils.dynamic_num_generator.num_generator import init_num_generator1, init_num_generator2
from .....utils.dynamic_num_generator.wrapper import dynamic_num_generator





class RESUNET(Model):
def __init__(self,
base_model: Model,
gamma = 4 , smooth=1, alpha=0.7, beta=0.3) -> None:
super(RESUNET, self).__init__()

self.base = base_model
self.gamma = gamma
self.smooth = smooth
self.alpha = alpha
self.beta = beta
print(self.gamma)
print("gamma")

def forward(self,
mammo_x: torch.Tensor,
mammo_loss_and_gt: torch.Tensor) -> ModelIO:
"""first forwards supervised samples through student net, then all samples through both nets. """
output = dict(loss=0.0)
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)"
loss_type = mammo_loss_and_gt[:,0,:,:]
mask = mammo_loss_and_gt[:,1,:,:]
network_output = self.base._forward(mammo_x) # logsoftmax # B
network_output_soft = torch.exp(network_output).to(mammo_x.device) # softmax output of model
output['pixel_probs'] = network_output_soft

main_shape = mask.shape
network_output_shape = network_output.shape



loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2])
mask = mask.unsqueeze(dim = 1)
mask_channels = torch.cat((1-mask, mask), 1)

dummy_mask = mask_channels.clone().to(mammo_x.device)
dummy_net_out = network_output_soft.clone().to(mammo_x.device)
dummy_mask[loss_channels == 0] = 0.0
dummy_shape = dummy_mask.shape
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_net_out = torch.cat((dummy_net_out,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_yhat = network_output_soft * loss_channels
dummy_y = mask_channels * loss_channels
focal_loss = balanced_focal_cross_entropy_loss_semi(
dummy_net_out,
dummy_mask,
focal_gamma=self.gamma)
dice = dice_loss(dummy_yhat,dummy_y)
tversky = tversky_loss(inputs = dummy_yhat[:,1,...], targets = dummy_y[:,1,...], smooth=self.smooth, alpha = self.alpha, beta= self.beta)
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y)
topk_loss = topk(model_preds=dummy_yhat,roi_mask = dummy_y[:,0,...], fill_ratio=0.5, focal_gamma=self.gamma)


output['ce_loss'] = ce_loss
output['topk_loss'] = topk_loss
output['loss'] = topk_loss
output['diceloss'] = dice
output['focalloss'] = focal_loss
output['suploss'] = output['loss']
output['tverskyloss'] = tversky
output['org_pixel_labels'] = mammo_loss_and_gt
return output

+ 87
- 0
base_line/Resunet/resunet_topk_focal.py View File

from copy import deepcopy
import math
from random import uniform
from typing import List
import time
from mlassistant.core import Model, ModelIO
import torch
from torch.nn import functional as F
from torch import nn
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi
from .....utils.generalized_dice import dice_loss
from .....utils.losses.tverskyLoss import tversky_loss
from .....utils.losses.region_topk_focal_ce import calculate_topk_focal_ce_loss_per_region as topk

from .....utils.dynamic_num_generator.num_generator import init_num_generator1, init_num_generator2
from .....utils.dynamic_num_generator.wrapper import dynamic_num_generator






class RESUNET(Model):
def __init__(self,
base_model: Model,
gamma = 4 , smooth=1, alpha=0.7, beta=0.3) -> None:
super(RESUNET, self).__init__()

self.base = base_model
self.gamma = gamma
self.smooth = smooth
self.alpha = alpha
self.beta = beta
print(self.gamma)
print("gamma")

def forward(self,
mammo_x: torch.Tensor,
mammo_loss_and_gt: torch.Tensor) -> ModelIO:
"""first forwards supervised samples through student net, then all samples through both nets. """
output = dict(loss=0.0)
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)"
loss_type = mammo_loss_and_gt[:,0,:,:]
mask = mammo_loss_and_gt[:,1,:,:]
network_output = self.base._forward(mammo_x) # logsoftmax # B
network_output_soft = torch.exp(network_output).to(mammo_x.device) # softmax output of model
output['pixel_probs'] = network_output_soft

main_shape = mask.shape
network_output_shape = network_output.shape



loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2])
mask = mask.unsqueeze(dim = 1)
mask_channels = torch.cat((1-mask, mask), 1)

dummy_mask = mask_channels.clone().to(mammo_x.device)
dummy_net_out = network_output_soft.clone().to(mammo_x.device)
dummy_mask[loss_channels == 0] = 0.0
dummy_shape = dummy_mask.shape
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_net_out = torch.cat((dummy_net_out,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_yhat = network_output_soft * loss_channels
dummy_y = mask_channels * loss_channels
focal_loss = balanced_focal_cross_entropy_loss_semi(
dummy_net_out,
dummy_mask,
focal_gamma=self.gamma)
dice = dice_loss(dummy_yhat,dummy_y)
tversky = tversky_loss(inputs = dummy_yhat[:,1,...], targets = dummy_y[:,1,...], smooth=self.smooth, alpha = self.alpha, beta= self.beta)
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y)
topk_loss = topk(model_preds=dummy_yhat,roi_mask = dummy_y[:,0,...], fill_ratio=0.5, focal_gamma=self.gamma)


output['ce_loss'] = ce_loss
output['topk_loss'] = topk_loss
output['loss'] = topk_loss
output['diceloss'] = dice
output['focalloss'] = focal_loss
output['suploss'] = output['loss']
output['tverskyloss'] = tversky
output['org_pixel_labels'] = mammo_loss_and_gt
return output

+ 85
- 0
base_line/Resunet/resunet_tvfocal.py View File

from copy import deepcopy
import math
from random import uniform
from typing import List
import time
from mlassistant.core import Model, ModelIO
import torch
from torch.nn import functional as F
from torch import nn
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi
from .....utils.generalized_dice import dice_loss
from .....utils.losses.tverskyLoss import tversky_loss
from .....utils.losses.focal_tversky import focal_tversky
from .....utils.losses.region_topk_focal_ce import calculate_topk_focal_ce_loss_per_region as topk

from .....utils.dynamic_num_generator.num_generator import init_num_generator1, init_num_generator2
from .....utils.dynamic_num_generator.wrapper import dynamic_num_generator





class RESUNET(Model):
def __init__(self,
base_model: Model,
gamma = 4 , smooth=1, alpha=0.7, beta=0.3) -> None:
super(RESUNET, self).__init__()

self.base = base_model
self.gamma = gamma
self.smooth = smooth
self.alpha = alpha
self.beta = beta
print(self.gamma)
print("gamma")

def forward(self,
mammo_x: torch.Tensor,
mammo_loss_and_gt: torch.Tensor) -> ModelIO:
"""first forwards supervised samples through student net, then all samples through both nets. """
output = dict(loss=0.0)
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)"
loss_type = mammo_loss_and_gt[:,0,:,:]
mask = mammo_loss_and_gt[:,1,:,:]
network_output = self.base._forward(mammo_x) # logsoftmax # B
network_output_soft = torch.exp(network_output).to(mammo_x.device) # softmax output of model
output['pixel_probs'] = network_output_soft

main_shape = mask.shape
network_output_shape = network_output.shape



loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2])
mask = mask.unsqueeze(dim = 1)
mask_channels = torch.cat((1-mask, mask), 1)

dummy_mask = mask_channels.clone().to(mammo_x.device)
dummy_net_out = network_output_soft.clone().to(mammo_x.device)
dummy_mask[loss_channels == 0] = 0.0
dummy_shape = dummy_mask.shape
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_net_out = torch.cat((dummy_net_out,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_yhat = network_output_soft * loss_channels
dummy_y = mask_channels * loss_channels
# focal_loss = balanced_focal_cross_entropy_loss_semi(
# dummy_net_out,
# dummy_mask,
# focal_gamma=self.gamma)
dice = dice_loss(dummy_yhat,dummy_y)
tversky = tversky_loss(inputs = dummy_yhat[:,1,...], targets = dummy_y[:,1,...], smooth=self.smooth, alpha = self.alpha, beta= self.beta)
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y)
#topk_loss = topk(model_preds=dummy_yhat,roi_mask = dummy_y[:,0,...], fill_ratio=0.5, focal_gamma=self.gamma)
tvfocal_loss = focal_tversky(dummy_y[:,1,...].flatten(),dummy_yhat[:,1,...].flatten())

output['ce_loss'] = ce_loss
output['tvfocal_loss'] = tvfocal_loss
output['loss'] = tvfocal_loss
output['diceloss'] = dice
#output['focalloss'] = focal_loss
output['suploss'] = output['loss']
output['tverskyloss'] = tversky
output['org_pixel_labels'] = mammo_loss_and_gt
return output

+ 174
- 0
base_line/UNET/Model/unet.py View File

import numpy as np
import torch
import torch.nn as nn
from torch.distributions.uniform import Uniform


def kaiming_normal_init_weight(model):
for m in model.modules():
if isinstance(m, nn.Conv3d):
torch.nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm3d):
m.weight.data.fill_(1)
m.bias.data.zero_()
return model


def sparse_init_weight(model):
for m in model.modules():
if isinstance(m, nn.Conv3d):
torch.nn.init.sparse_(m.weight, sparsity=0.1)
elif isinstance(m, nn.BatchNorm3d):
m.weight.data.fill_(1)
m.bias.data.zero_()
return model


class ConvBlock(nn.Module):
"""two convolution layers with batch norm and leaky relu"""

def __init__(self, in_channels, out_channels, dropout_p):
super(ConvBlock, self).__init__()
self.conv_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(),
nn.Dropout(dropout_p),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU()
)

def forward(self, x):
return self.conv_conv(x)


class DownBlock(nn.Module):
"""Downsampling followed by ConvBlock"""

def __init__(self, in_channels, out_channels, dropout_p):
super(DownBlock, self).__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
ConvBlock(in_channels, out_channels, dropout_p)

)

def forward(self, x):
return self.maxpool_conv(x)


class UpBlock(nn.Module):
"""Upssampling followed by ConvBlock"""

def __init__(self, in_channels1, in_channels2, out_channels, dropout_p,
bilinear=True):
super(UpBlock, self).__init__()
self.bilinear = bilinear
if bilinear:
self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size=1)
self.up = nn.Upsample(
scale_factor=2, mode='bilinear', align_corners=True)
else:
self.up = nn.ConvTranspose2d(
in_channels1, in_channels2, kernel_size=2, stride=2)
self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p)

def forward(self, x1, x2):
if self.bilinear:
x1 = self.conv1x1(x1)
x1 = self.up(x1)
x = torch.cat([x2, x1], dim=1)
return self.conv(x)


class Encoder(nn.Module):
def __init__(self, params):
super(Encoder, self).__init__()
self.params = params
self.in_chns = self.params['in_chns']
self.ft_chns = self.params['feature_chns']
self.n_class = self.params['class_num']
self.bilinear = self.params['bilinear']
self.dropout = self.params['dropout']
assert (len(self.ft_chns) == 5)
self.in_conv = ConvBlock(
self.in_chns, self.ft_chns[0], self.dropout[0])
self.down1 = DownBlock(
self.ft_chns[0], self.ft_chns[1], self.dropout[1])
self.down2 = DownBlock(
self.ft_chns[1], self.ft_chns[2], self.dropout[2])
self.down3 = DownBlock(
self.ft_chns[2], self.ft_chns[3], self.dropout[3])
self.down4 = DownBlock(
self.ft_chns[3], self.ft_chns[4], self.dropout[4])

def forward(self, x):
x0 = self.in_conv(x)
x1 = self.down1(x0)
x2 = self.down2(x1)
x3 = self.down3(x2)
x4 = self.down4(x3)
return [x0, x1, x2, x3, x4]


class Decoder(nn.Module):
def __init__(self, params):
super(Decoder, self).__init__()
self.params = params
self.in_chns = self.params['in_chns']
self.ft_chns = self.params['feature_chns']
self.n_class = self.params['class_num']
self.bilinear = self.params['bilinear']
assert (len(self.ft_chns) == 5)

self.up1 = UpBlock(
self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p=0.0)
self.up2 = UpBlock(
self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p=0.0)
self.up3 = UpBlock(
self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p=0.0)
self.up4 = UpBlock(
self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p=0.0)

self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class,
kernel_size=3, padding=1)

def forward(self, feature):
x0 = feature[0]
x1 = feature[1]
x2 = feature[2]
x3 = feature[3]
x4 = feature[4]

x = self.up1(x4, x3)
x = self.up2(x, x2)
x = self.up3(x, x1)
x = self.up4(x, x0)
output = self.out_conv(x)
return output


class UNet(nn.Module):
def __init__(self, in_chns, class_num):
super(UNet, self).__init__()

params = {'in_chns': in_chns,
'feature_chns': [64, 128, 128, 256, 512],
'dropout': [0.05, 0.1, 0.2, 0.3, 0.5],
'class_num': class_num,
'bilinear': False,
'acti_func': 'relu'}

self.encoder = Encoder(params)
self.decoder = Decoder(params)

def forward(self, x, need_fp=False):
feature = self.encoder(x)
if need_fp:
outs = self.decoder([torch.cat((feat, nn.Dropout2d(0.5)(feat))) for feat in feature])
return outs.chunk(2)
output = self.decoder(feature)
return output

+ 79
- 0
base_line/UNET/resunet_focal.py View File

from copy import deepcopy
import math
from random import uniform
from typing import List
import time
from mlassistant.core import Model, ModelIO
import torch
from torch.nn import functional as F
from torch import nn
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi
from .....utils.generalized_dice import dice_loss
from .....utils.losses.tverskyLoss import tversky_loss

from .....utils.dynamic_num_generator.num_generator import init_num_generator1, init_num_generator2
from .....utils.dynamic_num_generator.wrapper import dynamic_num_generator





class RESUNET(Model):
def __init__(self,
base_model: Model,
gamma = 4 , smooth=1, alpha=0.7, beta=0.3) -> None:
super(RESUNET, self).__init__()

self.base = base_model
self.gamma = gamma
self.smooth = smooth
self.alpha = alpha
self.beta = beta
print(self.gamma)
print("gamma")

def forward(self,
mammo_x: torch.Tensor,
mammo_loss_and_gt: torch.Tensor) -> ModelIO:
"""first forwards supervised samples through student net, then all samples through both nets. """
output = dict(loss=0.0)
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)"
loss_type = mammo_loss_and_gt[:,0,:,:]
mask = mammo_loss_and_gt[:,1,:,:]
network_output = self.base._forward(mammo_x) # logsoftmax # B
network_output_soft = torch.exp(network_output).to(mammo_x.device) # softmax output of model
output['pixel_probs'] = network_output_soft

main_shape = mask.shape
network_output_shape = network_output.shape



loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2])
mask = mask.unsqueeze(dim = 1)
mask_channels = torch.cat((1-mask, mask), 1)

dummy_mask = mask_channels.clone().to(mammo_x.device)
dummy_net_out = network_output_soft.clone().to(mammo_x.device)
dummy_mask[loss_channels == 0] = 0.0
dummy_shape = dummy_mask.shape
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_net_out = torch.cat((dummy_net_out,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_yhat = network_output_soft * loss_channels
dummy_y = mask_channels * loss_channels
focal_loss = balanced_focal_cross_entropy_loss_semi(
dummy_net_out,
dummy_mask,
focal_gamma=self.gamma)
dice = dice_loss(dummy_net_out,dummy_mask)
tversky = tversky_loss(inputs = dummy_yhat[:,1,...], targets = dummy_y[:,1,...], smooth=self.smooth, alpha = self.alpha, beta= self.beta)
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y)
output['ce_loss'] = ce_loss
output['loss'] = focal_loss
output['diceloss'] = dice
output['focalloss'] = focal_loss
output['suploss'] = output['loss']
output['tverskyloss'] = tversky
output['org_pixel_labels'] = mammo_loss_and_gt
return output

+ 156
- 0
base_line/res_inception.py View File

from copy import deepcopy
import math
from random import uniform
from typing import List
import time
from mlassistant.core import Model, ModelIO
import torch
from torch import nn
from ....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi
from ....utils.generalized_dice import dice_loss

from ....utils.dynamic_num_generator.num_generator import init_num_generator1, init_num_generator2
from ....utils.dynamic_num_generator.wrapper import dynamic_num_generator

class UncertaintyAwareStudentMeanTeacherNet(Model):
""" adapted from the paper 'Uncertainty-aware Self-ensembling Model for Semi-supervised
3D Left Atrium Segmentation'. link: https://arxiv.org/pdf/1907.07034.pdf

Args:
alpha (float): exponential moving average decay
total_num_iteration (int): total number of iteration this run will have
base_model (Model): base networks for both student and teacher
x_arg (str): argument name of `forward` indicating images, shape: B ...
x_roi_arg (str): argument name of `forward` indicating ground-truth ROIs, shape: B ...
x_annotated_arg (str): argument name of `forward` indicating being labeled or not for inputs, shape: B
uncertainty_ratio_thd (float): the initial ratio threshold for uncertainty pruning
std (float): standard deviation of gaussian noise to be applied to the input
mean (float): mean of gaussian noise to be applied to the input
T (int): number of perturbations for teacher net
lambda_coef (float): lambda coefficient for calculating unsupervised loss
use_supervised_in_consistency_checking (bool): if true then use supervised samples in unsupervised task (consistency checking)
"""
def __init__(self,
alpha: float,
base_model: Model,
std: float = math.sqrt(0.0001),
mean: float = 0.0,
T: int=1,
use_supervised_in_consistency_checking: bool = True,
uncertainty_thd_calculator=dynamic_num_generator(num_iters=200, func=init_num_generator2()),
unsup_coef_generator=dynamic_num_generator(num_iters=200, func=init_num_generator1()),
gamma = 2) -> None:
super(UncertaintyAwareStudentMeanTeacherNet, self).__init__()

assert alpha >= 0 and alpha <= 1, f'alpha should be in range [0, 1]; got {alpha} instead.'
assert T > 0, f'number of perturbations to be applied for teacher net should be at least one; got {T} instead.'

self._T: int = T
self._alpha: float = alpha
self._mean: float = mean
self._std: float = std
self._use_supervised_in_consistency_checking: bool = use_supervised_in_consistency_checking
self._uncertainty_thd_calculator = uncertainty_thd_calculator
self._unsup_coef_generator = unsup_coef_generator

self._student: Model = base_model
self._teacher: Model = deepcopy(self._student)
self.gamma = gamma

def _update_teacher_params(self,
_: Model,
__: torch.Tensor) -> None:
"""a backward hook to update teacher network parameters, using exponential moving
average method."""
for s_p, t_p in zip(self._student.parameters(), self._teacher.parameters()):
t_p.data = self._alpha * t_p.data + (1 - self._alpha) * s_p.data

def _apply_gaussian_noise(self,
input ) -> torch.Tensor:
inp = input
noise = torch.randn(inp.size()) * self._std + self._mean
noise = noise.to(inp.device)
inp = inp + (noise * inp)
input = inp
return input

def forward(self,
mammo_x: torch.Tensor,
mammo_loss_and_gt: torch.Tensor) -> ModelIO:
"""first forwards supervised samples through student net, then all samples through both nets. """
output = dict(loss=0.0)
if mammo_loss_and_gt is None:
output['loss'] = torch.tensor(0.0)
output['loss'] = torch.tensor(0.0)
output['mse_loss'] = torch.tensor(0.0)
output['suploss'] = torch.tensor(0.0)
output['focalloss'] = torch.tensor(0.0)
#output['org_pixel_labels'] = output['pixel_probs']
# output['ce_loss'] = torch.tensor(0.0)
# output['ce_0_loss'] = torch.tensor(0.0)
# output['ce_1_loss'] = torch.tensor(0.0)
# output['dice_loss'] = torch.tensor(0.0)
return output
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)"

loss_type_lable = mammo_loss_and_gt[:,0,:,:] # 1 means supervised and zero means unsupervised
mask = mammo_loss_and_gt[:,1,:,:] # 1 means cancer and zero means backgroundru ru network_output = self._student._forward(mammo_x)
network_output = self._student._forward(mammo_x) # torch.nn.LogSoftmax has been applied on last layer
network_output_soft = torch.exp(network_output).to(mammo_x.device)
output['pixel_probs'] = network_output_soft
main_shape = loss_type_lable.shape
network_output_shape = network_output.shape
dup_loss_type = loss_type_lable.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2])
mask = mask.unsqueeze(dim = 1)
dup_mask = torch.cat((1-mask, mask), 1)
supervised_loss = torch.tensor(0.0).to(mammo_x.device)

focal_mask = dup_mask.clone().to(mammo_x.device)
focal_net_out = network_output_soft.clone().to(mammo_x.device)
focal_mask[dup_loss_type == 0] = 0.0
focal_shape = focal_mask.shape
focal_mask = torch.cat((focal_mask,torch.ones((focal_shape[0],1,focal_shape[2],focal_shape[2])).to(mammo_x.device)), 1)
focal_net_out = torch.cat((focal_net_out,torch.ones((focal_shape[0],1,focal_shape[2],focal_shape[2])).to(mammo_x.device)), 1)

supervised_loss += balanced_focal_cross_entropy_loss_semi(
focal_net_out,
focal_mask,
focal_gamma=self.gamma)

dice = dice_loss(focal_net_out,focal_mask)
with torch.no_grad():
teacher_outs: List[torch.Tensor] = list()
for _ in range(self._T):
teacher_outs.append(self._teacher._forward(
self._apply_gaussian_noise(mammo_x)))
out_t_u = torch.exp(torch.stack(teacher_outs, dim=2)).mean(dim=2) # B C H' W'
uncertainty = -1 * torch.sum(out_t_u * torch.log(out_t_u), 1) # B H' W'
mse = torch.sum((network_output_soft - out_t_u) ** 2, dim=1)
# consider only most certain pixels for mse loss
mse_ = torch.zeros_like(mse)
UThreshhold = self._uncertainty_thd_calculator()
mask_mse = uncertainty < UThreshhold
mse_[mask_mse] += mse[mask_mse]
active_pxs = torch.sum(mask_mse.long(), dim=[1, 2]) # B
mse_loss = torch.sum(mse_ / active_pxs[:, None, None], dim=[1, 2])
mse_loss = torch.mean(mse_loss)
mse_coef = self._unsup_coef_generator()
output['loss'] = mse_coef * mse_loss
output['mse_loss'] = mse_loss
output['mse_loss'] = mse_coef
output['Uthreshhold'] = UThreshhold
output['suploss'] = supervised_loss
output['focalloss'] = supervised_loss
output['org_pixel_labels'] = mammo_loss_and_gt
# output['ce_loss'] = torch.tensor(0.0)
# output['ce_0_loss'] = torch.tensor(0.0)
# output['ce_1_loss'] = torch.tensor(0.0)
# output['dice_loss'] = torch.tensor(0.0)
return output

+ 80
- 0
base_line/resunet.py View File

from copy import deepcopy
import math
from random import uniform
from typing import List
import time
from mlassistant.core import Model, ModelIO
import torch
from torch.nn import functional as F
from torch import nn
from ....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi
from ....utils.generalized_dice import dice_loss
from ....utils.losses.tverskyLoss import tversky_loss

from ....utils.dynamic_num_generator.num_generator import init_num_generator1, init_num_generator2
from ....utils.dynamic_num_generator.wrapper import dynamic_num_generator





class RESUNET(Model):
def __init__(self,
base_model: Model,
gamma = 2 , smooth=1, alpha=0.7, beta=0.3) -> None:
super(RESUNET, self).__init__()

self.base = base_model
self.gamma = gamma
self.smooth = smooth
self.alpha = alpha
self.beta = beta

def forward(self,
mammo_x: torch.Tensor,
mammo_loss_and_gt: torch.Tensor) -> ModelIO:
"""first forwards supervised samples through student net, then all samples through both nets. """
output = dict(loss=0.0)
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)"
loss_type = mammo_loss_and_gt[:,0,:,:]
mask = mammo_loss_and_gt[:,1,:,:]
network_output = self.base._forward(mammo_x) # logsoftmax # B
network_output_soft = torch.exp(network_output).to(mammo_x.device) # softmax output of model
output['pixel_probs'] = network_output_soft

main_shape = mask.shape
network_output_shape = network_output.shape



loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2])
mask = mask.unsqueeze(dim = 1)
mask_channels = torch.cat((1-mask, mask), 1)

dummy_mask = mask_channels.clone().to(mammo_x.device)
dummy_net_out = network_output_soft.clone().to(mammo_x.device)
dummy_mask[loss_channels == 0] = 0.0
dummy_shape = dummy_mask.shape
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_net_out = torch.cat((dummy_net_out,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_yhat = network_output_soft * loss_channels
dummy_y = mask_channels * loss_channels
focal_loss = balanced_focal_cross_entropy_loss_semi(
dummy_net_out,
dummy_mask,
focal_gamma=self.gamma)
dice = dice_loss(dummy_net_out,dummy_mask)
tversky = tversky_loss(inputs = dummy_yhat[:,1,...], targets = dummy_y[:,1,...], smooth=self.smooth, alpha = self.alpha, beta= self.beta)
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y)
output['ce_1_loss'] = F.binary_cross_entropy(dummy_yhat * dummy_y,dummy_y)
output['ce_0_loss'] = F.binary_cross_entropy(dummy_yhat * (1 - dummy_y) + dummy_y, dummy_y)
output['ce_loss'] = ce_loss
output['loss'] = 1/2 *(dice + focal_loss)
output['diceloss'] = dice
output['focalloss'] = focal_loss
output['suploss'] = dice
output['tverskyloss'] = tversky
output['org_pixel_labels'] = mammo_loss_and_gt
return output

+ 0
- 0
base_line/transunet.py View File


+ 396
- 0
segmentation/Baseline/TransUNet/transunet_focal.py View File

import copy
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from torch.nn import Dropout, Softmax, Linear, Conv2d, LayerNorm
from torch.nn.modules.utils import _pair

from mlassistant.core import ModelIO, Model
from ......utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi
from ......utils.generalized_dice import dice_loss
from ......utils.losses.tverskyLoss import tversky_loss


def np2th(weights, conv=False):
"""Possibly convert HWIO to OIHW."""
if conv:
weights = weights.transpose([3, 2, 0, 1])
return torch.from_numpy(weights)


def swish(x):
return x * torch.sigmoid(x)


ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}


class Attention(nn.Module):
def __init__(self, config, vis):
super(Attention, self).__init__()
self.vis = vis
self.num_attention_heads = config['transformer']["num_heads"]
self.attention_head_size = int(config['hidden_size'] / self.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size

self.query = Linear(config['hidden_size'], self.all_head_size)
self.key = Linear(config['hidden_size'], self.all_head_size)
self.value = Linear(config['hidden_size'], self.all_head_size)

self.out = Linear(config['hidden_size'], config['hidden_size'])
self.attn_dropout = Dropout(config['transformer']["attention_dropout_rate"])
self.proj_dropout = Dropout(config['transformer']["attention_dropout_rate"])

self.softmax = Softmax(dim=-1)

def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)

def forward(self, hidden_states):
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)

query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)

attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
attention_probs = self.softmax(attention_scores)
weights = attention_probs if self.vis else None
attention_probs = self.attn_dropout(attention_probs)

context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
attention_output = self.out(context_layer)
attention_output = self.proj_dropout(attention_output)
return attention_output, weights


class Mlp(nn.Module):
def __init__(self, config):
super(Mlp, self).__init__()
self.fc1 = Linear(config['hidden_size'], config['transformer']["mlp_dim"])
self.fc2 = Linear(config['transformer']["mlp_dim"], config['hidden_size'])
self.act_fn = ACT2FN["gelu"]
self.dropout = Dropout(config['transformer']["dropout_rate"])

self._init_weights()

def _init_weights(self):
nn.init.xavier_uniform_(self.fc1.weight)
nn.init.xavier_uniform_(self.fc2.weight)
nn.init.normal_(self.fc1.bias, std=1e-6)
nn.init.normal_(self.fc2.bias, std=1e-6)

def forward(self, x):
x = self.fc1(x)
x = self.act_fn(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x


class Embeddings(nn.Module):
"""Construct the embeddings from patch, position embeddings.
"""
def __init__(self, config, img_size, in_channels=3):
super(Embeddings, self).__init__()
self.hybrid = None
self.config = config
img_size = _pair(img_size)

if config['patches'].get("grid") is not None: # ResNet
grid_size = config['patches']["grid"]
patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])
patch_size_real = (patch_size[0] * 16, patch_size[1] * 16)
n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1])
self.hybrid = True
else:
patch_size = _pair(config['patches']["size"])
n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
self.hybrid = False

self.patch_embeddings = Conv2d(in_channels=in_channels,
out_channels=config['hidden_size'],
kernel_size=patch_size,
stride=patch_size)
self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config['hidden_size']))

self.dropout = Dropout(config['transformer']["dropout_rate"])


def forward(self, x):
if self.hybrid:
x, features = self.hybrid_model(x)
else:
features = None
x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2))
x = x.flatten(2)
x = x.transpose(-1, -2) # (B, n_patches, hidden)

embeddings = x + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings, features


class Block(nn.Module):
def __init__(self, config, vis):
super(Block, self).__init__()
self.hidden_size = config['hidden_size']
self.attention_norm = LayerNorm(config['hidden_size'], eps=1e-6)
self.ffn_norm = LayerNorm(config['hidden_size'], eps=1e-6)
self.ffn = Mlp(config)
self.attn = Attention(config, vis)

def forward(self, x):
h = x
x = self.attention_norm(x)
x, weights = self.attn(x)
x = x + h

h = x
x = self.ffn_norm(x)
x = self.ffn(x)
x = x + h
return x, weights


class Encoder(nn.Module):
def __init__(self, config, vis):
super(Encoder, self).__init__()
self.vis = vis
self.layer = nn.ModuleList()
self.encoder_norm = LayerNorm(config['hidden_size'], eps=1e-6)
for _ in range(config['transformer']["num_layers"]):
layer = Block(config, vis)
self.layer.append(copy.deepcopy(layer))

def forward(self, hidden_states):
attn_weights = []
for layer_block in self.layer:
hidden_states, weights = layer_block(hidden_states)
if self.vis:
attn_weights.append(weights)
encoded = self.encoder_norm(hidden_states)
return encoded, attn_weights


class Transformer(nn.Module):
def __init__(self, config, img_size, vis):
super(Transformer, self).__init__()
self.embeddings = Embeddings(config, img_size=img_size)
self.encoder = Encoder(config, vis)

def forward(self, input_ids):
embedding_output, features = self.embeddings(input_ids)
encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden)
return encoded, attn_weights, features


class Conv2dReLU(nn.Sequential):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
padding=0,
stride=1,
use_batchnorm=True,
):
conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=not (use_batchnorm),
)
relu = nn.ReLU(inplace=True)

bn = nn.BatchNorm2d(out_channels)

super(Conv2dReLU, self).__init__(conv, bn, relu)


class DecoderBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
skip_channels=0,
use_batchnorm=True,
):
super().__init__()
self.conv1 = Conv2dReLU(
in_channels + skip_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.conv2 = Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.up = nn.UpsamplingBilinear2d(scale_factor=2)

def forward(self, x, skip=None):
x = self.up(x)
if skip is not None:
x = torch.cat([x, skip], dim=1)
x = self.conv1(x)
x = self.conv2(x)
return x


class SegmentationHead(nn.Sequential):

def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):
conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
super().__init__(conv2d, upsampling)


class DecoderCup(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
head_channels = 512
self.conv_more = Conv2dReLU(
config['hidden_size'],
head_channels,
kernel_size=3,
padding=1,
use_batchnorm=True,
)
decoder_channels = config['decoder_channels']
in_channels = [head_channels] + list(decoder_channels[:-1])
out_channels = decoder_channels

skip_channels=[0,0,0,0]

blocks = [
DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels)
]
self.blocks = nn.ModuleList(blocks)

def forward(self, hidden_states, features=None):
B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden)
h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
x = hidden_states.permute(0, 2, 1)
x = x.contiguous().view(B, hidden, h, w)
x = self.conv_more(x)
for i, decoder_block in enumerate(self.blocks):
if features is not None:
skip = features[i] if (i < self.config.n_skip) else None
else:
skip = None
x = decoder_block(x, skip=skip)
return x


class VisionTransformer(nn.Module):
def __init__(self, img_size=512, num_classes=2, zero_head=False, vis=False):
super(VisionTransformer, self).__init__()
config = self._get_b16_config()
self.num_classes = num_classes
self.zero_head = zero_head
self.transformer = Transformer(config, img_size, vis)
self.decoder = DecoderCup(config)
self.segmentation_head = SegmentationHead(
in_channels=config['decoder_channels'][-1],
out_channels=config['n_classes'],
kernel_size=3,
)
self.config = config

def forward(self, x):
if x.size()[1] == 1:
x = x.repeat(1,3,1,1)
x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden)
x = self.decoder(x, features)
logits = self.segmentation_head(x)
return F.softmax(logits, dim=1)
def _get_b16_config(self):
"""Returns the ViT-B/16 configuration."""
config = {}

config['patches'] = {'size': (16, 16)}
config['hidden_size'] = 768

config['transformer'] = {}
config['transformer']['mlp_dim'] = 3072
config['transformer']['num_heads'] = 12
config['transformer']['num_layers'] = 12
config['transformer']['attention_dropout_rate'] = 0.0
config['transformer']['dropout_rate'] = 0.1

config['decoder_channels'] = (512, 256, 128, 64)
config['n_classes'] = 2

return config


class TransUNet(Model):
"""
TransUNet architecture for baseline evaluation.
"""

def __init__(self):
super().__init__()
self.trans_unet = VisionTransformer()
self.trans_unet = torch.compile(self.trans_unet)

def forward(self,
mammo_x: torch.Tensor,
mammo_loss_and_gt: torch.Tensor) -> ModelIO:
loss_type = mammo_loss_and_gt[:, 0, ...]
ground_truth = mammo_loss_and_gt[:, 1, ...] * loss_type
out_prob_map = self.trans_unet(mammo_x)

main_shape = ground_truth.shape
out_prob_map_shape = out_prob_map.shape

loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],out_prob_map_shape[1],main_shape[1], main_shape[2])
ground_truth = ground_truth.unsqueeze(dim = 1)
mask_channels = torch.cat((1 - ground_truth, ground_truth), 1)

dummy_mask = mask_channels.clone().to(mammo_x.device)
dummy_net_out = out_prob_map.clone().to(mammo_x.device)
dummy_mask[loss_channels == 0] = 0.0
dummy_shape = dummy_mask.shape
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_net_out = torch.cat((dummy_net_out, torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_yhat = out_prob_map * loss_channels
dummy_y = mask_channels * loss_channels

dice = dice_loss(dummy_net_out,dummy_mask)
tversky = tversky_loss(inputs = dummy_yhat[:, 1, ...], targets = dummy_y[:, 1, ...], smooth=1, alpha = 0.7, beta=0.3)
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y)

output = {
'pixel_probs': out_prob_map,
'org_pixel_labels': mammo_loss_and_gt,
'loss': balanced_focal_cross_entropy_loss_semi(dummy_net_out, dummy_mask, 2),
'ce': ce_loss,
'dice_loss': dice,
'tversky': tversky
}

return output

+ 401
- 0
segmentation/Baseline/TransUNet/transunet_topk.py View File

import copy
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from torch.nn import Dropout, Softmax, Linear, Conv2d, LayerNorm
from torch.nn.modules.utils import _pair

from mlassistant.core import ModelIO, Model
from ......utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi
from ......utils.generalized_dice import dice_loss
from ......utils.losses.breast_topk_focal_ce import calculate_breast_topk_focal_ce_loss


def np2th(weights, conv=False):
"""Possibly convert HWIO to OIHW."""
if conv:
weights = weights.transpose([3, 2, 0, 1])
return torch.from_numpy(weights)


def swish(x):
return x * torch.sigmoid(x)


ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}


class Attention(nn.Module):
def __init__(self, config, vis):
super(Attention, self).__init__()
self.vis = vis
self.num_attention_heads = config['transformer']["num_heads"]
self.attention_head_size = int(config['hidden_size'] / self.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size

self.query = Linear(config['hidden_size'], self.all_head_size)
self.key = Linear(config['hidden_size'], self.all_head_size)
self.value = Linear(config['hidden_size'], self.all_head_size)

self.out = Linear(config['hidden_size'], config['hidden_size'])
self.attn_dropout = Dropout(config['transformer']["attention_dropout_rate"])
self.proj_dropout = Dropout(config['transformer']["attention_dropout_rate"])

self.softmax = Softmax(dim=-1)

def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)

def forward(self, hidden_states):
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)

query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)

attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
attention_probs = self.softmax(attention_scores)
weights = attention_probs if self.vis else None
attention_probs = self.attn_dropout(attention_probs)

context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
attention_output = self.out(context_layer)
attention_output = self.proj_dropout(attention_output)
return attention_output, weights


class Mlp(nn.Module):
def __init__(self, config):
super(Mlp, self).__init__()
self.fc1 = Linear(config['hidden_size'], config['transformer']["mlp_dim"])
self.fc2 = Linear(config['transformer']["mlp_dim"], config['hidden_size'])
self.act_fn = ACT2FN["gelu"]
self.dropout = Dropout(config['transformer']["dropout_rate"])

self._init_weights()

def _init_weights(self):
nn.init.xavier_uniform_(self.fc1.weight)
nn.init.xavier_uniform_(self.fc2.weight)
nn.init.normal_(self.fc1.bias, std=1e-6)
nn.init.normal_(self.fc2.bias, std=1e-6)

def forward(self, x):
x = self.fc1(x)
x = self.act_fn(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x


class Embeddings(nn.Module):
"""Construct the embeddings from patch, position embeddings.
"""
def __init__(self, config, img_size, in_channels=3):
super(Embeddings, self).__init__()
self.hybrid = None
self.config = config
img_size = _pair(img_size)

if config['patches'].get("grid") is not None: # ResNet
grid_size = config['patches']["grid"]
patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])
patch_size_real = (patch_size[0] * 16, patch_size[1] * 16)
n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1])
self.hybrid = True
else:
patch_size = _pair(config['patches']["size"])
n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
self.hybrid = False

self.patch_embeddings = Conv2d(in_channels=in_channels,
out_channels=config['hidden_size'],
kernel_size=patch_size,
stride=patch_size)
self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config['hidden_size']))

self.dropout = Dropout(config['transformer']["dropout_rate"])


def forward(self, x):
if self.hybrid:
x, features = self.hybrid_model(x)
else:
features = None
x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2))
x = x.flatten(2)
x = x.transpose(-1, -2) # (B, n_patches, hidden)

embeddings = x + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings, features


class Block(nn.Module):
def __init__(self, config, vis):
super(Block, self).__init__()
self.hidden_size = config['hidden_size']
self.attention_norm = LayerNorm(config['hidden_size'], eps=1e-6)
self.ffn_norm = LayerNorm(config['hidden_size'], eps=1e-6)
self.ffn = Mlp(config)
self.attn = Attention(config, vis)

def forward(self, x):
h = x
x = self.attention_norm(x)
x, weights = self.attn(x)
x = x + h

h = x
x = self.ffn_norm(x)
x = self.ffn(x)
x = x + h
return x, weights


class Encoder(nn.Module):
def __init__(self, config, vis):
super(Encoder, self).__init__()
self.vis = vis
self.layer = nn.ModuleList()
self.encoder_norm = LayerNorm(config['hidden_size'], eps=1e-6)
for _ in range(config['transformer']["num_layers"]):
layer = Block(config, vis)
self.layer.append(copy.deepcopy(layer))

def forward(self, hidden_states):
attn_weights = []
for layer_block in self.layer:
hidden_states, weights = layer_block(hidden_states)
if self.vis:
attn_weights.append(weights)
encoded = self.encoder_norm(hidden_states)
return encoded, attn_weights


class Transformer(nn.Module):
def __init__(self, config, img_size, vis):
super(Transformer, self).__init__()
self.embeddings = Embeddings(config, img_size=img_size)
self.encoder = Encoder(config, vis)

def forward(self, input_ids):
embedding_output, features = self.embeddings(input_ids)
encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden)
return encoded, attn_weights, features


class Conv2dReLU(nn.Sequential):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
padding=0,
stride=1,
use_batchnorm=True,
):
conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=not (use_batchnorm),
)
relu = nn.ReLU(inplace=True)

bn = nn.BatchNorm2d(out_channels)

super(Conv2dReLU, self).__init__(conv, bn, relu)


class DecoderBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
skip_channels=0,
use_batchnorm=True,
):
super().__init__()
self.conv1 = Conv2dReLU(
in_channels + skip_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.conv2 = Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.up = nn.UpsamplingBilinear2d(scale_factor=2)

def forward(self, x, skip=None):
x = self.up(x)
if skip is not None:
x = torch.cat([x, skip], dim=1)
x = self.conv1(x)
x = self.conv2(x)
return x


class SegmentationHead(nn.Sequential):

def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):
conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
super().__init__(conv2d, upsampling)


class DecoderCup(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
head_channels = 512
self.conv_more = Conv2dReLU(
config['hidden_size'],
head_channels,
kernel_size=3,
padding=1,
use_batchnorm=True,
)
decoder_channels = config['decoder_channels']
in_channels = [head_channels] + list(decoder_channels[:-1])
out_channels = decoder_channels

skip_channels=[0,0,0,0]

blocks = [
DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels)
]
self.blocks = nn.ModuleList(blocks)

def forward(self, hidden_states, features=None):
B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden)
h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
x = hidden_states.permute(0, 2, 1)
x = x.contiguous().view(B, hidden, h, w)
x = self.conv_more(x)
for i, decoder_block in enumerate(self.blocks):
if features is not None:
skip = features[i] if (i < self.config.n_skip) else None
else:
skip = None
x = decoder_block(x, skip=skip)
return x


class VisionTransformer(nn.Module):
def __init__(self, img_size=512, num_classes=2, zero_head=False, vis=False):
super(VisionTransformer, self).__init__()
config = self._get_b16_config()
self.num_classes = num_classes
self.zero_head = zero_head
self.transformer = Transformer(config, img_size, vis)
self.decoder = DecoderCup(config)
self.segmentation_head = SegmentationHead(
in_channels=config['decoder_channels'][-1],
out_channels=config['n_classes'],
kernel_size=3,
)
self.config = config

def forward(self, x):
if x.size()[1] == 1:
x = x.repeat(1,3,1,1)
x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden)
x = self.decoder(x, features)
logits = self.segmentation_head(x)
return F.softmax(logits, dim=1)
def _get_b16_config(self):
"""Returns the ViT-B/16 configuration."""
config = {}

config['patches'] = {'size': (16, 16)}
config['hidden_size'] = 768

config['transformer'] = {}
config['transformer']['mlp_dim'] = 3072
config['transformer']['num_heads'] = 12
config['transformer']['num_layers'] = 12
config['transformer']['attention_dropout_rate'] = 0.0
config['transformer']['dropout_rate'] = 0.1

config['decoder_channels'] = (512, 256, 128, 64)
config['n_classes'] = 2

return config


class TransUNet(Model):
"""
TransUNet architecture for baseline evaluation.
"""

def __init__(self):
super().__init__()
self.trans_unet = VisionTransformer()
self.iter_counter = 0

def forward(self,
mammo_x: torch.Tensor,
mammo_loss_and_gt: torch.Tensor) -> ModelIO:
self.iter_counter += 1

loss_type = mammo_loss_and_gt[:, 0, ...]
ground_truth = mammo_loss_and_gt[:, 1, ...] * loss_type
out_prob_map = self.trans_unet(mammo_x)

main_shape = ground_truth.shape
out_prob_map_shape = out_prob_map.shape

loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],out_prob_map_shape[1],main_shape[1], main_shape[2])
ground_truth = ground_truth.unsqueeze(dim = 1)
mask_channels = torch.cat((1 - ground_truth, ground_truth), 1)

dummy_mask = mask_channels.clone().to(mammo_x.device)
dummy_net_out = out_prob_map.clone().to(mammo_x.device)
dummy_mask[loss_channels == 0] = 0.0
dummy_shape = dummy_mask.shape
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_net_out = torch.cat((dummy_net_out, torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_yhat = out_prob_map * loss_channels
dummy_y = mask_channels * loss_channels

dice = dice_loss(dummy_net_out,dummy_mask)
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y)

coeff = (self.iter_counter // 200) / 50

output = {
'pixel_probs': out_prob_map,
'org_pixel_labels': mammo_loss_and_gt,
'loss': calculate_breast_topk_focal_ce_loss(dummy_yhat, 0.85 + 0.1 * coeff, dummy_y[:, [1], ...],
verbose=False, bottom_fill_ratio=0.2, thd_to_apply_neg_loss_on_pos_batch=0.1,
pos_skip_mask=(1 - loss_type).bool(), neg_skip_mask=(1 - loss_type).bool()),
'focal_loss': balanced_focal_cross_entropy_loss_semi(dummy_net_out, dummy_mask, 2),
'ce_loss': ce_loss,
'dice_loss': dice,
}

return output

+ 61
- 0
segmentation/Baseline/UNet/unet.py View File

import torch
import torch.nn.functional as F
import numpy as np
from .unet_model import UNet
from mlassistant.core import ModelIO, Model
from ......utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi
from ......utils.generalized_dice import dice_loss



class UNetBaseline(Model):
"""
UNet architecture for baseline evaluation.
"""

def __init__(self):
super().__init__()

self.scale_channels = [64, 128, 128, 256, 512] # Filter numbers for each scale
self.shape = (512, 512) # Change the probability map shape according to the mammo scans shape
self.input_channels = 1
self.class_num = 2
self.unet = UNet(1, 2)

def forward(self,
mammo_x: torch.Tensor,
mammo_loss_and_gt: torch.Tensor) -> ModelIO:
loss_type = mammo_loss_and_gt[:, 0, ...]
ground_truth = mammo_loss_and_gt[:, 1, ...] * loss_type
out_prob_map = self.unet(mammo_x)

main_shape = ground_truth.shape
out_prob_map_shape = out_prob_map.shape

loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],out_prob_map_shape[1],main_shape[1], main_shape[2])
ground_truth = ground_truth.unsqueeze(dim=1)
mask_channels = torch.cat((1 - ground_truth, ground_truth), 1)

dummy_mask = mask_channels.clone().to(mammo_x.device)
dummy_net_out = out_prob_map.clone().to(mammo_x.device)
dummy_mask[loss_channels == 0] = 0.0
dummy_shape = dummy_mask.shape
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_net_out = torch.cat((dummy_net_out, torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_yhat = out_prob_map * loss_channels
dummy_y = mask_channels * loss_channels

dice = dice_loss(dummy_net_out, dummy_mask)
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y)
focal_loss = balanced_focal_cross_entropy_loss_semi(dummy_net_out, dummy_mask, 2)

output = {
'pixel_probs': out_prob_map,
'org_pixel_labels': mammo_loss_and_gt,
'loss': focal_loss,
'ce_loss': ce_loss,
'dice_loss': dice,
}

return output

+ 169
- 0
segmentation/Baseline/UNet/unet_model.py View File

import numpy as np
import torch
import torch.nn as nn
from torch.distributions.uniform import Uniform


def kaiming_normal_init_weight(model):
for m in model.modules():
if isinstance(m, nn.Conv3d):
torch.nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm3d):
m.weight.data.fill_(1)
m.bias.data.zero_()
return model


def sparse_init_weight(model):
for m in model.modules():
if isinstance(m, nn.Conv3d):
torch.nn.init.sparse_(m.weight, sparsity=0.1)
elif isinstance(m, nn.BatchNorm3d):
m.weight.data.fill_(1)
m.bias.data.zero_()
return model


class ConvBlock(nn.Module):
"""two convolution layers with batch norm and leaky relu"""

def __init__(self, in_channels, out_channels, dropout_p):
super(ConvBlock, self).__init__()
self.conv_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(),
nn.Dropout(dropout_p),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU()
)

def forward(self, x):
return self.conv_conv(x)


class DownBlock(nn.Module):
"""Downsampling followed by ConvBlock"""

def __init__(self, in_channels, out_channels, dropout_p):
super(DownBlock, self).__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
ConvBlock(in_channels, out_channels, dropout_p)

)

def forward(self, x):
return self.maxpool_conv(x)


class UpBlock(nn.Module):
"""Upssampling followed by ConvBlock"""

def __init__(self, in_channels1, in_channels2, out_channels, dropout_p,
bilinear=True):
super(UpBlock, self).__init__()
self.bilinear = bilinear
if bilinear:
self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size=1)
self.up = nn.Upsample(
scale_factor=2, mode='bilinear', align_corners=True)
else:
self.up = nn.ConvTranspose2d(
in_channels1, in_channels2, kernel_size=2, stride=2)
self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p)

def forward(self, x1, x2):
if self.bilinear:
x1 = self.conv1x1(x1)
x1 = self.up(x1)
x = torch.cat([x2, x1], dim=1)
return self.conv(x)


class Encoder(nn.Module):
def __init__(self, params):
super(Encoder, self).__init__()
self.params = params
self.in_chns = self.params['in_chns']
self.ft_chns = self.params['feature_chns']
self.n_class = self.params['class_num']
self.bilinear = self.params['bilinear']
self.dropout = self.params['dropout']
assert (len(self.ft_chns) == 5)
self.in_conv = ConvBlock(
self.in_chns, self.ft_chns[0], self.dropout[0])
self.down1 = DownBlock(
self.ft_chns[0], self.ft_chns[1], self.dropout[1])
self.down2 = DownBlock(
self.ft_chns[1], self.ft_chns[2], self.dropout[2])
self.down3 = DownBlock(
self.ft_chns[2], self.ft_chns[3], self.dropout[3])
self.down4 = DownBlock(
self.ft_chns[3], self.ft_chns[4], self.dropout[4])

def forward(self, x):
x0 = self.in_conv(x)
x1 = self.down1(x0)
x2 = self.down2(x1)
x3 = self.down3(x2)
x4 = self.down4(x3)
return [x0, x1, x2, x3, x4]


class Decoder(nn.Module):
def __init__(self, params):
super(Decoder, self).__init__()
self.params = params
self.in_chns = self.params['in_chns']
self.ft_chns = self.params['feature_chns']
self.n_class = self.params['class_num']
self.bilinear = self.params['bilinear']
assert (len(self.ft_chns) == 5)

self.up1 = UpBlock(
self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p=0.0)
self.up2 = UpBlock(
self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p=0.0)
self.up3 = UpBlock(
self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p=0.0)
self.up4 = UpBlock(
self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p=0.0)

self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class,
kernel_size=3, padding=1)

def forward(self, feature):
x0 = feature[0]
x1 = feature[1]
x2 = feature[2]
x3 = feature[3]
x4 = feature[4]

x = self.up1(x4, x3)
x = self.up2(x, x2)
x = self.up3(x, x1)
x = self.up4(x, x0)
output = self.out_conv(x)
return output


class UNet(nn.Module):
def __init__(self, in_chns, class_num):
super(UNet, self).__init__()

params = {'in_chns': in_chns,
'feature_chns': [64, 128, 128, 256, 512],
'dropout': [0.05, 0.1, 0.2, 0.3, 0.5],
'class_num': class_num,
'bilinear': False,
'acti_func': 'relu'}

self.encoder = Encoder(params)
self.decoder = Decoder(params)

def forward(self, x):
feature = self.encoder(x)
output = self.decoder(feature)
return output.softmax(dim=1)

+ 101
- 0
segmentation/CPS/cps_resunet.py View File

import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np

from .utils import init_weight, unsupervised_loss
from .....models.semi_supervised.base_line.Resunet.Heavy.Models.resunet import RESUNet as RESUNETSEGMENTOR

from mlassistant.core import ModelIO, Model
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi
from .....utils.generalized_dice import dice_loss


class Network(nn.Module):
def __init__(self):
super(Network, self).__init__()

self.branch1 = RESUNETSEGMENTOR(1, 2)
self.branch2 = RESUNETSEGMENTOR(1, 2)

def forward(self, data, step=1):
if step == 1:
return self.branch1(data)
elif step == 2:
return self.branch2(data)


class CPS(Model):
"""
The main design for CPS method with ResUNet as the
base network.
"""

def __init__(self):
super().__init__()

self.iter_counter = 0
self.network = Network()

init_weight(self.network.branch1.business_layer, nn.init.kaiming_normal_,
nn.BatchNorm2d, 1e-5, 0.1,
mode='fan_in', nonlinearity='relu')
init_weight(self.network.branch2.business_layer, nn.init.kaiming_normal_,
nn.BatchNorm2d, 1e-5, 0.1,
mode='fan_in', nonlinearity='relu')

def forward(self,
mammo_x: torch.Tensor,
mammo_loss_and_gt: torch.Tensor) -> ModelIO:
self.iter_counter += 1

pred1 = F.softmax(self.network(mammo_x, step=1), dim=1)
pred2 = F.softmax(self.network(mammo_x, step=2), dim=1)

unsup_loss = unsupervised_loss(pred1, pred2)

loss_type = mammo_loss_and_gt[:, 0, ...]
ground_truth = mammo_loss_and_gt[:, 1, ...] * loss_type

main_shape = ground_truth.shape
out_prob_map_shape = pred1.shape

loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],out_prob_map_shape[1],main_shape[1], main_shape[2])
ground_truth = ground_truth.unsqueeze(dim = 1)
mask_channels = torch.cat((1 - ground_truth, ground_truth), 1)

dummy_mask = mask_channels.clone().to(mammo_x.device)
dummy_net_out1 = pred1.clone().to(mammo_x.device)
dummy_net_out2 = pred2.clone().to(mammo_x.device)
dummy_mask[loss_channels == 0] = 0.0
dummy_shape = dummy_mask.shape
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_net_out1 = torch.cat((dummy_net_out1, torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_net_out2 = torch.cat((dummy_net_out2, torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_yhat = pred1 * loss_channels
dummy_y = mask_channels * loss_channels

focal_loss1 = balanced_focal_cross_entropy_loss_semi(dummy_net_out1, dummy_mask, 4)
focal_loss2 = balanced_focal_cross_entropy_loss_semi(dummy_net_out2, dummy_mask, 4)
dice_ls = dice_loss(dummy_net_out1, dummy_mask)
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y)

weight = self._get_unsup_weight()

output = {
'pixel_probs': pred1,
'org_pixel_labels': mammo_loss_and_gt,
'loss': focal_loss1 + focal_loss2 + unsup_loss,
'ce_loss': ce_loss,
'dice_loss': dice_ls,
}

return output

def _get_unsup_weight(self):
if self.iter_counter // 200 >= 70:
return 1.0
term = 1 - (self.iter_counter // 200) / 70
return np.exp(-5.0 * term * term)


+ 36
- 0
segmentation/CPS/utils.py View File

import torch
import torch.nn as nn
import torch.nn.functional as F


def __init_weight(feature, conv_init, norm_layer, bn_eps, bn_momentum,
**kwargs):
for name, m in feature.named_modules():
if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
conv_init(m.weight, **kwargs)
elif isinstance(m, norm_layer):
m.eps = bn_eps
m.momentum = bn_momentum
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)


def init_weight(module_list, conv_init, norm_layer, bn_eps, bn_momentum,
**kwargs):
if isinstance(module_list, list):
for feature in module_list:
__init_weight(feature, conv_init, norm_layer, bn_eps, bn_momentum,
**kwargs)
else:
__init_weight(module_list, conv_init, norm_layer, bn_eps, bn_momentum,
**kwargs)

def unsupervised_loss(probability_map1, probability_map2):
max1 = torch.max(probability_map1, dim=1)[1].long()
max2 = torch.max(probability_map2, dim=1)[1].long()

loss1 = F.binary_cross_entropy(probability_map1, max2)
loss2 = F.binary_cross_entropy(probability_map2, max1)

return loss1 + loss2

+ 97
- 0
segmentation/ICT/ict.py View File

from typing import Tuple
import torch
import torch.nn.functional as F
import numpy as np
from .loss import unsupervised_loss
from .unet import UNet
from mlassistant.core import ModelIO, Model
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi
from .....utils.generalized_dice import dice_loss


class ICT(Model):
"""
The main design for ICT method with UNet as the
base network.
"""

def __init__(self):
super().__init__()

self.scale_channels = [64, 128, 128, 256, 512] # Filter numbers for each scale
self.shape = (512, 512) # Change the probability map shape according to the mammo scans shape
self.input_channels = 1
self.class_num = 2
self.iter_counter = 14400

self.main_unet = UNet(self.input_channels, self.scale_channels, self.class_num, False)
self.mean_teacher_unet = UNet(self.input_channels, self.scale_channels, self.class_num, True)

def forward(self,
mammo_x: torch.Tensor,
mammo_loss_and_gt: torch.Tensor) -> ModelIO:
self.iter_counter += 1
self._update_mean_teacher()

combined_model_out, combined_mean_out = self._mix_up(mammo_x)
model_out = self.main_unet(mammo_x)

unsup_loss = unsupervised_loss(combined_mean_out, combined_model_out)

loss_type = mammo_loss_and_gt[:, 0, ...]
ground_truth = mammo_loss_and_gt[:, 1, ...] * loss_type

main_shape = ground_truth.shape
out_prob_map_shape = model_out.shape

loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],out_prob_map_shape[1],main_shape[1], main_shape[2])
ground_truth = ground_truth.unsqueeze(dim = 1)
mask_channels = torch.cat((1 - ground_truth, ground_truth), 1)

dummy_mask = mask_channels.clone().to(mammo_x.device)
dummy_net_out = model_out.clone().to(mammo_x.device)
dummy_mask[loss_channels == 0] = 0.0
dummy_shape = dummy_mask.shape
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_net_out = torch.cat((dummy_net_out, torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_yhat = model_out * loss_channels
dummy_y = mask_channels * loss_channels

focal_loss = balanced_focal_cross_entropy_loss_semi(dummy_net_out, dummy_mask, 2)
dice_ls = dice_loss(dummy_net_out,dummy_mask)
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y)

weight = 0.6 * ((self.iter_counter // 200) / 400)

output = {
'pixel_probs': model_out,
'org_pixel_labels': mammo_loss_and_gt,
'loss': focal_loss * (0.8 - weight) + unsup_loss * (0.2 + weight),
'ce_loss': ce_loss,
'dice_loss': dice_ls,
}

return output

def _mix_up(self, unlabeled_mammo_x: torch.Tensor) -> Tuple:
lambda_coeff = np.random.beta(1., 1.)
shuffled_indices = torch.from_numpy(np.random.permutation(unlabeled_mammo_x.shape[0])).to(unlabeled_mammo_x.device)

mixed_mammo_x = lambda_coeff * unlabeled_mammo_x + \
(1 - lambda_coeff) * unlabeled_mammo_x[shuffled_indices]

unlabeled_model_out = self.main_unet(mixed_mammo_x)
with torch.no_grad():
unlabeled_mean_out = self.mean_teacher_unet(unlabeled_mammo_x)
unlabeled_mean_out = lambda_coeff * unlabeled_mean_out + \
(1 - lambda_coeff) * unlabeled_mean_out[shuffled_indices]

return unlabeled_model_out, unlabeled_mean_out

def _update_mean_teacher(self, alpha=0.999):
alpha = min(1 - 1 / (self.iter_counter + 1), alpha)
for mean_param, param in zip(self.mean_teacher_unet.parameters(), self.main_unet.parameters()):
mean_param.data.mul_(alpha).add_(1 - alpha, param.data)


+ 132
- 0
segmentation/ICT/ict_resunet.py View File

from typing import Tuple
import torch
import torch.nn.functional as F
import numpy as np
from copy import deepcopy

from .loss import unsupervised_loss
from .....models.semi_supervised.base_line.Resunet.Heavy.Models.resunet import RESUNet as RESUNETSEGMENTOR

from mlassistant.core import ModelIO, Model
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi
from .....utils.generalized_dice import dice_loss


class ICT(Model):
"""
The main design for ICT method with ResUNet as the
base network.
"""

def __init__(self):
super().__init__()

self.scale_channels = [64, 128, 128, 256, 512]
self.shape = (512, 512)
self.input_channels = 1
self.class_num = 2
self.iter_counter = 0

encoders_conv_kws = [
{"inp_channel": 1, "out_channel": 32, "n_conv_blocks": 2},
{"inp_channel": 33, "out_channel": 64, "n_conv_blocks": 2},
{"inp_channel": 65, "out_channel": 128, "n_conv_blocks": 3},
{"inp_channel": 129, "out_channel": 256, "n_conv_blocks": 3},
]

bridge_conv_kws = {"inp_channel": 257, "out_channel": 512}

decoders_conv_kws = [
{"inp_channel": 512, "enc_inp_channel": 256, "out_channel": 256},
{"inp_channel": 256, "enc_inp_channel": 128, "out_channel": 64},
{"inp_channel": 64, "enc_inp_channel": 64, "out_channel": 32},
{"inp_channel": 32, "enc_inp_channel": 32, "out_channel": 16},
]

self.main_resunet = ResunetSegmentor(
seg_num_labels=self.class_num,
concat_original_image=True,
encoders_conv_kws=encoders_conv_kws,
decoders_conv_kws=decoders_conv_kws,
bridge_conv_kws=bridge_conv_kws,
)

self.mean_teacher_resunet = deepcopy(self.main_resunet)

for param in self.mean_teacher_resunet.parameters():
param.detach_()

def forward(self,
mammo_x: torch.Tensor,
mammo_loss_and_gt: torch.Tensor) -> ModelIO:
self.iter_counter += 1
self._update_mean_teacher()

combined_model_out, combined_mean_out = self._mix_up(mammo_x)
model_out = self.main_resunet(mammo_x, None)["pixel_probs"]

unsup_loss = unsupervised_loss(combined_mean_out, combined_model_out)

loss_type = mammo_loss_and_gt[:, 0, ...]
ground_truth = mammo_loss_and_gt[:, 1, ...] * loss_type

main_shape = ground_truth.shape
out_prob_map_shape = model_out.shape

loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],out_prob_map_shape[1],main_shape[1], main_shape[2])
ground_truth = ground_truth.unsqueeze(dim = 1)
mask_channels = torch.cat((1 - ground_truth, ground_truth), 1)

dummy_mask = mask_channels.clone().to(mammo_x.device)
dummy_net_out = model_out.clone().to(mammo_x.device)
dummy_mask[loss_channels == 0] = 0.0
dummy_shape = dummy_mask.shape
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_net_out = torch.cat((dummy_net_out, torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_yhat = model_out * loss_channels
dummy_y = mask_channels * loss_channels

focal_loss = balanced_focal_cross_entropy_loss_semi(dummy_net_out, dummy_mask, 4)
dice_ls = dice_loss(dummy_net_out,dummy_mask)
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y)

weight = self._get_unsup_weight()

output = {
'pixel_probs': model_out,
'org_pixel_labels': mammo_loss_and_gt,
'loss': focal_loss + unsup_loss, # * weight,
'ce_loss': ce_loss,
'dice_loss': dice_ls,
}

return output

def _mix_up(self, unlabeled_mammo_x: torch.Tensor) -> Tuple:
lambda_coeff = np.random.beta(1., 1.)
shuffled_indices = torch.from_numpy(np.random.permutation(unlabeled_mammo_x.shape[0])).to(unlabeled_mammo_x.device)

mixed_mammo_x = lambda_coeff * unlabeled_mammo_x + \
(1 - lambda_coeff) * unlabeled_mammo_x[shuffled_indices]

unlabeled_model_out = self.main_resunet(mixed_mammo_x, None)["pixel_probs"]
with torch.no_grad():
unlabeled_mean_out = self.mean_teacher_resunet(unlabeled_mammo_x, None)["pixel_probs"]
unlabeled_mean_out = lambda_coeff * unlabeled_mean_out + \
(1 - lambda_coeff) * unlabeled_mean_out[shuffled_indices]

return unlabeled_model_out, unlabeled_mean_out

def _update_mean_teacher(self, alpha=0.999):
alpha = min(1 - 1 / (self.iter_counter + 1), alpha)
for mean_param, param in zip(self.mean_teacher_resunet.parameters(), self.main_resunet.parameters()):
mean_param.data.mul_(alpha).add_(1 - alpha, param.data)

def _get_unsup_weight(self):
if self.iter_counter // 200 >= 70:
return 1.0
term = 1 - (self.iter_counter // 200) / 70
return np.exp(-5.0 * term * term)


+ 70
- 0
segmentation/ICT/loss.py View File

import torch
from torch.nn import BCELoss
from .....utils.dice import dice_loss


def supervised_loss(probability_map: torch.Tensor, ground_truth: torch.Tensor) -> torch.Tensor:
"""
Computes the total supervised loss of the batch.
The loss is consisted of dice loss and cross entropy loss.
:param probability_map: probability map output of the segmentation model
:param ground_truth: ground truth of the processed input
:return: loss of the labeled data
"""

binary_entropy_loss = BCELoss()
loss = (binary_entropy_loss(probability_map, ground_truth) +
dice_loss(probability_map.unsqueeze(dim=0), ground_truth.long())) / 2

return loss


def unsupervised_loss(mean_teacher_mixed_up: torch.Tensor, model_mixed_up_output: torch.Tensor) -> torch.Tensor:
"""
Computes the total unsupervised loss for the batch.
:param mean_teacher_mixed_up: probability map output of the mean teacher model
:param model_mixed_up_output: probability map output of the main model using the mixed up data
:return: loss of the unlabeled data
"""

loss = torch.mean((mean_teacher_mixed_up - model_mixed_up_output) ** 2)

return loss


def total_loss(labeled_out: torch.Tensor,
ground_truth: torch.Tensor,
unlabeled_mean_out: torch.Tensor,
unlabeled_model_out: torch.Tensor,
epoch: int) -> torch.Tensor:
"""
Computes the total loss consisting of both supervised and unsupervised losses.
:param labeled_x: probability maps belonging to the labeled dataset
:param unlabeled_x: probability maps belonging to the unlabeled dataset
:param ground_truth: ground truth belonging to the labeled dataset
:param epoch: current epoch of training
:return: total loss
"""

sup_loss = supervised_loss(labeled_out, ground_truth)
if unlabeled_mean_out is None:
unsup_loss = 0
else:
unsup_loss = unsupervised_loss(unlabeled_mean_out, unlabeled_model_out)

weight = unsupervised_weight(epoch)

loss = sup_loss * (0.8 - weight) + unsup_loss * (0.2 + weight)
return loss


def unsupervised_weight(epoch: int, max_epoch=200):
"""
Computes trade-off weight lambda for the unsupervised loss
:param epoch: current epoch
:param w_max: a coefficient used in the formula
:param max_epoch: maximum number of epochs to run
:return: trade-off weight lambda
"""

return 0.6 * (epoch / max_epoch)

+ 157
- 0
segmentation/ICT/unet.py View File

from typing import Tuple, List
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms


class ConvolutionBlock(nn.Module):
"""
A block with two convolutional layers followed by batch norm
layers and LeakyReLU as activation function.
"""

def __init__(self, input_channel: int, output_channel: int):
super(ConvolutionBlock, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(input_channel, output_channel, kernel_size=3, padding=1),
nn.BatchNorm2d(output_channel),
nn.LeakyReLU(),
nn.Conv2d(output_channel, output_channel, kernel_size=3, padding=1),
nn.BatchNorm2d(output_channel),
nn.LeakyReLU()
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)


class DownSamplingBlock(nn.Module):
"""
A down-sampling block with a max pooling layer followed by
a convolutional block.
"""

def __init__(self, input_channel: int, output_channel: int):
super(DownSamplingBlock, self).__init__()
self.down_sampler = nn.Sequential(
nn.MaxPool2d(2),
ConvolutionBlock(input_channel, output_channel)
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_sampler(x)


class UpSamplingBlock(nn.Module):
"""
An up-sampling block with a 2d transpose convolution layer
for up-sampling followed by a ConvolutionBlock.
"""

def __init__(self, input_channel1, input_channel2, output_channel):
super(UpSamplingBlock, self).__init__()
self.up = nn.ConvTranspose2d(input_channel1, input_channel2, kernel_size=2, stride=2)
self.conv = ConvolutionBlock(input_channel2 * 2, output_channel)

def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
x1 = self.up(x1)
height = x1.size()[2] # Finding dimensions of the up-sampled image for cropping
width = x1.size()[3]
x2 = transforms.CenterCrop((height, width))(x2) # Cropping the image from the skip connection from the center
x = torch.cat([x2, x1], dim=1) # Concatenating both images together
return self.conv(x)


class DownSampler(nn.Module):
"""
The module used for down-sampling process consisted of
multiple down-sampling blocks and one convolutional block
at the first part.
"""

def __init__(self, input_channel: int, scale_channels: list):
super(DownSampler, self).__init__()
self.input_channel = input_channel
self.scale_channels = scale_channels

self.conv = ConvolutionBlock(self.input_channel, self.scale_channels[0])
self.down_sampler1 = DownSamplingBlock(self.scale_channels[0], self.scale_channels[1])
self.down_sampler2 = DownSamplingBlock(self.scale_channels[1], self.scale_channels[2])
self.down_sampler3 = DownSamplingBlock(self.scale_channels[2], self.scale_channels[3])
self.down_sampler4 = DownSamplingBlock(self.scale_channels[3], self.scale_channels[4])

def forward(self, x: torch.Tensor) -> List:
features = self.conv(x)
down_sampled1 = self.down_sampler1(features)
down_sampled2 = self.down_sampler2(down_sampled1)
down_sampled3 = self.down_sampler3(down_sampled2)
down_sampled4 = self.down_sampler4(down_sampled3)

return [features, down_sampled1, down_sampled2, down_sampled3, down_sampled4]


class UpSampler(nn.Module):
"""
The module used for up-sampling process consisted of
multiple up-sampling blocks and one convolutional layer
at the last part for the probability map.
"""

def __init__(self, scale_channels: list, class_num: int):
super(UpSampler, self).__init__()
self.scale_channels = scale_channels
self.class_num = class_num

self.up_sampler4 = UpSamplingBlock(scale_channels[4], scale_channels[3], scale_channels[3])
self.up_sampler3 = UpSamplingBlock(scale_channels[3], scale_channels[2], scale_channels[2])
self.up_sampler2 = UpSamplingBlock(scale_channels[2], scale_channels[1], scale_channels[1])
self.up_sampler1 = UpSamplingBlock(scale_channels[1], scale_channels[0], scale_channels[0])

self.probability_map_conv = nn.Sequential(
nn.Conv2d(self.scale_channels[0], self.class_num, kernel_size=1),
nn.Softmax(dim=1)
)

def forward(self, down_outputs: list) -> Tuple:
"""
NOTE: Numbers in front of the var names indicate the scale number.
"""
down_outputs0 = down_outputs.pop(0)
down_outputs1 = down_outputs.pop(0)
down_outputs2 = down_outputs.pop(0)
down_outputs3 = down_outputs.pop(0)
down_outputs4 = down_outputs.pop(0)

up_sampled4 = self.up_sampler4(down_outputs4, down_outputs3)
up_sampled3 = self.up_sampler3(up_sampled4, down_outputs2)
up_sampled2 = self.up_sampler2(up_sampled3, down_outputs1)
up_sampled1 = self.up_sampler1(up_sampled2, down_outputs0)
output = self.probability_map_conv(up_sampled1)

return output


class UNet(nn.Module):
"""
Implementation of UNet architecture.
"""

def __init__(self, input_channel: list, scale_channels: list, class_num: int, ema_flag: bool):
super(UNet, self).__init__()
self.down_sampler = DownSampler(input_channel, scale_channels)
self.up_sampler = UpSampler(scale_channels, class_num)

if ema_flag:
for param in self.down_sampler.parameters():
param.detach_()
for param in self.up_sampler.parameters():
param.detach_()

def forward(self, x: list) -> Tuple:
down_sampled_x = self.down_sampler(x)
up_sampled_x = self.up_sampler(down_sampled_x)
return up_sampled_x

+ 154
- 0
segmentation/URPC/loss.py View File

import torch
import numpy as np
from typing import Tuple
from torch.nn import BCELoss, KLDivLoss
from .....utils.dice import dice_loss


def single_supervised_loss(probability_map: torch.Tensor, ground_truth: torch.Tensor) -> torch.Tensor:
"""
Computes the supervised loss of a single scale of the model.
The loss is consisted of dice loss and cross entropy loss.
:param probability_map: probability map output of the segmentation model
:param ground_truth: ground truth of the processed input
:return: overall loss of the supervised data
"""

binary_entropy_loss = BCELoss()
loss = (binary_entropy_loss(probability_map, ground_truth) +
dice_loss(probability_map.unsqueeze(dim=0), ground_truth.long())) / 2
return loss


def supervised_loss(probability_map: list, ground_truth: torch.Tensor) -> torch.Tensor:
"""
Computes the total supervised loss of the batch.
The loss is consisted of dice loss and cross entropy loss.
:param probability_map: probability map outputs of the segmentation model from different scales
:param ground_truth: ground truth of the processed input
:return: mean loss of the supervised data over all scales
"""

probability_map0 = probability_map.pop(0)
probability_map1 = probability_map.pop(0)
probability_map2 = probability_map.pop(0)
probability_map3 = probability_map.pop(0)

total_loss = (single_supervised_loss(probability_map0, ground_truth) +
single_supervised_loss(probability_map1, ground_truth) +
single_supervised_loss(probability_map2, ground_truth) +
single_supervised_loss(probability_map3, ground_truth)) / 4

return total_loss


def uncertainty_minimization(probability_map: list, mean_prob_map: torch.Tensor) -> Tuple[torch.Tensor, Tuple]:
"""
Computes the uncertainty minimization term for the unsupervised loss.
:param mean_prob_map: mean probability map over all the scales
:param probability_map: probability map outputs of the segmentation model from different scales
:return: uncertainty minimization term and a tuple of the exponential weights
"""

kld_loss = KLDivLoss(reduction='none')

kl_loss_scale0 = torch.sum(kld_loss(torch.log(probability_map[0]),
mean_prob_map), dim=1, keepdim=True)
exp_kl_loss_scale0 = torch.exp(-1 * kl_loss_scale0)

kl_loss_scale1 = torch.sum(kld_loss(torch.log(probability_map[1]),
mean_prob_map), dim=1, keepdim=True)
exp_kl_loss_scale1 = torch.exp(-1 * kl_loss_scale1)

kl_loss_scale2 = torch.sum(kld_loss(torch.log(probability_map[2]),
mean_prob_map), dim=1, keepdim=True)
exp_kl_loss_scale2 = torch.exp(-1 * kl_loss_scale2)

kl_loss_scale3 = torch.sum(kld_loss(torch.log(probability_map[3]),
mean_prob_map), dim=1, keepdim=True)
exp_kl_loss_scale3 = torch.exp(-1 * kl_loss_scale3)

UM = (torch.mean(kl_loss_scale0) + torch.mean(kl_loss_scale1) +
torch.mean(kl_loss_scale2) + torch.mean(kl_loss_scale3)) / 4

return UM, (exp_kl_loss_scale0, exp_kl_loss_scale1, exp_kl_loss_scale2, exp_kl_loss_scale3)


def uncertainty_rectification(probability_map: list, mean_prob_map: torch.Tensor,
weights: tuple) -> torch.Tensor:
"""
Computes the uncertainty rectification term for the unsupervised loss.
:param weights: weights computed by the kl divergences from the UM term
:param mean_prob_map: mean probability map over all the scales
:param probability_map: probability map outputs of the segmentation model from different scales
:return: uncertainty rectification term
"""

mse_loss_scale0 = torch.mean(((probability_map[0] - mean_prob_map) ** 2) *
weights[0]) / torch.mean(weights[0])
mse_loss_scale1 = torch.mean(((probability_map[1] - mean_prob_map) ** 2) *
weights[1]) / torch.mean(weights[1])
mse_loss_scale2 = torch.mean(((probability_map[2] - mean_prob_map) ** 2) *
weights[2]) / torch.mean(weights[2])
mse_loss_scale3 = torch.mean(((probability_map[3] - mean_prob_map) ** 2) *
weights[3]) / torch.mean(weights[3])
UR = (mse_loss_scale0 + mse_loss_scale1 + mse_loss_scale2 + mse_loss_scale3) / 4

return UR


def unsupervised_loss(probability_map: list) -> torch.Tensor:
"""
Computes the total unsupervised loss for the batch.
:param probability_map: probability map outputs of the segmentation model from different scales
:return: mean loss of the unsupervised data over all scales
"""

mean_prob_map = (probability_map[0] + probability_map[1] + probability_map[2] + probability_map[3]) / 4
UM, weights = uncertainty_minimization(probability_map, mean_prob_map)
UR = uncertainty_rectification(probability_map, mean_prob_map, weights)
loss = UM + UR
return loss


def total_loss(x: list, ground_truth: torch.Tensor, region_mask: torch.Tensor, epoch: int) -> torch.Tensor:
"""
Computes the total loss consisting of both supervised and unsupervised losses.
:param labeled_x: probability maps belonging to the labeled dataset
:param unlabeled_x: probability maps belonging to the unlabeled dataset
:param ground_truth: ground truth belonging to the labeled dataset
:param region_mask: regional mask for specifying supervised regions
:param epoch: current epoch of training
:return: total loss
"""

# sup_loss = supervised_loss(labeled_x, ground_truth)
# if unlabeled_x is None:
# unsup_loss = 0
# else:
# unsup_loss = unsupervised_loss(unlabeled_x)
# weight = unsupervised_weight(epoch)
# loss = sup_loss * (0.8 - weight) + unsup_loss * (0.2 + weight)

scales_outputs_sup, scales_outputs_unsup = [], []
for scale_output in x:
scales_outputs_sup.append(scale_output * region_mask.to(torch.float64))
scales_outputs_unsup.append(scale_output * ~region_mask.to(torch.float64))

# sup_loss = supervised_loss(scales_outputs_sup, ground_truth * region_mask.to(torch.float64))
unsup_loss = unsupervised_loss(scales_outputs_unsup)
return unsup_loss


def unsupervised_weight(epoch: int, max_epoch=200):
"""
Computes trade-off weight lambda for the unsupervised loss
:param epoch: current epoch
:param w_max: a coefficient used in the formula
:param max_epoch: maximum number of epochs to run
:return: trade-off weight lambda
"""

return 0.6 * (epoch / max_epoch)

+ 351
- 0
segmentation/URPC/resunet.py View File

import torch.nn
from typing import Dict, List, Any
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss
from torch.nn import functional as F
from PIL import Image
from torchvision.utils import draw_segmentation_masks
import os
from mlassistant.core import Model
from mlassistant.context.dataloader import DataloaderContext
import numpy as np


from typing import Dict, List, Any, Callable, Optional
from torch import nn
from torch.nn import functional as F
from torchvision.models import resnet18
from mlassistant.core import Model
from ....segmentation.submodules.res_enc_block import ResEncoderBlock
from .....enums.roi.roi_aggregation_type import ROIAggregation
from ....full_segmentation.utils import sum_roi_4ch_to_1ch
from ....segmentation.utils import cal_patch_pixel_labels_by_patch_label_and_roi, \
create_patches_label_based_on_roi, AssembleDisassemble, \
reduce_resolution, select_non_boundaries
from ....segmentation.losses import cal_focal_loss, cal_weighted_focal_loss


class ResDecoderBlock(torch.nn.Module):
def __init__(self, inp_channel, enc_inp_channel, out_channel, n_conv_blocks=2,
conv_block=torch.nn.Conv2d, conv_block_kws={'kernel_size': 3, 'padding': 1},
activation=torch.nn.ReLU):
super(ResDecoderBlock, self).__init__()

self.up_sample = torch.nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)

self.skip_connection = conv_block(inp_channel + enc_inp_channel, out_channel, **conv_block_kws)

modules = [
torch.nn.BatchNorm2d(inp_channel + enc_inp_channel),
activation(), # activation of the previous channel
conv_block(inp_channel + enc_inp_channel, out_channel, **conv_block_kws)]

for i in range(1, n_conv_blocks):
modules += [
torch.nn.BatchNorm2d(out_channel),
activation(),
conv_block(out_channel, out_channel, **conv_block_kws)]

self.seq = torch.nn.Sequential(*modules)

def forward(self, x, x_enc):
out = self.up_sample(x)
out = torch.cat([out, x_enc], dim=1)

return self.skip_connection(out) + self.seq(out)


class ResUNet(Model):
def __init__(self,
seg_num_labels: int,
concat_original_image: bool,
encoders_conv_kws: Optional[List[Dict[str, Any]]],
bridge_conv_kws: Optional[Dict[str, Any]],
decoders_conv_kws: List[Dict[str, Any]],
roi_agg_type: ROIAggregation = ROIAggregation.mass_vs_others,
remove_boundaries: bool = False,
pretrained_resnet18: bool = False,
loss=(lambda log_pred, gt: cal_weighted_focal_loss(log_pred, gt, 4)),
use_dropout: bool = False) -> None:

super(ResUNet, self).__init__()

assert encoders_conv_kws is not None or pretrained_resnet18, \
"for creating encoder part, either of number of channels or using pretrained version of resnet18 should be determined"
assert not concat_original_image or not pretrained_resnet18, \
"concatenating is currently not supported when using pretrained version"

num_encoders = 6 if pretrained_resnet18 else len(encoders_conv_kws)
self.halving_depth = num_encoders - len(decoders_conv_kws) + 1
self._roi_agg_type: ROIAggregation = roi_agg_type
self._remove_boundaries: bool = remove_boundaries
self._loss = loss
self._concat_original_image: bool = concat_original_image

self.encoders: nn.ModuleList = None
self.bridge: nn.Sequential = None
self._create_encoder(pretrained_resnet18, encoders_conv_kws)
self._create_bridge(bridge_conv_kws)

self.decoders = nn.ModuleList([ResDecoderBlock(**kwargs)
for kwargs in decoders_conv_kws])

last_channel = decoders_conv_kws[-1]['out_channel']
if not use_dropout:
self.decider = nn.Sequential(
nn.Conv2d(last_channel, seg_num_labels, kernel_size=1),
nn.LogSoftmax(dim=1))
else:
self.decider = nn.Sequential(
nn.Dropout(0.3),
nn.Conv2d(last_channel, seg_num_labels, kernel_size=1),
nn.LogSoftmax(dim=1))

def _create_encoder(self,
pretrained: bool,
encoders_conv_kws: Optional[List[Dict[str, Any]]]) -> None:
if pretrained:
resnet = resnet18(pretrained=True)
self._inp_channel = 3
self.encoders = nn.ModuleList([])
self.encoders.append(
nn.Sequential(
resnet.conv1,
resnet.bn1,
resnet.relu,
resnet.maxpool))
for l in [resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4]:
self.encoders.append(l)

else:
self._inp_channel = encoders_conv_kws[0]['inp_channel']
self.encoders = nn.ModuleList([ResEncoderBlock(**encoders_conv_kws[0])])
encoders_conv_kws.pop(0)
self.encoders += nn.ModuleList([nn.Sequential(
nn.AvgPool2d(kernel_size=2, stride=2),
ResEncoderBlock(**kwargs)
)
for kwargs in encoders_conv_kws])


def _create_bridge(self,
bridge_conv_kws: Optional[List[Dict[str, Any]]]) -> None:
if bridge_conv_kws is None:
self.bridge = nn.AvgPool2d(kernel_size=2, stride=2),
else:
self.bridge = nn.Sequential(
nn.AvgPool2d(kernel_size=2, stride=2),
ResEncoderBlock(**bridge_conv_kws))


def _forward(self, x):
x = x.repeat_interleave(self._inp_channel, 1)
encoders_outs = []
out = x
x_ = x

for enc in self.encoders:
out = enc(out)
encoders_outs.append(out)

if self._concat_original_image:
x_ = F.interpolate(x_, out.shape[2:], mode='bilinear', align_corners=True)
out = torch.cat([out, x_], dim=1)

out = self.bridge(out)

for dec in self.decoders:
out = dec(out, encoders_outs.pop(-1))
return self.decider(out)


class ResunetSegmentor(ResUNet):

def __init__(self,
seg_num_labels: int,
concat_original_image: bool,
encoders_conv_kws: List[Dict[str, Any]],
bridge_conv_kws: Dict[str, Any],
decoders_conv_kws: List[Dict[str, Any]] ,
log_transform=False ,
is_base=True,
base_saving_dir=None ,
gamma=2):
"""
Args:
seg_num_labels (int): [description] = 2 (background vs nipple)
concat_original_image (bool): [description]
encoders_conv_kws (List[Dict[str, Any]]): [description]
bridge_conv_kws (Dict[str, Any]): [description]
decoders_conv_kws (List[Dict[str, Any]]): [description]
"""
self.is_base = is_base
self.batch_count = 0
super(ResunetSegmentor, self).__init__(
seg_num_labels,
concat_original_image,
encoders_conv_kws,
bridge_conv_kws,
decoders_conv_kws,
)
self.seg_num_labels = seg_num_labels
self.log_transform = log_transform
self.dropout_2 = torch.nn.Dropout(0.4)
self.dropout_1 = torch.nn.Dropout(0.2)
last_channel = decoders_conv_kws[-1]["out_channel"]
self.decider = torch.nn.Sequential(
torch.nn.Conv2d(last_channel, seg_num_labels, kernel_size=1),
torch.nn.LogSoftmax(dim=1),
)
self.gamma = gamma
self.base_model = None
self.base_saving_dir = base_saving_dir
self.debug = self.base_saving_dir is not None
if self.debug:
os.makedirs(self.base_saving_dir , exist_ok=True)
def set_base_model(self,base_model:Model):
self.base_model = base_model
return

def _forward(self, x):
encoders_outs = []
scales_outs = []
out = x
x_ = x
# print("encoder output shapes:")
for enc in self.encoders:
out = enc(out)
# print(out.shape)
encoders_outs.append(out)
out = self.dropout_1(out)
if self._concat_original_image:
x_ = F.interpolate(x_, out.shape[2:], mode='bilinear', align_corners=True)
out = torch.cat([out, x_], dim=1)
final_encoder_out = out
final_encoder_out = final_encoder_out.flatten(2)
out = self.dropout_2(out)
out = self.bridge(out)
# print("decoder output shapes:")
for dec in self.decoders:
# print(out.shape)
encoder_out = encoders_outs.pop(-1)
out = dec(out, encoder_out)
scales_outs.append(F.interpolate(out, (512, 512)))
out = self.dropout_1(out)

scales_outs[-1] = self.decider(out)

return scales_outs
def draw_segmentations(self,mammo_x , pixel_probs , ground_truths, prefix=''):
names = DataloaderContext.instance.dataloader.get_current_batch_samples_names()
for image,prediction , ground_truth, name in zip(mammo_x, pixel_probs , ground_truths, names):
image = (image * 255).to(torch.uint8)
if len(image.shape)< 3:
image = image.unsqueeze(0)
image = torch.repeat_interleave(image, 3, dim=0).cpu() # rgb
# image
new_prediction = (prediction.clone().detach()>0.5).bool().unsqueeze(0)
label = (ground_truth.clone().detach()>0.5).bool().unsqueeze(0)
# new_pred
masks = torch.cat([label]).cpu()
# print(masks.shape)
alpha = 0.4
final_image = draw_segmentation_masks(image, masks, alpha , colors=['green']).transpose(0,1).transpose(1,2)
name = name.replace('/' , '_')
Image.fromarray(final_image.cpu().numpy()).save(
os.path.join(self.base_saving_dir , f"{name}_{prefix}_label.png"))
masks = torch.cat([new_prediction]).cpu()
# print(masks.shape)
alpha = 0.4
final_image = draw_segmentation_masks(image, masks, alpha , colors=['red']).transpose(0,1).transpose(1,2)
Image.fromarray(final_image.cpu().numpy()).save(
os.path.join(self.base_saving_dir , f"{name}_{prefix}_pred.png"))
return
def draw_raw_segmentations(self,mammo_x , pixel_probs , ground_truths , prefix=""):
names = DataloaderContext.instance.dataloader.get_current_batch_samples_names()
for image, prediction , ground_truth, name in zip(mammo_x, pixel_probs , ground_truths, names):
# image
name = name.replace('/' , '_')
if len(prefix) >0:
name = f"{name}_{prefix}"
image = (image * 255).to(torch.uint8)
Image.fromarray(image.squeeze().cpu().numpy()).save(f"{self.base_saving_dir}/{name}.png")
# pred
prediction = (prediction.clone().detach().squeeze() * 255).to(torch.uint8)
Image.fromarray(prediction.cpu().numpy()).save(f"{self.base_saving_dir}/{name}_prediction.png")
# label
label = (ground_truth.clone().detach().squeeze() * 255).to(torch.uint8)
Image.fromarray(label.cpu().numpy()).save(f"{self.base_saving_dir}/{name}_label.png")
return

def draw_error_mask(self , mammo_x , probabilities , ground_truth):
names = DataloaderContext.instance.dataloader.get_current_batch_samples_names()
error_mask = torch.zeros(size = (len(mammo_x) , 3 , mammo_x.shape[-2] , mammo_x.shape[-1]))
prob_channels = probabilities.shape[1]
error_mask[:,0:prob_channels] = ground_truth - probabilities
error_mask[error_mask < 0] = 0
error_mask *= 255
error_mask = error_mask.to(torch.uint8).cpu().numpy()
for error_image,image ,name in zip(error_mask, mammo_x, names):
image = (image * 255).to(torch.uint8).squeeze()
name = name.replace('/' , '_')
Image.fromarray(image.cpu().numpy()).save(
os.path.join(self.base_saving_dir , f"{name}.png"))
Image.fromarray(error_image.transpose(1,2,0)).save(
os.path.join(self.base_saving_dir , f"{name}_error_mask.png"))
return

def forward(self,
mammo_x: torch.Tensor,
mammo_breast_mask: torch.Tensor):
batch_size = len(mammo_x)
predicted_mask = None
chosen_samples = [0 for i in range(batch_size)]
if not self.is_base:
predicted_mask = torch.exp(self.base_model._forward(mammo_x)).to(mammo_x.device)[:,1]
predicted_mask[predicted_mask>=0.5] = 1
predicted_mask[predicted_mask < 0.5] = 0
intersection = predicted_mask.squeeze() * mammo_breast_mask.squeeze()
union = ((predicted_mask.squeeze() + mammo_breast_mask.squeeze()) >0.5).to(torch.int)
for i in range(batch_size):
if (intersection[i].sum()/union[i].sum()) > 0.75:
chosen_samples[i] = 1
output = dict()
log_pixel_probs = self._forward(mammo_x)
for i in range(len(log_pixel_probs)):
log_pixel_probs[i] = torch.exp(log_pixel_probs[i]).to(mammo_x.device)

return log_pixel_probs

def apply_log_transform(self,pixel_probs):
pixel_probs = pixel_probs/pixel_probs.max()
pixel_probs = (pixel_probs * 255).to(torch.int16)
for i in range(0, len(pixel_probs)):
c = (255 / torch.log(1 + torch.max(pixel_probs[i][1]))).to(pixel_probs.device)
pixel_probs[i][1] = c * (torch.log(pixel_probs[i][1] + 1))
pixel_probs[i][0] = 255 - pixel_probs[i][1]
pixel_probs = pixel_probs.to(torch.float32)
pixel_probs = pixel_probs/pixel_probs.max()
return pixel_probs

def refine_segmentation_using_clustering(self, segmentation , image):
y,x = torch.where(segmentation >= 0.5)
pixels_values = image[segmentation >= 0.5].tolist()
median = np.median(np.array(pixels_values))
std = np.std(np.array(pixels_values))
pixels_values += [median-3*std , median+3*std]
self.clusterer.fit(np.array(pixels_values).reshape(-1,1))
pixel_labels = self.clusterer.labels_
true_positives_label = pixel_labels[-1] # pixel labels of closed to 1.0 pixels
false_positive_indices = [i for i in range(len(pixel_labels)) if pixel_labels[i] != true_positives_label]
false_positive_indices = torch.tensor(false_positive_indices[:-1]).to(image.device) # no need for last one
if len(false_positive_indices) >0:
false_positive_x = x[false_positive_indices]
false_positive_y = y[false_positive_indices]
segmentation[false_positive_y , false_positive_x] = 0
return segmentation

+ 80
- 0
segmentation/URPC/urpc.py View File

import torch
from torch import nn
import torch.nn.functional as F
from .loss import unsupervised_loss
from .utils import UpSampler, DownSampler
from mlassistant.core import ModelIO, Model
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi
from .....utils.generalized_dice import dice_loss



class URPC(Model):
"""
The main design for URPC architecture with UNet as the
base network.
"""

def __init__(self):
super().__init__()

self.scale_channels = [64, 128, 128, 256, 512] # Filter numbers for each scale
self.shape = (512, 512) # Change the probability map shape according to the mammo scans shape
self.input_channels = 1
self.class_num = 2
self.iter_counter = 13400

self.down_sampler = DownSampler(self.input_channels, self.scale_channels)
self.up_sampler = UpSampler(self.scale_channels, self.class_num, self.shape)

def forward(self,
mammo_x: torch.Tensor,
mammo_loss_and_gt: torch.Tensor) -> ModelIO:
self.iter_counter += 1
down_sampled_x = self.down_sampler(mammo_x)
out0, out1, out2, out3 = self.up_sampler(down_sampled_x)

loss_type = mammo_loss_and_gt[:, 0, ...]
ground_truth = mammo_loss_and_gt[:, 1, ...] * loss_type

main_shape = ground_truth.shape
out_prob_map_shape = out0.shape

loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],out_prob_map_shape[1],main_shape[1], main_shape[2])
ground_truth = ground_truth.unsqueeze(dim = 1)
mask_channels = torch.cat((1 - ground_truth, ground_truth), 1)

dummy_mask = mask_channels.clone().to(mammo_x.device)
dummy_net_out = out0.clone().to(mammo_x.device)
dummy_mask[loss_channels == 0] = 0.0
dummy_shape = dummy_mask.shape
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_net_out = torch.cat((dummy_net_out, torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_yhat = out0 * loss_channels
dummy_y = mask_channels * loss_channels

focal_loss = balanced_focal_cross_entropy_loss_semi(dummy_net_out, dummy_mask, 2)
dice_ls = dice_loss(dummy_net_out,dummy_mask)
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y)

weight = 0.6 * ((self.iter_counter // 200) / 400)
unsup_loss = unsupervised_loss([out0[:, [1], ...], out1[:, [1], ...],
out2[:, [1], ...], out3[:, [1], ...]])

output = {
'pixel_probs': out0,
'org_pixel_labels': mammo_loss_and_gt,
'loss': focal_loss * (0.8 - weight) + unsup_loss * (0.2 + weight),
'ce_loss': ce_loss,
'dice_loss': dice_ls,
}

return output


# python main.py semi_supervised.segmentation.entrypoint_urpc -phase train -save_dir /home/dadbeh/Results -device cuda:2 -console_file false
# python metric.py -L ../Results/TransUNetTopK/log.csv -M train_Seg_Sens_0 val_Seg_Sens_0 -M train_Seg_Sens_1 val_Seg_Sens_1 -M train_Seg_AvgSens val_Seg_AvgSens -M train_Obj_FP/P-1 val_Obj_FP/P-1 -M train_Obj_Sen-T-1 val_Obj_Sen-T-1 -M train_ls_Loss val_ls_Loss -M train_ls_focal_loss val_ls_focal_loss -M train_ls_ce_loss train_ls_ce_loss -M train_ls_dice_loss val_ls_dice_loss -M train_Obj_Prec-T-1 val_Obj_Prec-T-1
# python main.py semi_supervised.segmentation.entrypoint_urpc -phase saveModelOutput -samples_dir liveE15 -device cuda:2 -save_dir /home/dadbeh/Results -load_dir /home/dadbeh/TempResults -epoch 113 -console_file false
# python main.py semi_supervised.segmentation.entrypoint_urpc -phase saveModelOutput -samples_dir liveE15 -device cuda:2 -save_dir /home/dadbeh/Results/TopKFocalCEWeightResults -load_dir /home/dadbeh/Results/TopKFocalCEWeightResults -epoch 37 -console_file false

+ 92
- 0
segmentation/URPC/urpc_resunet.py View File

import torch
from torch import nn
import torch.nn.functional as F
from .loss import unsupervised_loss
from .resunet import ResunetSegmentor
from mlassistant.core import ModelIO, Model
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi
from .....utils.generalized_dice import dice_loss


class URPC(Model):
"""
The main design for URPC architecture with UNet as the
base network.
"""

def __init__(self):
super().__init__()

self.shape = (512, 512) # Change the probability map shape according to the mammo scans shape
self.input_channels = 1
self.class_num = 2
self.iter_counter = 0

encoders_conv_kws = [
{"inp_channel": 1, "out_channel": 32, "n_conv_blocks": 2},
{"inp_channel": 33, "out_channel": 64, "n_conv_blocks": 2},
{"inp_channel": 65, "out_channel": 128, "n_conv_blocks": 3},
{"inp_channel": 129, "out_channel": 256, "n_conv_blocks": 3},
]

bridge_conv_kws = {"inp_channel": 257, "out_channel": 512}

decoders_conv_kws = [
{"inp_channel": 512, "enc_inp_channel": 256, "out_channel": 256},
{"inp_channel": 256, "enc_inp_channel": 128, "out_channel": 64},
{"inp_channel": 64, "enc_inp_channel": 64, "out_channel": 32},
{"inp_channel": 32, "enc_inp_channel": 32, "out_channel": 16},
]

self.base_network = ResunetSegmentor(
seg_num_labels=self.class_num,
concat_original_image=True,
encoders_conv_kws=encoders_conv_kws,
decoders_conv_kws=decoders_conv_kws,
bridge_conv_kws=bridge_conv_kws,
)

def forward(self,
mammo_x: torch.Tensor,
mammo_loss_and_gt: torch.Tensor) -> ModelIO:
self.iter_counter += 1
out3, out2, out1, out0 = self.base_network(mammo_x, None)

loss_type = mammo_loss_and_gt[:, 0, ...]
ground_truth = mammo_loss_and_gt[:, 1, ...] * loss_type

main_shape = ground_truth.shape
out_prob_map_shape = out0.shape

loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],out_prob_map_shape[1],main_shape[1], main_shape[2])
ground_truth = ground_truth.unsqueeze(dim = 1)
mask_channels = torch.cat((1 - ground_truth, ground_truth), 1)

dummy_mask = mask_channels.clone().to(mammo_x.device)
dummy_net_out = out0.clone().to(mammo_x.device)
dummy_mask[loss_channels == 0] = 0.0
dummy_shape = dummy_mask.shape
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_net_out = torch.cat((dummy_net_out, torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_yhat = out0 * loss_channels
dummy_y = mask_channels * loss_channels

focal_loss = balanced_focal_cross_entropy_loss_semi(dummy_net_out, dummy_mask, 4)
dice_ls = dice_loss(dummy_net_out,dummy_mask)
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y)

weight = 0.6 * ((self.iter_counter // 200) / 200)
unsup_loss = unsupervised_loss([out0[:, [1], ...], out1[:, [1], ...],
out2[:, [1], ...], out3[:, [1], ...]])

output = {
'pixel_probs': out0,
'org_pixel_labels': mammo_loss_and_gt,
'loss': focal_loss * (0.8 - weight) + unsup_loss * (0.2 + weight),
'ce_loss': ce_loss,
'dice_loss': dice_ls,
}

return output

+ 153
- 0
segmentation/URPC/utils.py View File

from typing import Tuple, List
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms


class ConvolutionBlock(nn.Module):
"""
A block with two convolutional layers followed by batch norm
layers and LeakyReLU as activation function.
"""

def __init__(self, input_channel: int, output_channel: int):
super(ConvolutionBlock, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(input_channel, output_channel, kernel_size=3, padding=1),
nn.BatchNorm2d(output_channel),
nn.LeakyReLU(),
nn.Conv2d(output_channel, output_channel, kernel_size=3, padding=1),
nn.BatchNorm2d(output_channel),
nn.LeakyReLU()
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)


class DownSamplingBlock(nn.Module):
"""
A down-sampling block with a max pooling layer followed by
a convolutional block.
"""

def __init__(self, input_channel: int, output_channel: int):
super(DownSamplingBlock, self).__init__()
self.down_sampler = nn.Sequential(
nn.MaxPool2d(2),
ConvolutionBlock(input_channel, output_channel)
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_sampler(x)


class UpSamplingBlock(nn.Module):
"""
An up-sampling block with a 2d transpose convolution layer
for up-sampling followed by a ConvolutionBlock.
"""

def __init__(self, input_channel1, input_channel2, output_channel):
super(UpSamplingBlock, self).__init__()
self.up = nn.ConvTranspose2d(input_channel1, input_channel2, kernel_size=2, stride=2)
self.conv = ConvolutionBlock(input_channel2 * 2, output_channel)

def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
x1 = self.up(x1)
height = x1.size()[2] # Finding dimensions of the up-sampled image for cropping
width = x1.size()[3]
x2 = transforms.CenterCrop((height, width))(x2) # Cropping the image from the skip connection from the center
x = torch.cat([x2, x1], dim=1) # Concatenating both images together
return self.conv(x)


class DownSampler(nn.Module):
"""
The module used for down-sampling process consisted of
multiple down-sampling blocks and one convolutional block
at the first part.
"""

def __init__(self, input_channel: int, scale_channels: list):
super(DownSampler, self).__init__()
self.input_channel = input_channel
self.scale_channels = scale_channels

self.conv = ConvolutionBlock(self.input_channel, self.scale_channels[0])
self.down_sampler1 = DownSamplingBlock(self.scale_channels[0], self.scale_channels[1])
self.down_sampler2 = DownSamplingBlock(self.scale_channels[1], self.scale_channels[2])
self.down_sampler3 = DownSamplingBlock(self.scale_channels[2], self.scale_channels[3])
self.down_sampler4 = DownSamplingBlock(self.scale_channels[3], self.scale_channels[4])

def forward(self, x: torch.Tensor) -> List:
features = self.conv(x)
down_sampled1 = self.down_sampler1(features)
down_sampled2 = self.down_sampler2(down_sampled1)
down_sampled3 = self.down_sampler3(down_sampled2)
down_sampled4 = self.down_sampler4(down_sampled3)

return [features, down_sampled1, down_sampled2, down_sampled3, down_sampled4]


class UpSampler(nn.Module):
"""
The module used for up-sampling process consisted of
multiple up-sampling blocks and one convolutional layer
at the last part for the probability map.
"""

def __init__(self, scale_channels: list, class_num: int, shape: tuple):
super(UpSampler, self).__init__()
self.scale_channels = scale_channels
self.class_num = class_num
self.shape = shape

self.up_sampler4 = UpSamplingBlock(scale_channels[4], scale_channels[3], scale_channels[3])
self.up_sampler3 = UpSamplingBlock(scale_channels[3], scale_channels[2], scale_channels[2])
self.up_sampler2 = UpSamplingBlock(scale_channels[2], scale_channels[1], scale_channels[1])
self.up_sampler1 = UpSamplingBlock(scale_channels[1], scale_channels[0], scale_channels[0])

self.probability_map_conv0 = nn.Sequential(
nn.Conv2d(self.scale_channels[0], self.class_num, kernel_size=1)
)
self.probability_map_conv1 = nn.Sequential(
nn.Conv2d(self.scale_channels[1], self.class_num, kernel_size=1)
)
self.probability_map_conv2 = nn.Sequential(
nn.Conv2d(self.scale_channels[2], self.class_num, kernel_size=1)
)
self.probability_map_conv3 = nn.Sequential(
nn.Conv2d(self.scale_channels[3], self.class_num, kernel_size=1)
)

def forward(self, down_outputs: list) -> Tuple:
"""
NOTE: Numbers in front of the var names indicate the scale number.
"""
down_outputs0 = down_outputs.pop(0)
down_outputs1 = down_outputs.pop(0)
down_outputs2 = down_outputs.pop(0)
down_outputs3 = down_outputs.pop(0)
down_outputs4 = down_outputs.pop(0)

up_sampled4 = self.up_sampler4(down_outputs4, down_outputs3)
up_sampled4 = F.interpolate(up_sampled4, (64, 64))
output3 = self.probability_map_conv3(up_sampled4)
interpolated_output3 = F.interpolate(output3, self.shape)

up_sampled3 = self.up_sampler3(up_sampled4, down_outputs2)
output2 = self.probability_map_conv2(up_sampled3)
interpolated_output2 = F.interpolate(output2, self.shape)

up_sampled2 = self.up_sampler2(up_sampled3, down_outputs1)
output1 = self.probability_map_conv1(up_sampled2)
interpolated_output1 = F.interpolate(output1, self.shape)

up_sampled1 = self.up_sampler1(up_sampled2, down_outputs0)
output0 = self.probability_map_conv0(up_sampled1)

return torch.softmax(output0, dim=1), torch.softmax(interpolated_output1, dim=1), \
torch.softmax(interpolated_output2, dim=1), torch.softmax(interpolated_output3, dim=1)

+ 111
- 0
segmentation/UniMatch/Models/deeplabv3plus.py View File

import torch
from torch import nn
import torch.nn.functional as F
import resnet as resnet


class DeepLabV3Plus(nn.Module):
def __init__(self):
super(DeepLabV3Plus, self).__init__()

self.backbone = resnet.__dict__['resnet101'](pretrained=False)

low_channels = 256
high_channels = 2048

self.head = ASPPModule(high_channels, [6, 12, 18])

self.reduce = nn.Sequential(nn.Conv2d(low_channels, 48, 1, bias=False),
nn.BatchNorm2d(48),
nn.ReLU(True))

self.fuse = nn.Sequential(nn.Conv2d(high_channels // 8 + 48, 256, 3, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.Conv2d(256, 256, 3, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True))

self.classifier = nn.Conv2d(256, cfg['nclass'], 1, bias=True)

def forward(self, x, need_fp=False):
h, w = x.shape[-2:]

feats = self.backbone.base_forward(x)
c1, c4 = feats[0], feats[-1]

if need_fp:
outs = self._decode(torch.cat((c1, nn.Dropout2d(0.5)(c1))),
torch.cat((c4, nn.Dropout2d(0.5)(c4))))
outs = F.interpolate(outs, size=(h, w), mode="bilinear", align_corners=True)
out, out_fp = outs.chunk(2)

return out, out_fp

out = self._decode(c1, c4)
out = F.interpolate(out, size=(h, w), mode="bilinear", align_corners=True)

return out

def _decode(self, c1, c4):
c4 = self.head(c4)
c4 = F.interpolate(c4, size=c1.shape[-2:], mode="bilinear", align_corners=True)

c1 = self.reduce(c1)

feature = torch.cat([c1, c4], dim=1)
feature = self.fuse(feature)

out = self.classifier(feature)

return out


def ASPPConv(in_channels, out_channels, atrous_rate):
block = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, padding=atrous_rate,
dilation=atrous_rate, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(True))
return block


class ASPPPooling(nn.Module):
def __init__(self, in_channels, out_channels):
super(ASPPPooling, self).__init__()
self.gap = nn.Sequential(nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(True))

def forward(self, x):
h, w = x.shape[-2:]
pool = self.gap(x)
return F.interpolate(pool, (h, w), mode="bilinear", align_corners=True)


class ASPPModule(nn.Module):
def __init__(self, in_channels, atrous_rates):
super(ASPPModule, self).__init__()
out_channels = in_channels // 8
rate1, rate2, rate3 = atrous_rates

self.b0 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(True))
self.b1 = ASPPConv(in_channels, out_channels, rate1)
self.b2 = ASPPConv(in_channels, out_channels, rate2)
self.b3 = ASPPConv(in_channels, out_channels, rate3)
self.b4 = ASPPPooling(in_channels, out_channels)

self.project = nn.Sequential(nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(True))

def forward(self, x):
feat0 = self.b0(x)
feat1 = self.b1(x)
feat2 = self.b2(x)
feat3 = self.b3(x)
feat4 = self.b4(x)
y = torch.cat((feat0, feat1, feat2, feat3, feat4), 1)
return self.project(y)

+ 154
- 0
segmentation/UniMatch/Models/resnet.py View File

import torch
import torch.nn as nn


def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class Bottleneck(nn.Module):
expansion = 4

def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1):
super(Bottleneck, self).__init__()

norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups

self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride

def forward(self, x):
identity = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)

out = self.conv3(out)
out = self.bn3(out)

if self.downsample is not None:
identity = self.downsample(x)

out += identity
out = self.relu(out)

return out


class ResNet(nn.Module):

def __init__(self, block, layers, groups=1, width_per_group=64):
super(ResNet, self).__init__()

norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer

self.inplanes = 128
self.dilation = 1
replace_stride_with_dilation = [False, False, True]

if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False),
norm_layer(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
norm_layer(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False),
)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2])

for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)

layers = list()
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))

return nn.Sequential(*layers)

def base_forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)

c1 = self.layer1(x)
c2 = self.layer2(c1)
c3 = self.layer3(c2)
c4 = self.layer4(c3)

return c1, c2, c3, c4


def _resnet(arch, block, layers, pretrained, **kwargs):
model = ResNet(block, layers, **kwargs)
if pretrained:
pretrained_path = "pretrained/%s.pth" % arch
state_dict = torch.load(pretrained_path)
model.load_state_dict(state_dict, strict=False)
return model


def resnet50(pretrained=False, **kwargs):
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, **kwargs)


def resnet101(pretrained=False, **kwargs):
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, **kwargs)

+ 228
- 0
segmentation/UniMatch/Models/resunet.py View File

import torch.nn
from typing import Dict, List, Any
from ......utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss
from torch.nn import functional as F
from PIL import Image
from torchvision.utils import draw_segmentation_masks
import os
from mlassistant.core import Model
from mlassistant.context.dataloader import DataloaderContext
import numpy as np


from typing import Dict, List, Any, Callable, Optional
from torch import nn
from torch.nn import functional as F
from torchvision.models import resnet18
from mlassistant.core import Model
from .....segmentation.submodules.res_enc_block import ResEncoderBlock
from .....segmentation.submodules.res_dec_block import ResDecoderBlock
from ......enums.roi.roi_aggregation_type import ROIAggregation
from .....full_segmentation.utils import sum_roi_4ch_to_1ch
from .....segmentation.utils import cal_patch_pixel_labels_by_patch_label_and_roi, \
create_patches_label_based_on_roi, AssembleDisassemble, \
reduce_resolution, select_non_boundaries
from .....segmentation.losses import cal_focal_loss, cal_weighted_focal_loss


class ResUNet(Model):
def __init__(self,
seg_num_labels: int,
concat_original_image: bool,
encoders_conv_kws: Optional[List[Dict[str, Any]]],
bridge_conv_kws: Optional[Dict[str, Any]],
decoders_conv_kws: List[Dict[str, Any]],
roi_agg_type: ROIAggregation = ROIAggregation.mass_vs_others,
remove_boundaries: bool = False,
pretrained_resnet18: bool = False,
loss=(lambda log_pred, gt: cal_weighted_focal_loss(log_pred, gt, 4)),
use_dropout: bool = False) -> None:

super(ResUNet, self).__init__()

assert encoders_conv_kws is not None or pretrained_resnet18, \
"for creating encoder part, either of number of channels or using pretrained version of resnet18 should be determined"
assert not concat_original_image or not pretrained_resnet18, \
"concatenating is currently not supported when using pretrained version"

num_encoders = 6 if pretrained_resnet18 else len(encoders_conv_kws)
self.halving_depth = num_encoders - len(decoders_conv_kws) + 1
self._roi_agg_type: ROIAggregation = roi_agg_type
self._remove_boundaries: bool = remove_boundaries
self._loss = loss
self._concat_original_image: bool = concat_original_image

self.encoders: nn.ModuleList = None
self.bridge: nn.Sequential = None
self._create_encoder(pretrained_resnet18, encoders_conv_kws)
self._create_bridge(bridge_conv_kws)

self.decoders = nn.ModuleList([ResDecoderBlock(**kwargs)
for kwargs in decoders_conv_kws])

last_channel = decoders_conv_kws[-1]['out_channel']
if not use_dropout:
self.decider = nn.Sequential(
nn.Conv2d(last_channel, seg_num_labels, kernel_size=1),
nn.LogSoftmax(dim=1))
else:
self.decider = nn.Sequential(
nn.Dropout(0.3),
nn.Conv2d(last_channel, seg_num_labels, kernel_size=1),
nn.LogSoftmax(dim=1))

def _create_encoder(self,
pretrained: bool,
encoders_conv_kws: Optional[List[Dict[str, Any]]]) -> None:
if pretrained:
resnet = resnet18(pretrained=True)
self._inp_channel = 3
self.encoders = nn.ModuleList([])
self.encoders.append(
nn.Sequential(
resnet.conv1,
resnet.bn1,
resnet.relu,
resnet.maxpool))
for l in [resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4]:
self.encoders.append(l)

else:
self._inp_channel = encoders_conv_kws[0]['inp_channel']
self.encoders = nn.ModuleList([ResEncoderBlock(**encoders_conv_kws[0])])
encoders_conv_kws.pop(0)
self.encoders += nn.ModuleList([nn.Sequential(
nn.AvgPool2d(kernel_size=2, stride=2),
ResEncoderBlock(**kwargs)
)
for kwargs in encoders_conv_kws])


def _create_bridge(self,
bridge_conv_kws: Optional[List[Dict[str, Any]]]) -> None:
if bridge_conv_kws is None:
self.bridge = nn.AvgPool2d(kernel_size=2, stride=2),
else:
self.bridge = nn.Sequential(
nn.AvgPool2d(kernel_size=2, stride=2),
ResEncoderBlock(**bridge_conv_kws))


def _forward(self, x):
x = x.repeat_interleave(self._inp_channel, 1)
encoders_outs = []
out = x
x_ = x

for enc in self.encoders:
out = enc(out)
encoders_outs.append(out)

if self._concat_original_image:
x_ = F.interpolate(x_, out.shape[2:], mode='bilinear', align_corners=True)
out = torch.cat([out, x_], dim=1)

out = self.bridge(out)

for dec in self.decoders:
out = dec(out, encoders_outs.pop(-1))
return self.decider(out)


class ResunetSegmentor(ResUNet):

def __init__(self,
seg_num_labels: int,
concat_original_image: bool,
encoders_conv_kws: List[Dict[str, Any]],
bridge_conv_kws: Dict[str, Any],
decoders_conv_kws: List[Dict[str, Any]] ,
log_transform=False ,
is_base=True,
base_saving_dir=None ,
gamma=2):
"""
Args:
seg_num_labels (int): [description] = 2 (background vs nipple)
concat_original_image (bool): [description]
encoders_conv_kws (List[Dict[str, Any]]): [description]
bridge_conv_kws (Dict[str, Any]): [description]
decoders_conv_kws (List[Dict[str, Any]]): [description]
"""
self.is_base = is_base
self.batch_count = 0
super(ResunetSegmentor, self).__init__(
seg_num_labels,
concat_original_image,
encoders_conv_kws,
bridge_conv_kws,
decoders_conv_kws,
)
self.seg_num_labels = seg_num_labels
self.log_transform = log_transform
self.dropout_2 = torch.nn.Dropout(0.4)
self.dropout_1 = torch.nn.Dropout(0.2)
last_channel = decoders_conv_kws[-1]["out_channel"]
self.decider = torch.nn.Sequential(
torch.nn.Conv2d(last_channel, seg_num_labels, kernel_size=1),
torch.nn.LogSoftmax(dim=1),
)
self.gamma = gamma
self.base_model = None
self.base_saving_dir = base_saving_dir
self.debug = self.base_saving_dir is not None
if self.debug:
os.makedirs(self.base_saving_dir , exist_ok=True)
def set_base_model(self,base_model:Model):
self.base_model = base_model
return

def _forward(self, x, need_fp=False):
encoders_outs = []
out = x
x_ = x
for enc in self.encoders:
out = enc(out)
encoders_outs.append(out)
out = self.dropout_1(out)
if self._concat_original_image:
x_ = F.interpolate(x_, out.shape[2:], mode='bilinear', align_corners=True)
out = torch.cat([out, x_], dim=1)
final_encoder_out = out
final_encoder_out = final_encoder_out.flatten(2)

out = self.dropout_2(out)
out = self.bridge(out)
for dec in self.decoders:
encoder_out = encoders_outs.pop(-1)
if need_fp:
encoder_out = torch.cat((encoder_out, nn.Dropout2d(0.5)(encoder_out)))
out = dec(out, encoder_out)
out = self.dropout_1(out)

final_segmentation = self.decider(out)
if need_fp:
return final_segmentation.chunk(2)
return final_segmentation

def forward(self,
mammo_x: torch.Tensor,
need_fp=False):
batch_size = len(mammo_x)
predicted_mask = None
chosen_samples = [0 for i in range(batch_size)]

output = dict()
log_pixel_probs = self._forward(mammo_x, need_fp)
pixel_probs = torch.exp(log_pixel_probs).to(mammo_x.device)
return output

+ 174
- 0
segmentation/UniMatch/Models/unet.py View File

import numpy as np
import torch
import torch.nn as nn
from torch.distributions.uniform import Uniform


def kaiming_normal_init_weight(model):
for m in model.modules():
if isinstance(m, nn.Conv3d):
torch.nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm3d):
m.weight.data.fill_(1)
m.bias.data.zero_()
return model


def sparse_init_weight(model):
for m in model.modules():
if isinstance(m, nn.Conv3d):
torch.nn.init.sparse_(m.weight, sparsity=0.1)
elif isinstance(m, nn.BatchNorm3d):
m.weight.data.fill_(1)
m.bias.data.zero_()
return model


class ConvBlock(nn.Module):
"""two convolution layers with batch norm and leaky relu"""

def __init__(self, in_channels, out_channels, dropout_p):
super(ConvBlock, self).__init__()
self.conv_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(),
nn.Dropout(dropout_p),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU()
)

def forward(self, x):
return self.conv_conv(x)


class DownBlock(nn.Module):
"""Downsampling followed by ConvBlock"""

def __init__(self, in_channels, out_channels, dropout_p):
super(DownBlock, self).__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
ConvBlock(in_channels, out_channels, dropout_p)

)

def forward(self, x):
return self.maxpool_conv(x)


class UpBlock(nn.Module):
"""Upssampling followed by ConvBlock"""

def __init__(self, in_channels1, in_channels2, out_channels, dropout_p,
bilinear=True):
super(UpBlock, self).__init__()
self.bilinear = bilinear
if bilinear:
self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size=1)
self.up = nn.Upsample(
scale_factor=2, mode='bilinear', align_corners=True)
else:
self.up = nn.ConvTranspose2d(
in_channels1, in_channels2, kernel_size=2, stride=2)
self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p)

def forward(self, x1, x2):
if self.bilinear:
x1 = self.conv1x1(x1)
x1 = self.up(x1)
x = torch.cat([x2, x1], dim=1)
return self.conv(x)


class Encoder(nn.Module):
def __init__(self, params):
super(Encoder, self).__init__()
self.params = params
self.in_chns = self.params['in_chns']
self.ft_chns = self.params['feature_chns']
self.n_class = self.params['class_num']
self.bilinear = self.params['bilinear']
self.dropout = self.params['dropout']
assert (len(self.ft_chns) == 5)
self.in_conv = ConvBlock(
self.in_chns, self.ft_chns[0], self.dropout[0])
self.down1 = DownBlock(
self.ft_chns[0], self.ft_chns[1], self.dropout[1])
self.down2 = DownBlock(
self.ft_chns[1], self.ft_chns[2], self.dropout[2])
self.down3 = DownBlock(
self.ft_chns[2], self.ft_chns[3], self.dropout[3])
self.down4 = DownBlock(
self.ft_chns[3], self.ft_chns[4], self.dropout[4])

def forward(self, x):
x0 = self.in_conv(x)
x1 = self.down1(x0)
x2 = self.down2(x1)
x3 = self.down3(x2)
x4 = self.down4(x3)
return [x0, x1, x2, x3, x4]


class Decoder(nn.Module):
def __init__(self, params):
super(Decoder, self).__init__()
self.params = params
self.in_chns = self.params['in_chns']
self.ft_chns = self.params['feature_chns']
self.n_class = self.params['class_num']
self.bilinear = self.params['bilinear']
assert (len(self.ft_chns) == 5)

self.up1 = UpBlock(
self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p=0.0)
self.up2 = UpBlock(
self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p=0.0)
self.up3 = UpBlock(
self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p=0.0)
self.up4 = UpBlock(
self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p=0.0)

self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class,
kernel_size=3, padding=1)

def forward(self, feature):
x0 = feature[0]
x1 = feature[1]
x2 = feature[2]
x3 = feature[3]
x4 = feature[4]

x = self.up1(x4, x3)
x = self.up2(x, x2)
x = self.up3(x, x1)
x = self.up4(x, x0)
output = self.out_conv(x)
return output


class UNet(nn.Module):
def __init__(self, in_chns, class_num):
super(UNet, self).__init__()

params = {'in_chns': in_chns,
'feature_chns': [64, 128, 128, 256, 512],
'dropout': [0.05, 0.1, 0.2, 0.3, 0.5],
'class_num': class_num,
'bilinear': False,
'acti_func': 'relu'}

self.encoder = Encoder(params)
self.decoder = Decoder(params)

def forward(self, x, need_fp=False):
feature = self.encoder(x)
if need_fp:
outs = self.decoder([torch.cat((feat, nn.Dropout2d(0.5)(feat))) for feat in feature])
return outs.chunk(2)
output = self.decoder(feature)
return output

+ 147
- 0
segmentation/UniMatch/fixmatch.py View File

import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms

from copy import deepcopy
import numpy as np
import random

from .Models.unet import UNet
from .transform import random_rot_flip, random_rotate
from mlassistant.core import ModelIO, Model
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi
from .....utils.generalized_dice import dice_loss


class DiceLoss(nn.Module):
def __init__(self, n_classes):
super(DiceLoss, self).__init__()
self.n_classes = n_classes

def _one_hot_encoder(self, input_tensor):
tensor_list = []
for i in range(self.n_classes):
temp_prob = input_tensor == i * torch.ones_like(input_tensor)
tensor_list.append(temp_prob)
output_tensor = torch.cat(tensor_list, dim=1)
return output_tensor.float()

def _dice_loss(self, score, target, ignore):
target = target.float()
smooth = 1e-5
intersect = torch.sum(score[ignore != 1] * target[ignore != 1])
y_sum = torch.sum(target[ignore != 1] * target[ignore != 1])
z_sum = torch.sum(score[ignore != 1] * score[ignore != 1])
loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
loss = 1 - loss
return loss

def forward(self, inputs, target, weight=None, softmax=False, ignore=None):
if softmax:
inputs = torch.softmax(inputs, dim=1)
target = self._one_hot_encoder(target)
if weight is None:
weight = [1] * self.n_classes
assert inputs.size() == target.size(), 'predict & target shape do not match'
class_wise_dice = []
loss = 0.0
for i in range(0, self.n_classes):
dice = self._dice_loss(inputs[:, i], target[:, i], ignore)
class_wise_dice.append(1.0 - dice.item())
loss += dice * weight[i]
return loss / self.n_classes


class FixMatchModel(nn.Module):
def __init__(self):
super(FixMatchModel, self).__init__()

self.base_model = UNet(1, 2)
self.criterion_dice = DiceLoss(n_classes=2)

def forward(self, x):
weak_img = self._get_weak_augs(x)
strong_img = self._get_strong_augs(weak_img)

pred_u_w = self.base_model(weak_img.to(x.device)).softmax(dim=1)
pred_u_s = self.base_model(strong_img.to(x.device)).softmax(dim=1)
mask_u_w = pred_u_w.argmax(dim=1)

print(pred_u_w.max(dim=1))
input()

loss_u_s = self.criterion_dice(pred_u_s, mask_u_w.unsqueeze(1).float())

return self.base_model(x).softmax(dim=1), loss_u_s

def _get_strong_augs(self, img):
img_s = deepcopy(img)

if random.random() < 0.8:
img_s = transforms.ColorJitter(0.25, 0.25, 0.25, 0.15)(img_s)
if random.random() < 0.5:
sigma = np.random.uniform(0.1, 2.0)
img_s = transforms.functional.gaussian_blur(img_s, kernel_size=(5, 5), sigma=sigma)

return img_s

def _get_weak_augs(self, img):
if random.random() > 0.5:
img = random_rot_flip(img)
elif random.random() > 0.5:
img = random_rotate(img)

if isinstance(img, torch.Tensor):
return img

return torch.from_numpy(img).float()


class FixMatch(Model):
"""
FixMatch architecture.
"""

def __init__(self):
super().__init__()
self.fix_match = FixMatchModel()

def forward(self,
mammo_x: torch.Tensor,
mammo_loss_and_gt: torch.Tensor) -> ModelIO:
loss_type = mammo_loss_and_gt[:, 0, ...]
ground_truth = mammo_loss_and_gt[:, 1, ...] * loss_type
out_prob_map, unsup_loss = self.fix_match(mammo_x)

main_shape = ground_truth.shape
out_prob_map_shape = out_prob_map.shape

loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],out_prob_map_shape[1],main_shape[1], main_shape[2])
ground_truth = ground_truth.unsqueeze(dim = 1)
mask_channels = torch.cat((1 - ground_truth, ground_truth), 1)

dummy_mask = mask_channels.clone().to(mammo_x.device)
dummy_net_out = out_prob_map.clone().to(mammo_x.device)
dummy_mask[loss_channels == 0] = 0.0
dummy_shape = dummy_mask.shape
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_net_out = torch.cat((dummy_net_out, torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_yhat = out_prob_map * loss_channels
dummy_y = mask_channels * loss_channels

dice = dice_loss(dummy_net_out,dummy_mask)
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y)
focal_loss = balanced_focal_cross_entropy_loss_semi(dummy_net_out, dummy_mask, 4)

output = {
'pixel_probs': out_prob_map,
'org_pixel_labels': mammo_loss_and_gt,
'loss': (focal_loss + unsup_loss) / 2,
'focal_loss': focal_loss,
'ce_loss': ce_loss,
'dice_loss': dice,
}

return output

+ 153
- 0
segmentation/UniMatch/fixmatch_resunet.py View File

import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms

from copy import deepcopy
import numpy as np
import random

from .....models.semi_supervised.base_line.Resunet.Heavy.Models.resunet import RESUNet as RESUNETSEGMENTOR
from .Models.unet import UNet
from .transform import random_rot_flip, random_rotate
from mlassistant.core import ModelIO, Model
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi
from .....utils.generalized_dice import dice_loss


class DiceLoss(nn.Module):
def __init__(self, n_classes):
super(DiceLoss, self).__init__()
self.n_classes = n_classes

def _one_hot_encoder(self, input_tensor):
tensor_list = []
for i in range(self.n_classes):
temp_prob = input_tensor == i * torch.ones_like(input_tensor)
tensor_list.append(temp_prob)
output_tensor = torch.cat(tensor_list, dim=1)
return output_tensor.float()

def _dice_loss(self, score, target, ignore):
target = target.float()
smooth = 1e-5
intersect = torch.sum(score[ignore != 1] * target[ignore != 1])
y_sum = torch.sum(target[ignore != 1] * target[ignore != 1])
z_sum = torch.sum(score[ignore != 1] * score[ignore != 1])
loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
loss = 1 - loss
return loss

def forward(self, inputs, target, weight=None, softmax=False, ignore=None):
if softmax:
inputs = torch.softmax(inputs, dim=1)
target = self._one_hot_encoder(target)
if weight is None:
weight = [1] * self.n_classes
assert inputs.size() == target.size(), 'predict & target shape do not match'
class_wise_dice = []
loss = 0.0
for i in range(0, self.n_classes):
dice = self._dice_loss(inputs[:, i], target[:, i], ignore)
class_wise_dice.append(1.0 - dice.item())
loss += dice * weight[i]
return loss / self.n_classes


class FixMatchModel(nn.Module):
def __init__(self):
super(FixMatchModel, self).__init__()

self.base_model = RESUNETSEGMENTOR(1, 2)
self.criterion_dice = DiceLoss(n_classes=2)

def forward(self, x):
weak_img = self._get_weak_augs(x)
strong_img = self._get_strong_augs(weak_img)

pred_u_w = self.base_model(weak_img.to(x.device)).softmax(dim=1)
pred_u_s = self.base_model(strong_img.to(x.device)).softmax(dim=1)

indicator = []
for i in range(len(pred_u_w)):
if torch.max(x) >= 0.95:
indicator.append(i)
pred_u_w = pred_u_w[indicator, ...]
pred_u_s = pred_u_w[indicator, ...]
mask_u_w = pred_u_w.argmax(dim=1)
loss_u_s = self.criterion_dice(pred_u_s, mask_u_w.unsqueeze(1).float())

return self.base_model(x).softmax(dim=1), loss_u_s

def _get_strong_augs(self, img):
img_s = deepcopy(img)

if random.random() < 0.8:
img_s = transforms.ColorJitter(0.25, 0.25, 0.25, 0.15)(img_s)
if random.random() < 0.5:
sigma = np.random.uniform(0.1, 2.0)
img_s = transforms.functional.gaussian_blur(img_s, kernel_size=(5, 5), sigma=sigma)

return img_s

def _get_weak_augs(self, img):
if random.random() > 0.5:
img = random_rot_flip(img)
elif random.random() > 0.5:
img = random_rotate(img)

if isinstance(img, torch.Tensor):
return img

return torch.from_numpy(img).float()


class FixMatch(Model):
"""
FixMatch architecture.
"""

def __init__(self):
super().__init__()
self.fix_match = FixMatchModel()

def forward(self,
mammo_x: torch.Tensor,
mammo_loss_and_gt: torch.Tensor) -> ModelIO:
loss_type = mammo_loss_and_gt[:, 0, ...]
ground_truth = mammo_loss_and_gt[:, 1, ...] * loss_type
out_prob_map, unsup_loss = self.fix_match(mammo_x)

main_shape = ground_truth.shape
out_prob_map_shape = out_prob_map.shape

loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],out_prob_map_shape[1],main_shape[1], main_shape[2])
ground_truth = ground_truth.unsqueeze(dim = 1)
mask_channels = torch.cat((1 - ground_truth, ground_truth), 1)

dummy_mask = mask_channels.clone().to(mammo_x.device)
dummy_net_out = out_prob_map.clone().to(mammo_x.device)
dummy_mask[loss_channels == 0] = 0.0
dummy_shape = dummy_mask.shape
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_net_out = torch.cat((dummy_net_out, torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_yhat = out_prob_map * loss_channels
dummy_y = mask_channels * loss_channels

dice = dice_loss(dummy_net_out,dummy_mask)
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y)
focal_loss = balanced_focal_cross_entropy_loss_semi(dummy_net_out, dummy_mask, 4)

output = {
'pixel_probs': out_prob_map,
'org_pixel_labels': mammo_loss_and_gt,
'loss': (focal_loss + unsup_loss) / 2,
'focal_loss': focal_loss,
'ce_loss': ce_loss,
'dice_loss': dice,
}

return output

+ 40
- 0
segmentation/UniMatch/transform.py View File

import numpy as np
import random
import torch
from torchvision import transforms
from scipy import ndimage


def random_rot_flip(img):
k = np.random.randint(0, 4)
img = transforms.functional.rotate(img, angle=k * 90)
axis = np.random.randint(0, 2)
img = torch.flip(img, [axis])
return img


def random_rotate(img):
angle = np.random.randint(-20, 20)
img = transforms.functional.rotate(img, angle=angle)
return img


def obtain_cutmix_box(img_size, p=0.5, size_min=0.02, size_max=0.4, ratio_1=0.3, ratio_2=1/0.3):
mask = torch.zeros(img_size, img_size)
if random.random() > p:
return mask

size = np.random.uniform(size_min, size_max) * img_size * img_size
while True:
ratio = np.random.uniform(ratio_1, ratio_2)
cutmix_w = int(np.sqrt(size / ratio))
cutmix_h = int(np.sqrt(size * ratio))
x = np.random.randint(0, img_size)
y = np.random.randint(0, img_size)

if x + cutmix_w <= img_size and y + cutmix_h <= img_size:
break

mask[y:y + cutmix_h, x:x + cutmix_w] = 1

return mask

+ 160
- 0
segmentation/UniMatch/unimatch.py View File

import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms

from copy import deepcopy
import math
import numpy as np
import os
from PIL import Image
import random
from scipy.ndimage.interpolation import zoom
from scipy import ndimage

from .Models.unet import UNet
from .transform import random_rot_flip, random_rotate, obtain_cutmix_box
from mlassistant.core import ModelIO, Model
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi
from .....utils.generalized_dice import dice_loss


class DiceLoss(nn.Module):
def __init__(self, n_classes):
super(DiceLoss, self).__init__()
self.n_classes = n_classes

def _one_hot_encoder(self, input_tensor):
tensor_list = []
for i in range(self.n_classes):
temp_prob = input_tensor == i * torch.ones_like(input_tensor)
tensor_list.append(temp_prob)
output_tensor = torch.cat(tensor_list, dim=1)
return output_tensor.float()

def _dice_loss(self, score, target, ignore):
target = target.float()
smooth = 1e-5
intersect = torch.sum(score[ignore != 1] * target[ignore != 1])
y_sum = torch.sum(target[ignore != 1] * target[ignore != 1])
z_sum = torch.sum(score[ignore != 1] * score[ignore != 1])
loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
loss = 1 - loss
return loss

def forward(self, inputs, target, weight=None, softmax=False, ignore=None):
if softmax:
inputs = torch.softmax(inputs, dim=1)
target = self._one_hot_encoder(target)
if weight is None:
weight = [1] * self.n_classes
assert inputs.size() == target.size(), 'predict & target shape do not match'
class_wise_dice = []
loss = 0.0
for i in range(0, self.n_classes):
dice = self._dice_loss(inputs[:, i], target[:, i], ignore)
class_wise_dice.append(1.0 - dice.item())
loss += dice * weight[i]
return loss / self.n_classes


class UniMatchModel(nn.Module):
def __init__(self):
super(UniMatchModel, self).__init__()

self.base_model = UNet(1, 2)
self.criterion_dice = DiceLoss(n_classes=2)

def forward(self, x):
weak_img = self._get_weak_augs(x)
strong_img1, strong_img2 = self._get_strong_augs(weak_img)

pred_u_w, pred_fp = self.base_model(weak_img.to(x.device), True)
pred_u_s1, pred_u_s2 = self.base_model(torch.cat((strong_img1.to(x.device), strong_img2.to(x.device)))).chunk(2)
mask_u_w = pred_u_w.argmax(dim=1)

loss_u_s1 = self.criterion_dice(pred_u_s1.softmax(dim=1),
mask_u_w.unsqueeze(1).float())
loss_u_s2 = self.criterion_dice(pred_u_s2.softmax(dim=1),
mask_u_w.unsqueeze(1).float())
loss_u_w_fp = self.criterion_dice(pred_fp.softmax(dim=1), mask_u_w.unsqueeze(1).float())
loss = loss_u_s1 * 0.25 + loss_u_s2 * 0.25 + loss_u_w_fp * 0.5

return self.base_model(x).softmax(dim=1), loss

def _get_strong_augs(self, img):
img_s1, img_s2 = deepcopy(img), deepcopy(img)

if random.random() < 0.8:
img_s1 = transforms.ColorJitter(0.5, 0.5, 0.5, 0.25)(img_s1)
if random.random() < 0.5:
sigma = np.random.uniform(0.1, 2.0)
img_s1 = transforms.functional.gaussian_blur(img_s1, kernel_size=(5, 5), sigma=sigma)

if random.random() < 0.8:
img_s2 = transforms.ColorJitter(0.5, 0.5, 0.5, 0.25)(img_s2)
if random.random() < 0.5:
sigma = np.random.uniform(0.1, 2.0)
img_s2 = transforms.functional.gaussian_blur(img_s2, kernel_size=(5, 5), sigma=sigma)

return img_s1, img_s2

def _get_weak_augs(self, img):
if random.random() > 0.5:
img = random_rot_flip(img)
elif random.random() > 0.5:
img = random_rotate(img)

if isinstance(img, torch.Tensor):
return img

return torch.from_numpy(img).float()


class UniMatch(Model):
"""
UniMatch architecture.
"""

def __init__(self):
super().__init__()
self.uni_match = UniMatchModel()

def forward(self,
mammo_x: torch.Tensor,
mammo_loss_and_gt: torch.Tensor) -> ModelIO:
loss_type = mammo_loss_and_gt[:, 0, ...]
ground_truth = mammo_loss_and_gt[:, 1, ...] * loss_type
out_prob_map, unsup_loss = self.uni_match(mammo_x)

main_shape = ground_truth.shape
out_prob_map_shape = out_prob_map.shape

loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],out_prob_map_shape[1],main_shape[1], main_shape[2])
ground_truth = ground_truth.unsqueeze(dim = 1)
mask_channels = torch.cat((1 - ground_truth, ground_truth), 1)

dummy_mask = mask_channels.clone().to(mammo_x.device)
dummy_net_out = out_prob_map.clone().to(mammo_x.device)
dummy_mask[loss_channels == 0] = 0.0
dummy_shape = dummy_mask.shape
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_net_out = torch.cat((dummy_net_out, torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_yhat = out_prob_map * loss_channels
dummy_y = mask_channels * loss_channels

dice = dice_loss(dummy_net_out,dummy_mask)
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y)
focal_loss = balanced_focal_cross_entropy_loss_semi(dummy_net_out, dummy_mask, 4)

output = {
'pixel_probs': out_prob_map,
'org_pixel_labels': mammo_loss_and_gt,
'loss': (focal_loss + unsup_loss) / 2,
'focal_loss': focal_loss,
'ce_loss': ce_loss,
'dice_loss': dice,
}

return output

+ 187
- 0
segmentation/UniMatch/unimatch_resunet.py View File

import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms

from .Models.resunet import ResunetSegmentor
from mlassistant.core import ModelIO, Model
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi
from .....utils.generalized_dice import dice_loss

from copy import deepcopy
import math
import numpy as np
import os
from PIL import Image
import random
from scipy.ndimage.interpolation import zoom
from scipy import ndimage

from .transform import random_rot_flip, random_rotate, obtain_cutmix_box
from mlassistant.core import ModelIO, Model
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi
from .....utils.generalized_dice import dice_loss


class DiceLoss(nn.Module):
def __init__(self, n_classes):
super(DiceLoss, self).__init__()
self.n_classes = n_classes

def _one_hot_encoder(self, input_tensor):
tensor_list = []
for i in range(self.n_classes):
temp_prob = input_tensor == i * torch.ones_like(input_tensor)
tensor_list.append(temp_prob)
output_tensor = torch.cat(tensor_list, dim=1)
return output_tensor.float()

def _dice_loss(self, score, target, ignore):
target = target.float()
smooth = 1e-5
intersect = torch.sum(score[ignore != 1] * target[ignore != 1])
y_sum = torch.sum(target[ignore != 1] * target[ignore != 1])
z_sum = torch.sum(score[ignore != 1] * score[ignore != 1])
loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
loss = 1 - loss
return loss

def forward(self, inputs, target, weight=None, softmax=False, ignore=None):
if softmax:
inputs = torch.softmax(inputs, dim=1)
target = self._one_hot_encoder(target)
if weight is None:
weight = [1] * self.n_classes
assert inputs.size() == target.size(), 'predict & target shape do not match'
class_wise_dice = []
loss = 0.0
for i in range(0, self.n_classes):
dice = self._dice_loss(inputs[:, i], target[:, i], ignore)
class_wise_dice.append(1.0 - dice.item())
loss += dice * weight[i]
return loss / self.n_classes


class UniMatchModel(nn.Module):
def __init__(self):
super(UniMatchModel, self).__init__()

encoders_conv_kws = [
{"inp_channel": 1, "out_channel": 32, "n_conv_blocks": 2},
{"inp_channel": 33, "out_channel": 64, "n_conv_blocks": 2},
{"inp_channel": 65, "out_channel": 128, "n_conv_blocks": 3},
{"inp_channel": 129, "out_channel": 256, "n_conv_blocks": 3},
]

bridge_conv_kws = {"inp_channel": 257, "out_channel": 512}

decoders_conv_kws = [
{"inp_channel": 512, "enc_inp_channel": 256, "out_channel": 256},
{"inp_channel": 256, "enc_inp_channel": 128, "out_channel": 64},
{"inp_channel": 64, "enc_inp_channel": 64, "out_channel": 32},
{"inp_channel": 32, "enc_inp_channel": 32, "out_channel": 16},
]

self.base_model = ResunetSegmentor(
seg_num_labels=2,
concat_original_image=True,
encoders_conv_kws=encoders_conv_kws,
decoders_conv_kws=decoders_conv_kws,
bridge_conv_kws=bridge_conv_kws,
)

self.criterion_dice = DiceLoss(n_classes=2)

def forward(self, x):
weak_img = self._get_weak_augs(x)
strong_img1, strong_img2 = self._get_strong_augs(weak_img)

pred_u_w, pred_fp = self.base_model(weak_img.to(x.device), True)
pred_u_s1, pred_u_s2 = self.base_model(torch.cat((strong_img1.to(x.device), strong_img2.to(x.device)))).chunk(2)
mask_u_w = pred_u_w.argmax(dim=1)

loss_u_s1 = self.criterion_dice(pred_u_s1.softmax(dim=1),
mask_u_w.unsqueeze(1).float())
loss_u_s2 = self.criterion_dice(pred_u_s2.softmax(dim=1),
mask_u_w.unsqueeze(1).float())
loss_u_w_fp = self.criterion_dice(pred_fp.softmax(dim=1), mask_u_w.unsqueeze(1).float())
loss = loss_u_s1 * 0.25 + loss_u_s2 * 0.25 + loss_u_w_fp * 0.5

return self.base_model(x).softmax(dim=1), loss

def _get_strong_augs(self, img):
img_s1, img_s2 = deepcopy(img), deepcopy(img)

if random.random() < 0.8:
img_s1 = transforms.ColorJitter(0.5, 0.5, 0.5, 0.25)(img_s1)
if random.random() < 0.5:
sigma = np.random.uniform(0.1, 2.0)
img_s1 = transforms.functional.gaussian_blur(img_s1, kernel_size=(5, 5), sigma=sigma)

if random.random() < 0.8:
img_s2 = transforms.ColorJitter(0.5, 0.5, 0.5, 0.25)(img_s2)
if random.random() < 0.5:
sigma = np.random.uniform(0.1, 2.0)
img_s2 = transforms.functional.gaussian_blur(img_s2, kernel_size=(5, 5), sigma=sigma)

return img_s1, img_s2

def _get_weak_augs(self, img):
if random.random() > 0.5:
img = random_rot_flip(img)
elif random.random() > 0.5:
img = random_rotate(img)

if isinstance(img, torch.Tensor):
return img

return torch.from_numpy(img).float()


class UniMatch(Model):
"""
UniMatch architecture.
"""

def __init__(self):
super().__init__()
self.uni_match = UniMatchModel()

def forward(self,
mammo_x: torch.Tensor,
mammo_loss_and_gt: torch.Tensor) -> ModelIO:
loss_type = mammo_loss_and_gt[:, 0, ...]
ground_truth = mammo_loss_and_gt[:, 1, ...] * loss_type
out_prob_map, unsup_loss = self.uni_match(mammo_x)

main_shape = ground_truth.shape
out_prob_map_shape = out_prob_map.shape

loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],out_prob_map_shape[1],main_shape[1], main_shape[2])
ground_truth = ground_truth.unsqueeze(dim = 1)
mask_channels = torch.cat((1 - ground_truth, ground_truth), 1)

dummy_mask = mask_channels.clone().to(mammo_x.device)
dummy_net_out = out_prob_map.clone().to(mammo_x.device)
dummy_mask[loss_channels == 0] = 0.0
dummy_shape = dummy_mask.shape
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_net_out = torch.cat((dummy_net_out, torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_yhat = out_prob_map * loss_channels
dummy_y = mask_channels * loss_channels

dice = dice_loss(dummy_net_out,dummy_mask)
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y)
focal_loss = balanced_focal_cross_entropy_loss_semi(dummy_net_out, dummy_mask, 4)

output = {
'pixel_probs': out_prob_map,
'org_pixel_labels': mammo_loss_and_gt,
'loss': (focal_loss + unsup_loss) / 2,
'focal_loss': focal_loss,
'ce_loss': ce_loss,
'dice_loss': dice,
}

return output

+ 194
- 0
uasmt/UASMT.py View File

from copy import deepcopy
import math
from random import uniform
from typing import List
from torch.nn import functional as F
from torch import nn
from ....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi
from ....utils.generalized_dice import dice_loss
from ....utils.losses.ent_losses import entropy_loss_normalized
from ....utils.losses.tverskyLoss import tversky_loss
import time
from mlassistant.core import Model, ModelIO
import torch
from torch import nn
from ....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi
from ....utils.generalized_dice import dice_loss

from ....utils.dynamic_num_generator.num_generator import init_num_generator1, init_num_generator2
from ....utils.dynamic_num_generator.wrapper import dynamic_num_generator
import numpy as np

def random_dropout(min, max):
def assign(module: nn.Dropout,_: torch.Tensor) -> None:
module.p = uniform(min, max)
return assign

def static_dropout(prob: float):
def assign(module: nn.Dropout, _: torch.Tensor) -> None:
module.p = prob
return assign

class UncertaintyAwareStudentMeanTeacherNet(Model):
""" adapted from the paper 'Uncertainty-aware Self-ensembling Model for Semi-supervised
3D Left Atrium Segmentation'. link: https://arxiv.org/pdf/1907.07034.pdf

Args:
alpha (float): exponential moving average decay
total_num_iteration (int): total number of iteration this run will have
base_model (Model): base networks for both student and teacher
x_arg (str): argument name of `forward` indicating images, shape: B ...
x_roi_arg (str): argument name of `forward` indicating ground-truth ROIs, shape: B ...
x_annotated_arg (str): argument name of `forward` indicating being labeled or not for inputs, shape: B
uncertainty_ratio_thd (float): the initial ratio threshold for uncertainty pruning
std (float): standard deviation of gaussian noise to be applied to the input
mean (float): mean of gaussian noise to be applied to the input
T (int): number of perturbations for teacher net
lambda_coef (float): lambda coefficient for calculating unsupervised loss
use_supervised_in_consistency_checking (bool): if true then use supervised samples in unsupervised task (consistency checking)
"""
def train(self,mode : bool=True):
super().train(mode)
self.epochNUM += 1
# with torch.no_grad():
# if self.epochNUM % 2 == 0:
# self.update_teacher_params()
self.mode = mode
print(mode)
def __init__(self,
alpha: float,
base_model: Model,
std: float = math.sqrt(0.0001),
mean: float = 0.0,
T: int=8,
gamma = 2, smooth=1, alphat=0.7, beta=0.3) -> None:
super(UncertaintyAwareStudentMeanTeacherNet, self).__init__()

self.epochNUM = 1
assert alpha >= 0 and alpha <= 1, f'alpha should be in range [0, 1]; got {alpha} instead.'
assert T > 0, f'number of perturbations to be applied for teacher net should be at least one; got {T} instead.'

self._T: int = T
self._alpha: float = alpha
self._mean: float = mean
self._std: float = std
self.smooth = smooth
self.alphatt = alphat
self.beta = beta

self._student: Model = base_model
self._teacher: Model = deepcopy(self._student)
self.gamma = gamma
self.uncercoef = lambda t : torch.tensor(0.1*np.exp(-5 * (1-t/200)**2))

self.H = np.log(2)
self.threshold = lambda t : torch.tensor((self.H/4)*np.exp(-5 * (1-t/200)**2) + self.H * 3/4)
self.flag = True
self.mode = True

def _update_teacher_params(self,
_: Model,
__: torch.Tensor) -> None:
"""a backward hook to update teacher network parameters, using exponential moving
average method."""

for s_p, t_p in zip(self._student.parameters(), self._teacher.parameters()):
t_p.data = self._alpha * t_p.data + (1 - self._alpha) * s_p.data
def update_teacher_params(self) -> None:
"""a backward hook to update teacher network parameters, using exponential moving
average method."""

for s_p, t_p in zip(self._student.parameters(), self._teacher.parameters()):
t_p.data = self._alpha * t_p.data + (1 - self._alpha) * s_p.data

def _apply_gaussian_noise(self,
input ) -> torch.Tensor:
inp = input
noise = torch.randn(inp.size()) * self._std + self._mean
noise = noise.to(inp.device)
inp = inp + (noise * inp)
input = inp
return input

def forward(self,
mammo_x: torch.Tensor,
mammo_loss_and_gt: torch.Tensor) -> ModelIO:
"""first forwards supervised samples through student net, then all samples through both nets. """
with torch.no_grad():
if self.mode and self.epochNUM > 1:
self.update_teacher_params()
output = dict(loss=0.0)
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)"
loss_type = mammo_loss_and_gt[:,0,:,:]
mask = mammo_loss_and_gt[:,1,:,:]
network_output = self._student.forward(mammo_x).softmax(dim=1) # logsoftmax # B
network_output_soft = network_output #torch.exp(network_output).to(mammo_x.device) # softmax output of model
output['pixel_probs'] = network_output_soft


main_shape = mask.shape
network_output_shape = network_output.shape



loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2])
mask = mask.unsqueeze(dim = 1)
mask_channels = torch.cat((1-mask, mask), 1)

dummy_mask = mask_channels.clone().to(mammo_x.device)
dummy_net_out = network_output_soft.clone().to(mammo_x.device)
dummy_mask[loss_channels == 0] = 0.0
dummy_shape = dummy_mask.shape
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_net_out = torch.cat((dummy_net_out,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1)
dummy_yhat = network_output_soft * loss_channels
dummy_y = mask_channels * loss_channels
focal_loss = balanced_focal_cross_entropy_loss_semi(
dummy_net_out,
dummy_mask,
focal_gamma=self.gamma)
dice = dice_loss(dummy_net_out,dummy_mask)
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y)
tversky = tversky_loss(inputs = dummy_yhat[:,1,...], targets = dummy_y[:,1,...], smooth=self.smooth, alpha = self.alphatt, beta= self.beta)

with torch.no_grad():
teacher_outs: List[torch.Tensor] = list()
for _ in range(self._T):
teacher_outs.append(self._teacher.forward(
self._apply_gaussian_noise(mammo_x)).softmax(dim=1))
out_t_u = (torch.stack(teacher_outs, dim=2)).mean(dim=2) # B C H' W'
uncertainty = -1 * torch.sum(out_t_u * torch.log(out_t_u), 1) # B H' W'
mse = torch.sum((network_output_soft - out_t_u) ** 2, dim=1)
# consider only most certain pixels for mse loss
mse_ = torch.zeros_like(mse).to(mammo_x.device)
UThreshhold = self.threshold(int(self.epochNUM/2))
mask_mse = uncertainty < UThreshhold
mse_[mask_mse] += mse[mask_mse]
active_pxs = torch.sum(mask_mse.long(), dim=[1, 2]) + 1 # B
mse_loss = torch.sum(mse_ / active_pxs[:, None, None], dim=[1, 2])
mse_loss = torch.mean(mse_loss)
mse_coef = self.uncercoef(int(self.epochNUM/2)).to(mammo_x.device)
output['ce_loss'] = ce_loss
output['loss'] = focal_loss + mse_coef * mse_loss
#output['diceloss'] = dice
#output['tverskyloss'] = tversky
output['focalloss'] = focal_loss
output['suploss'] = focal_loss
output['org_pixel_labels'] = mammo_loss_and_gt
output['mse_loss'] = mse_loss
output['mse_coef_loss'] = torch.tensor(mse_coef)
output['Uthreshhold_loss'] = torch.tensor(UThreshhold)
output['totol_loss'] = output['loss']
return output

Loading…
Cancel
Save