@@ -0,0 +1,170 @@ | |||
from copy import deepcopy | |||
import math | |||
from random import uniform | |||
from typing import List | |||
import time | |||
from mlassistant.core import Model, ModelIO | |||
import torch | |||
from torch.nn import functional as F | |||
from torch import nn | |||
from ....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi | |||
from ....utils.generalized_dice import dice_loss | |||
from ....utils.losses.tverskyLoss import tversky_loss | |||
from ..base_line.Resunet.Heavy.Models.resunet import Encoder | |||
from ..base_line.Resunet.Heavy.Models.resunet import Decoder | |||
class BIAS_MLP(Model): | |||
def __init__(self) -> None: | |||
super(BIAS_MLP, self).__init__() | |||
self.MLP = nn.ModuleList([ | |||
nn.Conv2d(2,16,1), | |||
nn.ReLU(), | |||
nn.Conv2d(16,8,1), | |||
nn.ReLU(), | |||
nn.Conv2d(8,2,1), | |||
]) | |||
def forward(self,x): | |||
out = x | |||
for module in self.MLP: | |||
out = module(out) | |||
return out | |||
class RESUNET(Model): | |||
def __init__(self, | |||
gamma = 4 , smooth=1, alpha=0.7, beta=0.3, params = None) -> None: | |||
super(RESUNET, self).__init__() | |||
self.gamma = gamma | |||
self.smooth = smooth | |||
self.alpha = alpha | |||
self.beta = beta | |||
print(self.gamma) | |||
print("gamma") | |||
self.MLP = BIAS_MLP() | |||
self.encoder = Encoder(params) | |||
self.decoder = Decoder(params) | |||
self.epochNUM = 0 | |||
self.forward_call = 0 | |||
self.mode = True | |||
def train(self,mode : bool=True): | |||
super().train(mode) | |||
self.epochNUM += 1 | |||
self.mode = mode | |||
if mode: | |||
self.forward_call = 0 | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_loss_and_gt: torch.Tensor) -> ModelIO: | |||
"""first forwards supervised samples through student net, then all samples through both nets. """ | |||
output = dict(loss=0.0) | |||
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)" | |||
loss_type = mammo_loss_and_gt[:,0,:,:] | |||
mask = mammo_loss_and_gt[:,1,:,:] | |||
if self.mode: | |||
if (not self.mode) or (int(self.forward_call) % 50 == 0) and(self.forward_call > 10): | |||
print("changing") | |||
print(list(self.encoder.parameters())[0][0]) | |||
network_output = self.decoder(self.encoder(mammo_x)) | |||
network_output_soft = network_output.softmax(dim=1) #torch.exp(network_output).to(mammo_x.device) # softmax output of model | |||
bias_output = self.MLP(network_output) + network_output | |||
bias_output = bias_output.softmax(dim=1) | |||
output['pixel_probs'] = network_output_soft | |||
main_shape = mask.shape | |||
network_output_shape = network_output.shape | |||
loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2]) | |||
mask = mask.unsqueeze(dim = 1) | |||
mask_channels = torch.cat((1-mask, mask), 1) | |||
dummy_mask_l = mask_channels.clone().to(mammo_x.device) | |||
dummy_net_out_l = bias_output.clone().to(mammo_x.device) | |||
dummy_mask_l[loss_channels == 0] = 0.0 | |||
dummy_shape_l = dummy_mask_l.shape | |||
dummy_mask_l = torch.cat((dummy_mask_l,torch.ones((dummy_shape_l[0],1,dummy_shape_l[2],dummy_shape_l[2])).to(mammo_x.device)), 1) | |||
dummy_net_out_l = torch.cat((dummy_net_out_l,torch.ones((dummy_shape_l[0],1,dummy_shape_l[2],dummy_shape_l[2])).to(mammo_x.device)), 1) | |||
focal_loss_l = balanced_focal_cross_entropy_loss_semi( | |||
dummy_net_out_l, | |||
dummy_mask_l, | |||
focal_gamma=self.gamma) | |||
with torch.no_grad(): | |||
dummy_mask_u = (bias_output >= 0.5).clone().to(mammo_x.device) # TODO sanity check | |||
dummy_net_out_u = bias_output.clone().to(mammo_x.device) | |||
dummy_mask_u[loss_channels == 1.0] = 0.0 | |||
dummy_shape_u = dummy_mask_u.shape | |||
dummy_mask_u = torch.cat((dummy_mask_u,torch.ones((dummy_shape_u[0],1,dummy_shape_u[2],dummy_shape_u[2])).to(mammo_x.device)), 1) | |||
dummy_net_out_u = torch.cat((dummy_net_out_u,torch.ones((dummy_shape_u[0],1,dummy_shape_u[2],dummy_shape_u[2])).to(mammo_x.device)), 1) | |||
focal_loss_u = balanced_focal_cross_entropy_loss_semi( | |||
dummy_net_out_u, | |||
dummy_mask_u, | |||
focal_gamma=self.gamma) | |||
# now two loss for eq 3 are here focal_loss_u & focal_loss_l | |||
# choose balance set | |||
output['loss'] = focal_loss_u + focal_loss_l | |||
output['MYloss'] = output["loss"] | |||
output['focalloss_l'] = focal_loss_l | |||
output['focalloss_u'] = focal_loss_u | |||
output['suploss'] = focal_loss_l | |||
output['unsuploss'] = focal_loss_u | |||
output['org_pixel_labels'] = mammo_loss_and_gt | |||
else: | |||
with torch.no_grad(): # TODO check sanity | |||
#self.encoder. | |||
print("torch no grad") | |||
features = self.encoder(mammo_x) | |||
print(list(self.encoder.parameters())[0][0]) | |||
network_output = self.decoder(features) | |||
network_output_soft = network_output.softmax(dim=1) | |||
output['pixel_probs'] = network_output_soft | |||
main_shape = mask.shape | |||
network_output_shape = network_output.shape | |||
loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2]) | |||
mask = mask.unsqueeze(dim = 1) | |||
mask_channels = torch.cat((1-mask, mask), 1) | |||
dummy_mask_l = mask_channels.clone().to(mammo_x.device) | |||
dummy_net_out_l = network_output_soft.clone().to(mammo_x.device) | |||
balance_mask = loss_channels.clone().to(mammo_x.device) | |||
#balance_mask[loss_channels == 0] = 2 # problem here sometimes all of it are supeprvised so we apply it in all of image. | |||
main_shape = balance_mask.shape | |||
number_1 = torch.sum(balance_mask == 1).to(mammo_x.device) | |||
number_0 = torch.sum(balance_mask == 0).to(mammo_x.device) | |||
random_matrix = torch.rand(main_shape).to(mammo_x.device) | |||
#print(number_0, number_1) | |||
if number_1 > number_0: #choosing balanced version | |||
balance_mask[(balance_mask == 1) & (random_matrix >= number_0/number_1)] = 2 | |||
else: | |||
balance_mask[(balance_mask == 0) & (random_matrix >= number_1/number_0)] = 2 | |||
dummy_mask_l[balance_mask == 2] = 0.0 | |||
dummy_shape_l = dummy_mask_l.shape | |||
dummy_mask_l = torch.cat((dummy_mask_l,torch.ones((dummy_shape_l[0],1,dummy_shape_l[2],dummy_shape_l[2])).to(mammo_x.device)), 1) | |||
dummy_net_out_l = torch.cat((dummy_net_out_l,torch.ones((dummy_shape_l[0],1,dummy_shape_l[2],dummy_shape_l[2])).to(mammo_x.device)), 1) | |||
focal_loss_l = balanced_focal_cross_entropy_loss_semi( | |||
dummy_net_out_l, | |||
dummy_mask_l, | |||
focal_gamma=self.gamma) | |||
output['loss'] = focal_loss_l | |||
output['MYloss'] = output["loss"] | |||
output['focalloss_l'] = focal_loss_l | |||
output['focalloss_u'] = focal_loss_l * 0 | |||
output['suploss'] = focal_loss_l | |||
output['unsuploss'] = focal_loss_l * 0 | |||
output['org_pixel_labels'] = mammo_loss_and_gt | |||
self.forward_call += 1 | |||
return output |
@@ -0,0 +1,180 @@ | |||
from copy import deepcopy | |||
import math | |||
from random import uniform | |||
from typing import List | |||
import time | |||
from mlassistant.core import Model, ModelIO | |||
import torch | |||
from torch.nn import functional as F | |||
from torch import nn | |||
from ....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi | |||
from ....utils.generalized_dice import dice_loss | |||
from ....utils.losses.tverskyLoss import tversky_loss | |||
from ....models.semi_supervised.base_line.Resunet.Heavy.Models.resunet import Encoder | |||
from ....models.semi_supervised.base_line.Resunet.Heavy.Models.resunet import Decoder | |||
class BIAS_MLP(Model): | |||
def __init__(self) -> None: | |||
super(BIAS_MLP, self).__init__() | |||
self.MLP = nn.ModuleList([ | |||
nn.Conv2d(2,16,1), | |||
nn.ReLU(), | |||
nn.Conv2d(16,8,1), | |||
nn.ReLU(), | |||
nn.Conv2d(8,2,1), | |||
]) | |||
def forward(self,x): | |||
out = x | |||
for module in self.MLP: | |||
out = module(out) | |||
return out | |||
class RESUNET(Model): | |||
def __init__(self, | |||
gamma = 4 , smooth=1, alpha=0.7, beta=0.3, params = None) -> None: | |||
super(RESUNET, self).__init__() | |||
self.gamma = gamma | |||
self.smooth = smooth | |||
self.alpha = alpha | |||
self.beta = beta | |||
print(self.gamma) | |||
print("gamma") | |||
self.MLP = BIAS_MLP() | |||
self.encoder = Encoder(params) | |||
self.decoder = Decoder(params) | |||
self.epochNUM = 0 | |||
self.forward_call = 0 | |||
self.mode = True | |||
def train(self,mode : bool=True): | |||
super().train(mode) | |||
self.epochNUM += 1 | |||
self.mode = mode | |||
if mode: | |||
self.forward_call = 0 | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_loss_and_gt: torch.Tensor) -> ModelIO: | |||
"""first forwards supervised samples through student net, then all samples through both nets. """ | |||
output = dict(loss=0.0) | |||
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)" | |||
loss_type = mammo_loss_and_gt[:,0,:,:] | |||
mask = mammo_loss_and_gt[:,1,:,:] | |||
if (self.mode) and int(self.forward_call) % 2 == 0: | |||
# print("second") | |||
# print(list(self.encoder.parameters())[0][0]) | |||
network_output = self.decoder(self.encoder(mammo_x)) | |||
network_output_soft = network_output.softmax(dim=1) #torch.exp(network_output).to(mammo_x.device) # softmax output of model | |||
bias_output = self.MLP(network_output) + network_output | |||
bias_output = bias_output.softmax(dim=1) | |||
output['pixel_probs'] = network_output_soft | |||
main_shape = mask.shape | |||
network_output_shape = network_output.shape | |||
loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2]) | |||
mask = mask.unsqueeze(dim = 1) | |||
mask_channels = torch.cat((1-mask, mask), 1) | |||
dummy_mask_l = mask_channels.clone().to(mammo_x.device) | |||
dummy_net_out_l = bias_output.clone().to(mammo_x.device) | |||
dummy_mask_l[loss_channels == 0] = 0.0 | |||
dummy_shape_l = dummy_mask_l.shape | |||
dummy_mask_l = torch.cat((dummy_mask_l,torch.ones((dummy_shape_l[0],1,dummy_shape_l[2],dummy_shape_l[2])).to(mammo_x.device)), 1) | |||
dummy_net_out_l = torch.cat((dummy_net_out_l,torch.ones((dummy_shape_l[0],1,dummy_shape_l[2],dummy_shape_l[2])).to(mammo_x.device)), 1) | |||
focal_loss_l = balanced_focal_cross_entropy_loss_semi( | |||
dummy_net_out_l, | |||
dummy_mask_l, | |||
focal_gamma=self.gamma) | |||
with torch.no_grad(): | |||
dummy_mask_u = (bias_output >= 0.5).clone().to(mammo_x.device) # TODO sanity check | |||
dummy_net_out_u = bias_output.clone().to(mammo_x.device) | |||
dummy_mask_u[loss_channels == 1.0] = 0.0 | |||
dummy_shape_u = dummy_mask_u.shape | |||
dummy_mask_u = torch.cat((dummy_mask_u,torch.ones((dummy_shape_u[0],1,dummy_shape_u[2],dummy_shape_u[2])).to(mammo_x.device)), 1) | |||
dummy_net_out_u = torch.cat((dummy_net_out_u,torch.ones((dummy_shape_u[0],1,dummy_shape_u[2],dummy_shape_u[2])).to(mammo_x.device)), 1) | |||
focal_loss_u = balanced_focal_cross_entropy_loss_semi( | |||
dummy_net_out_u, | |||
dummy_mask_u, | |||
focal_gamma=self.gamma) | |||
# now two loss for eq 3 are here focal_loss_u & focal_loss_l | |||
# choose balance set | |||
output['loss'] = focal_loss_u + focal_loss_l | |||
output['MYloss'] = output["loss"] | |||
output['focalloss_l'] = focal_loss_l | |||
output['focalloss_u'] = focal_loss_u | |||
output['suploss'] = focal_loss_l | |||
output['unsuploss'] = focal_loss_u | |||
output['org_pixel_labels'] = mammo_loss_and_gt | |||
elif self.mode: | |||
with torch.no_grad(): # TODO check sanity | |||
#self.encoder. | |||
# print("first") | |||
features = self.encoder(mammo_x) | |||
# print(list(self.encoder.parameters())[0][0]) | |||
network_output = self.decoder(features) | |||
network_output_soft = network_output.softmax(dim=1) | |||
output['pixel_probs'] = network_output_soft | |||
main_shape = mask.shape | |||
network_output_shape = network_output.shape | |||
loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2]) | |||
mask = mask.unsqueeze(dim = 1) | |||
mask_channels = torch.cat((1-mask, mask), 1) | |||
dummy_mask_l = mask_channels.clone().to(mammo_x.device) | |||
dummy_net_out_l = network_output_soft.clone().to(mammo_x.device) | |||
balance_mask = loss_channels.clone().to(mammo_x.device) | |||
#balance_mask[loss_channels == 0] = 2 # problem here sometimes all of it are supeprvised so we apply it in all of image. | |||
main_shape = balance_mask.shape | |||
number_1 = torch.sum(balance_mask == 1).to(mammo_x.device) | |||
number_0 = torch.sum(balance_mask == 0).to(mammo_x.device) | |||
random_matrix = torch.rand(main_shape).to(mammo_x.device) | |||
print(number_0, number_1) | |||
if number_1 > number_0: #choosing balanced version | |||
balance_mask[(balance_mask == 1) & (random_matrix >= number_0/number_1)] = 2 | |||
else: | |||
balance_mask[(balance_mask == 0) & (random_matrix >= number_1/number_0)] = 2 | |||
dummy_mask_l[balance_mask == 2] = 0.0 | |||
dummy_shape_l = dummy_mask_l.shape | |||
dummy_mask_l = torch.cat((dummy_mask_l,torch.ones((dummy_shape_l[0],1,dummy_shape_l[2],dummy_shape_l[2])).to(mammo_x.device)), 1) | |||
dummy_net_out_l = torch.cat((dummy_net_out_l,torch.ones((dummy_shape_l[0],1,dummy_shape_l[2],dummy_shape_l[2])).to(mammo_x.device)), 1) | |||
focal_loss_l = balanced_focal_cross_entropy_loss_semi( | |||
dummy_net_out_l, | |||
dummy_mask_l, | |||
focal_gamma=self.gamma) | |||
output['loss'] = focal_loss_l | |||
output['MYloss'] = output["loss"] | |||
output['focalloss_l'] = focal_loss_l | |||
output['focalloss_u'] = focal_loss_l * 0 | |||
output['suploss'] = focal_loss_l | |||
output['unsuploss'] = focal_loss_l * 0 | |||
output['org_pixel_labels'] = mammo_loss_and_gt | |||
else: | |||
network_output = self.decoder(self.encoder(mammo_x)) | |||
network_output_soft = network_output.softmax(dim=1) #torch.exp(network_output).to(mammo_x.device) # softmax output of mode | |||
output['pixel_probs'] = network_output_soft | |||
output['loss'] = torch.tensor(0) | |||
output['MYloss'] = torch.tensor(0) | |||
output['focalloss_l'] = torch.tensor(0) | |||
output['focalloss_u'] = torch.tensor(0) | |||
output['suploss'] = torch.tensor(0) | |||
output['unsuploss'] = torch.tensor(0) | |||
output['org_pixel_labels'] = mammo_loss_and_gt | |||
self.forward_call += 1 | |||
return output |
@@ -0,0 +1,149 @@ | |||
from copy import deepcopy | |||
import math | |||
from random import uniform | |||
import mlassistant | |||
from typing import List | |||
import time | |||
from mlassistant.core import Model, ModelIO | |||
import torch | |||
from torch.nn import functional as F | |||
from torch import nn | |||
from ....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi | |||
from ....utils.generalized_dice import dice_loss | |||
from ....utils.losses.tverskyLoss import tversky_loss | |||
from ....models.semi_supervised.base_line.Resunet.Heavy.Models.resunet import Encoder | |||
from ....models.semi_supervised.base_line.Resunet.Heavy.Models.resunet import Decoder | |||
from .utils import * | |||
from .decoders import * | |||
class CCT(Model): | |||
def __init__(self, | |||
gamma = 4 , smooth=1, alpha=0.7, beta=0.3, params = None , Epochs = 200, iter = 200) -> None: | |||
super(CCT, self).__init__() | |||
self.gamma = gamma | |||
self.smooth = smooth | |||
self.alpha = alpha | |||
self.beta = beta | |||
print(self.gamma) | |||
print("gamma") | |||
self.encoder = Encoder(params) | |||
self.main_decoder = Decoder(params) | |||
self.conf = Config["model"] | |||
self.conft = Config | |||
rampup_ends = int(self.conft['ramp_up'] * Epochs) | |||
cons_w_unsup = consistency_weight(final_w=self.conft['unsupervised_w'], iters_per_epoch=iter, | |||
rampup_ends=rampup_ends) | |||
self.unsup_loss_w = cons_w_unsup | |||
self.unsuper_loss = softmax_mse_loss | |||
self.sup_loss_w = self.conf['supervised_w'] | |||
self.softmax_temp = self.conf['softmax_temp'] | |||
self.sup_type = self.conf['sup_loss'] | |||
# Use weak labels | |||
self.use_weak_lables = self.conft['use_weak_lables'] | |||
self.weakly_loss_w = self.conft['weakly_loss_w'] | |||
# pair wise loss (sup mat) | |||
self.aux_constraint = self.conf['aux_constraint'] | |||
self.aux_constraint_w = self.conf['aux_constraint_w'] | |||
# confidence masking (sup mat) | |||
self.confidence_th = self.conf['confidence_th'] | |||
self.confidence_masking = self.conf['confidence_masking'] | |||
upscale = 8 | |||
num_out_ch = 512 | |||
decoder_in_ch = num_out_ch | |||
num_classes = 2 | |||
# The auxilary decoders | |||
vat_decoder = [VATDecoder(upscale, decoder_in_ch, num_classes, xi=self.conf['xi'], | |||
eps=self.conf['eps']) for _ in range(self.conf['vat'])] | |||
drop_decoder = [DropOutDecoder(upscale, decoder_in_ch, num_classes, | |||
drop_rate=self.conf['drop_rate'], spatial_dropout=self.conf['spatial']) | |||
for _ in range(self.conf['drop'])] | |||
cut_decoder = [CutOutDecoder(upscale, decoder_in_ch, num_classes, erase=self.conf['erase']) | |||
for _ in range(self.conf['cutout'])] | |||
context_m_decoder = [ContextMaskingDecoder(upscale, decoder_in_ch, num_classes) | |||
for _ in range(self.conf['context_masking'])] | |||
object_masking = [ObjectMaskingDecoder(upscale, decoder_in_ch, num_classes) | |||
for _ in range(self.conf['object_masking'])] | |||
feature_drop = [FeatureDropDecoder(upscale, decoder_in_ch, num_classes) | |||
for _ in range(self.conf['feature_drop'])] | |||
feature_noise = [FeatureNoiseDecoder(upscale, decoder_in_ch, num_classes, | |||
uniform_range=self.conf['uniform_range']) | |||
for _ in range(self.conf['feature_noise'])] | |||
# self.aux_decoders = nn.ModuleList([*vat_decoder, *drop_decoder, *cut_decoder, | |||
# *context_m_decoder, *object_masking, *feature_drop, *feature_noise]) | |||
self.aux_decoders = nn.ModuleList([*drop_decoder,*vat_decoder,*feature_noise, | |||
*context_m_decoder, *object_masking, *feature_drop]) | |||
self.iter = 0 | |||
self.mode = True | |||
def train(self,mode : bool=True): | |||
super().train(mode) | |||
self.mode = mode | |||
self.iter = 0 | |||
print("epoch:",mlassistant.context.epoch.EpochContext.epoch) | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_loss_and_gt: torch.Tensor) -> ModelIO: | |||
"""first forwards supervised samples through student net, then all samples through both nets. """ | |||
self.iter = 1 | |||
output = dict(loss=0.0) | |||
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)" | |||
loss_type = mammo_loss_and_gt[:,0,:,:] | |||
mask = mammo_loss_and_gt[:,1,:,:] | |||
curr_iter= self.iter | |||
epoch= mlassistant.context.epoch.EpochContext.epoch | |||
network_output = self.main_decoder(self.encoder(mammo_x)) | |||
network_output_soft = network_output.softmax(dim=1) | |||
output['pixel_probs'] = network_output_soft | |||
output_ul = F.interpolate(network_output_soft, size=(256,256), mode='bilinear') | |||
# Get auxiliary predictions | |||
outputs_ul = [aux_decoder(self.encoder(mammo_x)[-1], output_ul.detach()) for aux_decoder in self.aux_decoders] | |||
targets = F.softmax(output_ul.detach(), dim=1) | |||
# Compute unsupervised loss | |||
loss_unsup = sum([self.unsuper_loss(inputs=u, targets=targets, \ | |||
conf_mask=self.confidence_masking, threshold=self.confidence_th, use_softmax=False) | |||
for u in outputs_ul]) | |||
main_shape = mask.shape | |||
network_output_shape = network_output.shape | |||
loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2]) | |||
mask = mask.unsqueeze(dim = 1) | |||
mask_channels = torch.cat((1-mask, mask), 1) | |||
loss_unsup = (loss_unsup / len(outputs_ul)) | |||
dummy_mask = mask_channels.clone().to(mammo_x.device) | |||
dummy_net_out = network_output_soft.clone().to(mammo_x.device) | |||
dummy_mask[loss_channels == 0] = 0.0 | |||
dummy_shape = dummy_mask.shape | |||
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_net_out = torch.cat((dummy_net_out,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
focal_loss = balanced_focal_cross_entropy_loss_semi( | |||
dummy_net_out, | |||
dummy_mask, | |||
focal_gamma=self.gamma) | |||
loss_sup = focal_loss | |||
weight_u = self.unsup_loss_w(epoch=epoch, curr_iter=curr_iter) | |||
output['loss'] = loss_sup * self.sup_loss_w + loss_unsup * weight_u | |||
output['MYloss'] = output["loss"] | |||
output['focalloss'] = focal_loss | |||
output['suploss'] = loss_sup | |||
output['unsuploss'] = loss_unsup | |||
output['sup_loss_wloss'] = torch.tensor(self.sup_loss_w) | |||
output['lossweight_u'] = torch.tensor(weight_u) | |||
output['org_pixel_labels'] = mammo_loss_and_gt | |||
self.iter += 1 | |||
return output |
@@ -0,0 +1,258 @@ | |||
import math , time | |||
import torch | |||
import torch.nn.functional as F | |||
from torch import nn | |||
from itertools import chain | |||
import contextlib | |||
import random | |||
import numpy as np | |||
import cv2 | |||
from torch.distributions.uniform import Uniform | |||
def icnr(x, scale=2, init=nn.init.kaiming_normal_): | |||
""" | |||
Checkerboard artifact free sub-pixel convolution | |||
https://arxiv.org/abs/1707.02937 | |||
""" | |||
ni,nf,h,w = x.shape | |||
ni2 = int(ni/(scale**2)) | |||
k = init(torch.zeros([ni2,nf,h,w])).transpose(0, 1) | |||
k = k.contiguous().view(ni2, nf, -1) | |||
k = k.repeat(1, 1, scale**2) | |||
k = k.contiguous().view([nf,ni,h,w]).transpose(0, 1) | |||
x.data.copy_(k) | |||
class PixelShuffle(nn.Module): | |||
""" | |||
Real-Time Single Image and Video Super-Resolution | |||
https://arxiv.org/abs/1609.05158 | |||
""" | |||
def __init__(self, n_channels, scale): | |||
super(PixelShuffle, self).__init__() | |||
self.conv = nn.Conv2d(n_channels, n_channels*(scale**2), kernel_size=1) | |||
icnr(self.conv.weight) | |||
self.shuf = nn.PixelShuffle(scale) | |||
self.relu = nn.ReLU(inplace=True) | |||
def forward(self,x): | |||
x = self.shuf(self.relu(self.conv(x))) | |||
return x | |||
def upsample(in_channels, out_channels, upscale, kernel_size=3): | |||
# A series of x 2 upsamling until we get to the upscale we want | |||
layers = [] | |||
conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) | |||
nn.init.kaiming_normal_(conv1x1.weight.data, nonlinearity='relu') | |||
layers.append(conv1x1) | |||
for i in range(int(math.log(upscale, 2))): | |||
layers.append(PixelShuffle(out_channels, scale=2)) | |||
return nn.Sequential(*layers) | |||
class MainDecoder(nn.Module): | |||
def __init__(self, upscale, conv_in_ch, num_classes): | |||
super(MainDecoder, self).__init__() | |||
self.upsample = upsample(conv_in_ch, num_classes, upscale=upscale) | |||
def forward(self, x): | |||
x = self.upsample(x) | |||
return x | |||
class DropOutDecoder(nn.Module): | |||
def __init__(self, upscale, conv_in_ch, num_classes, drop_rate=0.3, spatial_dropout=True): | |||
super(DropOutDecoder, self).__init__() | |||
self.dropout = nn.Dropout2d(p=drop_rate) if spatial_dropout else nn.Dropout(drop_rate) | |||
self.upsample = upsample(conv_in_ch, num_classes, upscale=upscale) | |||
def forward(self, x, _): | |||
x = self.upsample(self.dropout(x)) | |||
return x | |||
class FeatureDropDecoder(nn.Module): | |||
def __init__(self, upscale, conv_in_ch, num_classes): | |||
super(FeatureDropDecoder, self).__init__() | |||
self.upsample = upsample(conv_in_ch, num_classes, upscale=upscale) | |||
def feature_dropout(self, x): | |||
attention = torch.mean(x, dim=1, keepdim=True) | |||
max_val, _ = torch.max(attention.view(x.size(0), -1), dim=1, keepdim=True) | |||
threshold = max_val * np.random.uniform(0.7, 0.9) | |||
threshold = threshold.view(x.size(0), 1, 1, 1).expand_as(attention) | |||
drop_mask = (attention < threshold).float() | |||
return x.mul(drop_mask) | |||
def forward(self, x, _): | |||
x = self.feature_dropout(x) | |||
x = self.upsample(x) | |||
return x | |||
class FeatureNoiseDecoder(nn.Module): | |||
def __init__(self, upscale, conv_in_ch, num_classes, uniform_range=0.3): | |||
super(FeatureNoiseDecoder, self).__init__() | |||
self.upsample = upsample(conv_in_ch, num_classes, upscale=upscale) | |||
self.uni_dist = Uniform(-uniform_range, uniform_range) | |||
def feature_based_noise(self, x): | |||
noise_vector = self.uni_dist.sample(x.shape[1:]).to(x.device).unsqueeze(0) | |||
x_noise = x.mul(noise_vector) + x | |||
return x_noise | |||
def forward(self, x, _): | |||
x = self.feature_based_noise(x) | |||
x = self.upsample(x) | |||
return x | |||
def _l2_normalize(d): | |||
# Normalizing per batch axis | |||
d_reshaped = d.view(d.shape[0], -1, *(1 for _ in range(d.dim() - 2))) | |||
d /= torch.norm(d_reshaped, dim=1, keepdim=True) + 1e-8 | |||
return d | |||
def get_r_adv(x, decoder, it=1, xi=1e-1, eps=10.0): | |||
""" | |||
Virtual Adversarial Training | |||
https://arxiv.org/abs/1704.03976 | |||
""" | |||
x_detached = x.detach() | |||
with torch.no_grad(): | |||
pred = F.softmax(decoder(x_detached), dim=1) | |||
d = torch.rand(x.shape).sub(0.5).to(x.device) | |||
d = _l2_normalize(d) | |||
for _ in range(it): | |||
d.requires_grad_() | |||
pred_hat = decoder(x_detached + xi * d) | |||
logp_hat = F.log_softmax(pred_hat, dim=1) | |||
adv_distance = F.kl_div(logp_hat, pred, reduction='batchmean') | |||
adv_distance.backward() | |||
d = _l2_normalize(d.grad) | |||
decoder.zero_grad() | |||
r_adv = d * eps | |||
return r_adv | |||
class VATDecoder(nn.Module): | |||
def __init__(self, upscale, conv_in_ch, num_classes, xi=1e-1, eps=10.0, iterations=1): | |||
super(VATDecoder, self).__init__() | |||
self.xi = xi | |||
self.eps = eps | |||
self.it = iterations | |||
self.upsample = upsample(conv_in_ch, num_classes, upscale=upscale) | |||
def forward(self, x, _): | |||
r_adv = get_r_adv(x, self.upsample, self.it, self.xi, self.eps) | |||
x = self.upsample(x + r_adv) | |||
return x | |||
def guided_cutout(output, upscale, resize, erase=0.4, use_dropout=False): | |||
if len(output.shape) == 3: | |||
masks = (output > 0).float() | |||
else: | |||
masks = (output.argmax(1) > 0).float() | |||
if use_dropout: | |||
p_drop = random.randint(3, 6)/10 | |||
maskdroped = (F.dropout(masks, p_drop) > 0).float() | |||
maskdroped = maskdroped + (1 - masks) | |||
maskdroped.unsqueeze_(0) | |||
maskdroped = F.interpolate(maskdroped, size=resize, mode='nearest') | |||
masks_np = [] | |||
for mask in masks: | |||
mask_np = np.uint8(mask.cpu().numpy()) | |||
mask_ones = np.ones_like(mask_np) | |||
try: # Version 3.x | |||
_, contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |||
except: # Version 4.x | |||
contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |||
polys = [c.reshape(c.shape[0], c.shape[-1]) for c in contours if c.shape[0] > 50] | |||
for poly in polys: | |||
min_w, max_w = poly[:, 0].min(), poly[:, 0].max() | |||
min_h, max_h = poly[:, 1].min(), poly[:, 1].max() | |||
bb_w, bb_h = max_w-min_w, max_h-min_h | |||
rnd_start_w = random.randint(0, int(bb_w*(1-erase))) | |||
rnd_start_h = random.randint(0, int(bb_h*(1-erase))) | |||
h_start, h_end = min_h+rnd_start_h, min_h+rnd_start_h+int(bb_h*erase) | |||
w_start, w_end = min_w+rnd_start_w, min_w+rnd_start_w+int(bb_w*erase) | |||
mask_ones[h_start:h_end, w_start:w_end] = 0 | |||
masks_np.append(mask_ones) | |||
masks_np = np.stack(masks_np) | |||
maskcut = torch.from_numpy(masks_np).float().unsqueeze_(1) | |||
maskcut = F.interpolate(maskcut, size=resize, mode='nearest') | |||
if use_dropout: | |||
return maskcut.to(output.device), maskdroped.to(output.device) | |||
return maskcut.to(output.device) | |||
class CutOutDecoder(nn.Module): | |||
def __init__(self, upscale, conv_in_ch, num_classes, drop_rate=0.3, spatial_dropout=True, erase=0.4): | |||
super(CutOutDecoder, self).__init__() | |||
self.erase = erase | |||
self.upscale = upscale | |||
self.upsample = upsample(conv_in_ch, num_classes, upscale=upscale) | |||
def forward(self, x, pred=None): | |||
maskcut = guided_cutout(pred, upscale=self.upscale, erase=self.erase, resize=(x.size(2), x.size(3))) | |||
x = x * maskcut | |||
x = self.upsample(x) | |||
return x | |||
def guided_masking(x, output, upscale, resize, return_msk_context=True): | |||
if len(output.shape) == 3: | |||
masks_context = (output > 0).float().unsqueeze(1) | |||
else: | |||
masks_context = (output.argmax(1) > 0).float().unsqueeze(1) | |||
masks_context = F.interpolate(masks_context, size=resize, mode='nearest') | |||
x_masked_context = masks_context * x | |||
if return_msk_context: | |||
return x_masked_context | |||
masks_objects = (1 - masks_context) | |||
x_masked_objects = masks_objects * x | |||
return x_masked_objects | |||
class ContextMaskingDecoder(nn.Module): | |||
def __init__(self, upscale, conv_in_ch, num_classes): | |||
super(ContextMaskingDecoder, self).__init__() | |||
self.upscale = upscale | |||
self.upsample = upsample(conv_in_ch, num_classes, upscale=upscale) | |||
def forward(self, x, pred=None): | |||
x_masked_context = guided_masking(x, pred, resize=(x.size(2), x.size(3)), | |||
upscale=self.upscale, return_msk_context=True) | |||
x_masked_context = self.upsample(x_masked_context) | |||
return x_masked_context | |||
class ObjectMaskingDecoder(nn.Module): | |||
def __init__(self, upscale, conv_in_ch, num_classes): | |||
super(ObjectMaskingDecoder, self).__init__() | |||
self.upscale = upscale | |||
self.upsample = upsample(conv_in_ch, num_classes, upscale=upscale) | |||
def forward(self, x, pred=None): | |||
x_masked_obj = guided_masking(x, pred, resize=(x.size(2), x.size(3)), | |||
upscale=self.upscale, return_msk_context=False) | |||
x_masked_obj = self.upsample(x_masked_obj) | |||
return x_masked_obj |
@@ -0,0 +1,32 @@ | |||
import numpy as np | |||
def sigmoid_rampup(current, rampup_length): | |||
if rampup_length == 0: | |||
return 1.0 | |||
current = np.clip(current, 0.0, rampup_length) | |||
phase = 1.0 - current / rampup_length | |||
return float(np.exp(-5.0 * phase * phase)) | |||
def linear_rampup(current, rampup_length): | |||
assert current >= 0 and rampup_length >= 0 | |||
if current >= rampup_length: | |||
return 1.0 | |||
return current / rampup_length | |||
def cosine_rampup(current, rampup_length): | |||
if rampup_length == 0: | |||
return 1.0 | |||
current = np.clip(current, 0.0, rampup_length) | |||
return 1 - float(.5 * (np.cos(np.pi * current / rampup_length) + 1)) | |||
def log_rampup(current, rampup_length): | |||
if rampup_length == 0: | |||
return 1.0 | |||
current = np.clip(current, 0.0, rampup_length) | |||
return float(1- np.exp(-5.0 * current / rampup_length)) | |||
def exp_rampup(current, rampup_length): | |||
if rampup_length == 0: | |||
return 1.0 | |||
current = np.clip(current, 0.0, rampup_length) | |||
return float(np.exp(5.0 * (current / rampup_length - 1))) |
@@ -0,0 +1,412 @@ | |||
Config = { | |||
"name": "CCT", | |||
"experim_name": "CCT", | |||
"n_gpu": 1, | |||
"n_labeled_examples": 1464, | |||
"diff_lrs": True, | |||
"ramp_up": 0.1, | |||
"unsupervised_w": 30, | |||
"ignore_index": 255, | |||
"lr_scheduler": "Poly", | |||
"use_weak_lables":False, | |||
"weakly_loss_w": 0.4, | |||
"pretrained": True, | |||
"model":{ | |||
"supervised": False, | |||
"semi": True, | |||
"supervised_w": 1, | |||
"sup_loss": "CE", | |||
"un_loss": "MSE", | |||
"softmax_temp": 1, | |||
"aux_constraint": False, | |||
"aux_constraint_w": 1, | |||
"confidence_masking": False, | |||
"confidence_th": 0.5, | |||
"drop": 6, | |||
"drop_rate": 0.5, | |||
"spatial": True, | |||
"cutout": 6, | |||
"erase": 0.4, | |||
"vat": 2, | |||
"xi": 1e-6, | |||
"eps": 2.0, | |||
"context_masking": 2, | |||
"object_masking": 2, | |||
"feature_drop": 6, | |||
"feature_noise": 6, | |||
"uniform_range": 0.3 | |||
}, | |||
"optimizer": { | |||
"type": "SGD", | |||
"args":{ | |||
"lr": 1e-2, | |||
"weight_decay": 1e-4, | |||
"momentum": 0.9 | |||
} | |||
}, | |||
# "train_supervised": { | |||
# "data_dir": "VOCtrainval_11-May-2012", | |||
# "batch_size": 10, | |||
# "crop_size": 320, | |||
# "shuffle": True, | |||
# "base_size": 400, | |||
# "scale": True, | |||
# "augment": True, | |||
# "flip": True, | |||
# "rotate": False, | |||
# "blur": False, | |||
# "split": "train_supervised", | |||
# "num_workers": 8 | |||
# }, | |||
# "train_unsupervised": { | |||
# "data_dir": "VOCtrainval_11-May-2012", | |||
# "weak_labels_output": "pseudo_labels/result/pseudo_labels", | |||
# "batch_size": 10, | |||
# "crop_size": 320, | |||
# "shuffle": True, | |||
# "base_size": 400, | |||
# "scale": True, | |||
# "augment": True, | |||
# "flip": True, | |||
# "rotate": False, | |||
# "blur": False, | |||
# "split": "train_unsupervised", | |||
# "num_workers": 8 | |||
# }, | |||
# "val_loader": { | |||
# "data_dir": "VOCtrainval_11-May-2012", | |||
# "batch_size": 1, | |||
# "val": True, | |||
# "split": "val", | |||
# "shuffle": False, | |||
# "num_workers": 4 | |||
# }, | |||
# "trainer": { | |||
# "epochs": 80, | |||
# "save_dir": "saved/", | |||
# "save_period": 5, | |||
# "monitor": "max Mean_IoU", | |||
# "early_stop": 10, | |||
# "tensorboardX": True, | |||
# "log_dir": "saved/", | |||
# "log_per_iter": 20, | |||
# "val": True, | |||
# "val_per_epochs": 5 | |||
# } | |||
} | |||
import numpy as np | |||
import torch | |||
import torch.nn.functional as F | |||
import torch.nn as nn | |||
def sigmoid_rampup(current, rampup_length): | |||
if rampup_length == 0: | |||
return 1.0 | |||
current = np.clip(current, 0.0, rampup_length) | |||
phase = 1.0 - current / rampup_length | |||
return float(np.exp(-5.0 * phase * phase)) | |||
class consistency_weight(object): | |||
""" | |||
ramp_types = ['sigmoid_rampup', 'linear_rampup', 'cosine_rampup', 'log_rampup', 'exp_rampup'] | |||
""" | |||
def __init__(self, final_w, iters_per_epoch, rampup_starts=0, rampup_ends=7, ramp_type='sigmoid_rampup'): | |||
self.final_w = final_w | |||
self.iters_per_epoch = iters_per_epoch | |||
self.rampup_starts = rampup_starts * iters_per_epoch | |||
self.rampup_ends = rampup_ends * iters_per_epoch | |||
self.rampup_length = (self.rampup_ends - self.rampup_starts) | |||
self.rampup_func = sigmoid_rampup | |||
self.current_rampup = 0 | |||
def __call__(self, epoch, curr_iter): | |||
cur_total_iter = self.iters_per_epoch * epoch + curr_iter | |||
if cur_total_iter < self.rampup_starts: | |||
return 0 | |||
self.current_rampup = self.rampup_func(cur_total_iter - self.rampup_starts, self.rampup_length) | |||
return self.final_w * self.current_rampup | |||
def CE_loss(input_logits, target_targets, ignore_index, temperature=1): | |||
return F.cross_entropy(input_logits/temperature, target_targets, ignore_index=ignore_index) | |||
# for FocalLoss | |||
def softmax_helper(x): | |||
# copy from: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/utilities/nd_softmax.py | |||
rpt = [1 for _ in range(len(x.size()))] | |||
rpt[1] = x.size(1) | |||
x_max = x.max(1, keepdim=True)[0].repeat(*rpt) | |||
e_x = torch.exp(x - x_max) | |||
return e_x / e_x.sum(1, keepdim=True).repeat(*rpt) | |||
def get_alpha(supervised_loader): | |||
# get number of classes | |||
num_labels = 0 | |||
for image_batch, label_batch in supervised_loader: | |||
label_batch.data[label_batch.data==255] = 0 # pixels of ignore class added to background | |||
l_unique = torch.unique(label_batch.data) | |||
list_unique = [element.item() for element in l_unique.flatten()] | |||
num_labels = max(max(list_unique),num_labels) | |||
num_classes = num_labels + 1 | |||
# count class occurrences | |||
alpha = [0 for i in range(num_classes)] | |||
for image_batch, label_batch in supervised_loader: | |||
label_batch.data[label_batch.data==255] = 0 # pixels of ignore class added to background | |||
l_unique = torch.unique(label_batch.data) | |||
list_unique = [element.item() for element in l_unique.flatten()] | |||
l_unique_count = torch.stack([(label_batch.data==x_u).sum() for x_u in l_unique]) # tensor([65920, 36480]) | |||
list_count = [count.item() for count in l_unique_count.flatten()] | |||
for index in list_unique: | |||
alpha[index] += list_count[list_unique.index(index)] | |||
return alpha | |||
# for FocalLoss | |||
def softmax_helper(x): | |||
# copy from: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/utilities/nd_softmax.py | |||
rpt = [1 for _ in range(len(x.size()))] | |||
rpt[1] = x.size(1) | |||
x_max = x.max(1, keepdim=True)[0].repeat(*rpt) | |||
e_x = torch.exp(x - x_max) | |||
return e_x / e_x.sum(1, keepdim=True).repeat(*rpt) | |||
class FocalLoss(nn.Module): | |||
""" | |||
copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py | |||
This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in | |||
'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)' | |||
Focal_Loss= -1*alpha*(1-pt)*log(pt) | |||
:param num_class: | |||
:param alpha: (tensor) 3D or 4D the scalar factor for this criterion | |||
:param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more | |||
focus on hard misclassified example | |||
:param smooth: (float,double) smooth value when cross entropy | |||
:param balance_index: (int) balance class index, should be specific when alpha is float | |||
:param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch. | |||
""" | |||
def __init__(self, apply_nonlin=None, ignore_index = None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True): | |||
super(FocalLoss, self).__init__() | |||
self.apply_nonlin = apply_nonlin | |||
self.alpha = alpha | |||
self.gamma = gamma | |||
self.balance_index = balance_index | |||
self.smooth = smooth | |||
self.size_average = size_average | |||
if self.smooth is not None: | |||
if self.smooth < 0 or self.smooth > 1.0: | |||
raise ValueError('smooth value should be in [0,1]') | |||
def forward(self, logit, target): | |||
if self.apply_nonlin is not None: | |||
logit = self.apply_nonlin(logit) | |||
num_class = logit.shape[1] | |||
if logit.dim() > 2: | |||
# N,C,d1,d2 -> N,C,m (m=d1*d2*...) | |||
logit = logit.view(logit.size(0), logit.size(1), -1) | |||
logit = logit.permute(0, 2, 1).contiguous() | |||
logit = logit.view(-1, logit.size(-1)) | |||
target = torch.squeeze(target, 1) | |||
target = target.view(-1, 1) | |||
valid_mask = None | |||
if self.ignore_index is not None: | |||
valid_mask = target != self.ignore_index | |||
target = target * valid_mask | |||
alpha = self.alpha | |||
if alpha is None: | |||
alpha = torch.ones(num_class, 1) | |||
elif isinstance(alpha, (list, np.ndarray)): | |||
assert len(alpha) == num_class | |||
alpha = torch.FloatTensor(alpha).view(num_class, 1) | |||
alpha = alpha / alpha.sum() | |||
alpha = 1/alpha # inverse of class frequency | |||
elif isinstance(alpha, float): | |||
alpha = torch.ones(num_class, 1) | |||
alpha = alpha * (1 - self.alpha) | |||
alpha[self.balance_index] = self.alpha | |||
else: | |||
raise TypeError('Not support alpha type') | |||
if alpha.device != logit.device: | |||
alpha = alpha.to(logit.device) | |||
idx = target.cpu().long() | |||
one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_() | |||
# to resolve error in idx in scatter_ | |||
idx[idx==225]=0 | |||
one_hot_key = one_hot_key.scatter_(1, idx, 1) | |||
if one_hot_key.device != logit.device: | |||
one_hot_key = one_hot_key.to(logit.device) | |||
if self.smooth: | |||
one_hot_key = torch.clamp( | |||
one_hot_key, self.smooth/(num_class-1), 1.0 - self.smooth) | |||
pt = (one_hot_key * logit).sum(1) + self.smooth | |||
logpt = pt.log() | |||
gamma = self.gamma | |||
alpha = alpha[idx] | |||
alpha = torch.squeeze(alpha) | |||
loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt | |||
if valid_mask is not None: | |||
loss = loss * valid_mask.squeeze() | |||
if self.size_average: | |||
loss = loss.mean() | |||
else: | |||
loss = loss.sum() | |||
return loss | |||
class abCE_loss(nn.Module): | |||
""" | |||
Annealed-Bootstrapped cross-entropy loss | |||
""" | |||
def __init__(self, iters_per_epoch, epochs, num_classes, weight=None, | |||
reduction='mean', thresh=0.7, min_kept=1, ramp_type='log_rampup'): | |||
super(abCE_loss, self).__init__() | |||
self.weight = torch.FloatTensor(weight) if weight is not None else weight | |||
self.reduction = reduction | |||
self.thresh = thresh | |||
self.min_kept = min_kept | |||
self.ramp_type = ramp_type | |||
if ramp_type is not None: | |||
self.rampup_func = getattr(ramps, ramp_type) | |||
self.iters_per_epoch = iters_per_epoch | |||
self.num_classes = num_classes | |||
self.start = 1/num_classes | |||
self.end = 0.9 | |||
self.total_num_iters = (epochs - (0.6 * epochs)) * iters_per_epoch | |||
def threshold(self, curr_iter, epoch): | |||
cur_total_iter = self.iters_per_epoch * epoch + curr_iter | |||
current_rampup = self.rampup_func(cur_total_iter, self.total_num_iters) | |||
return current_rampup * (self.end - self.start) + self.start | |||
def forward(self, predict, target, ignore_index, curr_iter, epoch): | |||
batch_kept = self.min_kept * target.size(0) | |||
prob_out = F.softmax(predict, dim=1) | |||
tmp_target = target.clone() | |||
tmp_target[tmp_target == ignore_index] = 0 | |||
prob = prob_out.gather(1, tmp_target.unsqueeze(1)) | |||
mask = target.contiguous().view(-1, ) != ignore_index | |||
sort_prob, sort_indices = prob.contiguous().view(-1, )[mask].contiguous().sort() | |||
if self.ramp_type is not None: | |||
thresh = self.threshold(curr_iter=curr_iter, epoch=epoch) | |||
else: | |||
thresh = self.thresh | |||
min_threshold = sort_prob[min(batch_kept, sort_prob.numel() - 1)] if sort_prob.numel() > 0 else 0.0 | |||
threshold = max(min_threshold, thresh) | |||
loss_matrix = F.cross_entropy(predict, target, | |||
weight=self.weight.to(predict.device) if self.weight is not None else None, | |||
ignore_index=ignore_index, reduction='none') | |||
loss_matirx = loss_matrix.contiguous().view(-1, ) | |||
sort_loss_matirx = loss_matirx[mask][sort_indices] | |||
select_loss_matrix = sort_loss_matirx[sort_prob < threshold] | |||
if self.reduction == 'sum' or select_loss_matrix.numel() == 0: | |||
return select_loss_matrix.sum() | |||
elif self.reduction == 'mean': | |||
return select_loss_matrix.mean() | |||
else: | |||
raise NotImplementedError('Reduction Error!') | |||
def softmax_mse_loss(inputs, targets, conf_mask=False, threshold=None, use_softmax=False): | |||
assert inputs.requires_grad == True and targets.requires_grad == False | |||
assert inputs.size() == targets.size() # (batch_size * num_classes * H * W) | |||
inputs = F.softmax(inputs, dim=1) | |||
if use_softmax: | |||
targets = F.softmax(targets, dim=1) | |||
if conf_mask: | |||
loss_mat = F.mse_loss(inputs, targets, reduction='none') | |||
mask = (targets.max(1)[0] > threshold) | |||
loss_mat = loss_mat[mask.unsqueeze(1).expand_as(loss_mat)] | |||
if loss_mat.shape.numel() == 0: loss_mat = torch.tensor([0.]).to(inputs.device) | |||
return loss_mat.mean() | |||
else: | |||
return F.mse_loss(inputs, targets, reduction='mean') # take the mean over the batch_size | |||
def softmax_kl_loss(inputs, targets, conf_mask=False, threshold=None, use_softmax=False): | |||
assert inputs.requires_grad == True and targets.requires_grad == False | |||
assert inputs.size() == targets.size() | |||
input_log_softmax = F.log_softmax(inputs, dim=1) | |||
if use_softmax: | |||
targets = F.softmax(targets, dim=1) | |||
if conf_mask: | |||
loss_mat = F.kl_div(input_log_softmax, targets, reduction='none') | |||
mask = (targets.max(1)[0] > threshold) | |||
loss_mat = loss_mat[mask.unsqueeze(1).expand_as(loss_mat)] | |||
if loss_mat.shape.numel() == 0: loss_mat = torch.tensor([0.]).to(inputs.device) | |||
return loss_mat.sum() / mask.shape.numel() | |||
else: | |||
return F.kl_div(input_log_softmax, targets, reduction='mean') | |||
def softmax_js_loss(inputs, targets, **_): | |||
assert inputs.requires_grad == True and targets.requires_grad == False | |||
assert inputs.size() == targets.size() | |||
epsilon = 1e-5 | |||
M = (F.softmax(inputs, dim=1) + targets) * 0.5 | |||
kl1 = F.kl_div(F.log_softmax(inputs, dim=1), M, reduction='mean') | |||
kl2 = F.kl_div(torch.log(targets+epsilon), M, reduction='mean') | |||
return (kl1 + kl2) * 0.5 | |||
def pair_wise_loss(unsup_outputs, size_average=True, nbr_of_pairs=8): | |||
""" | |||
Pair-wise loss in the sup. mat. | |||
""" | |||
if isinstance(unsup_outputs, list): | |||
unsup_outputs = torch.stack(unsup_outputs) | |||
# Only for a subset of the aux outputs to reduce computation and memory | |||
unsup_outputs = unsup_outputs[torch.randperm(unsup_outputs.size(0))] | |||
unsup_outputs = unsup_outputs[:nbr_of_pairs] | |||
temp = torch.zeros_like(unsup_outputs) # For grad purposes | |||
for i, u in enumerate(unsup_outputs): | |||
temp[i] = F.softmax(u, dim=1) | |||
mean_prediction = temp.mean(0).unsqueeze(0) # Mean over the auxiliary outputs | |||
pw_loss = ((temp - mean_prediction)**2).mean(0) # Variance | |||
pw_loss = pw_loss.sum(1) # Sum over classes | |||
if size_average: | |||
return pw_loss.mean() | |||
return pw_loss.sum() |
@@ -1,2 +1,20 @@ | |||
# Enhancing_Injury_Segmentation_in_Breast_Mammograms_Through_Semi-Supervised_Learning | |||
In order to run the Base lines tables in the Thesis you just need to go to base_line folder. Every one of the methods has a folder. | |||
ResUnet has Implementation with different loss functions in it. You can use these models as your models for retrieving previous results for your special use. You can also use other SSL methods same way as before. | |||
For you to train your model you just would need to use forward pass of the model. You will see the below function prototype in all the methods. Now we will explain this function. | |||
```python | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_loss_and_gt: torch.Tensor) | |||
``` | |||
mammo_x is actually the [Batch, size, size] which is x-ray image that we use as input data. | |||
mammo_loss_and_gt is [Batch, 2, size, size]. mammo_loss_and_gt[:, 1, :, :] is the mask for our x-ray, referred as GT in Thesis. mammo_loss_and_gt[:, 0, :, :] is actually the DCM mask mentioned in the thesis. | |||
By providing the inputs for forward function you can use all the models code easily in your projects. | |||
Good Luck! |
@@ -0,0 +1,15 @@ | |||
In order to run the Base lines tables in the Thesis you just need to go to base_line folder. Every one of the methods has a folder. | |||
ResUnet has Implementation with different loss functions in it. You can use these models as your models for retrieving previous results for your special use. You can also use other SSL methods same way as before. | |||
For you to train your model you just would need to use forward pass of the model. You will see the below function prototype in all the methods. Now we will explain this function. | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_loss_and_gt: torch.Tensor) | |||
mammo_x is actually the [Batch, size, size] which is x-ray image that we use as input data. | |||
mammo_loss_and_gt is [Batch, 2, size, size]. mammo_loss_and_gt[:, 1, :, :] is the mask for our x-ray, referred as GT in Thesis. mammo_loss_and_gt[:, 0, :, :] is actually the DCM mask mentioned in the thesis. | |||
By providing the inputs for forward function you can use all the models code easily in your projects. | |||
Good Luck! |
@@ -0,0 +1,50 @@ | |||
from turtle import forward | |||
from mlassistant.core import Model | |||
from torch import nn | |||
from mlassistant.gan.models import GAN, BaseGenerator, BaseDiscriminator | |||
from .entropy_minimization import EntropyMinimization | |||
class ConvBlock(nn.Module): | |||
def __init__(self, num_classes, ndf) -> None: | |||
super().__init__() | |||
self.conv2d = nn.Conv2d(num_classes, ndf, kernel_size=4, stride=2, padding=1), | |||
self.activation = nn.LeakyReLU(negative_slope=0.2, inplace=True) | |||
def forward(self, x): | |||
x = self.conv2d(x) | |||
return self.activation(x) | |||
class Discriminator(BaseDiscriminator): # the discriminator | |||
def __init__(self, num_classes, ndf=64) -> None: | |||
super().__init__() | |||
self.net = nn.Sequential( | |||
ConvBlock(num_classes, ndf), | |||
ConvBlock(ndf, 2 * ndf), | |||
ConvBlock(2 * ndf, 4 * ndf), | |||
ConvBlock(4 * ndf, 8 * ndf), | |||
ConvBlock(8 * ndf , 1), | |||
) | |||
def forward(self, x): | |||
return x | |||
class Generator(BaseGenerator): # segnet | |||
def __init__(self, conf): | |||
super().__init__(conf) | |||
self.net = EntropyMinimization() | |||
def forward(self, x): | |||
segmentation_net_output = self.net(x) | |||
return segmentation_net_output | |||
class AdventGan(GAN): # the gan | |||
def __init__(self, d_name: str): | |||
super().__init__(Generator(), {d_name: Discriminator()}) |
@@ -0,0 +1,246 @@ | |||
import torch | |||
import torch as tf | |||
from torch.nn import functional as F | |||
import torch.nn as nn | |||
from .deeplabv2 import DeepLabV2 | |||
from .losses import * | |||
from mlassistant.core import Model | |||
import sys | |||
from ...full_segmentation.utils import sum_roi_4ch_to_1ch | |||
from torchvision.utils import draw_segmentation_masks | |||
from ....utils.mask_processes.resize_binarize_roi import resize_binarize_roi_torch | |||
from ....enums import * | |||
from .msc import MSC | |||
from ....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi | |||
from ....utils.dice import dice_loss | |||
images_path = './../../images/' | |||
import time | |||
class EntropyMinimization(Model): | |||
def train(self,mode : bool=True): | |||
super().train(mode) | |||
self.epochNUM += 1 | |||
print(self.epochNUM) | |||
def __init__(self, num_classes=2, _lambda=1.0, input_size=64 , gamma=2.0,ent_factor = 0.001, **kwargs): | |||
super().__init__() | |||
self.base_network = DeepLabV2S_ResNet101_MSC(n_classes=num_classes) | |||
self.num_classes = num_classes | |||
self._lambda = _lambda | |||
self.batch_count = 0 | |||
self.input_size = input_size | |||
self.gamma = gamma | |||
# self.activation = nn.ReLU() | |||
self.softmax = nn.Softmax(dim=1) | |||
self.epochNUM = 0 | |||
self.ent_factor = ent_factor | |||
self.time_c = time.time() | |||
def interpolate(self,shapes,arr): | |||
segmentation_label = arr | |||
output = dict() | |||
if len(segmentation_label.shape) == 2: | |||
segmentation_label = segmentation_label.unsqueeze(0) | |||
segmentation_labels = [] | |||
segmentation_dict = dict() | |||
for i, shape in enumerate(shapes): | |||
start_dim = int(len(shape) - 3) | |||
key = '_'.join([str(i) for i in shape]) | |||
if key not in segmentation_dict: | |||
segmentation = F.interpolate( | |||
segmentation_label.unsqueeze(1), | |||
size=(shape[-2], shape[-1]), | |||
mode="bilinear", | |||
align_corners=False).squeeze() | |||
if len(segmentation.shape) == 2: | |||
segmentation = segmentation.unsqueeze(0) | |||
output['org_pixel_labels_' + str(i)] = segmentation.clone() | |||
segmentation = torch.flatten(segmentation, start_dim=start_dim) | |||
segmentation_dict[key] = segmentation | |||
else: # already interpolated! | |||
segmentation = segmentation_dict[key] | |||
segmentation_labels.append(segmentation) | |||
final_segmentation_label = torch.cat(segmentation_labels, dim=start_dim) | |||
return final_segmentation_label, output | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_loss_and_gt: torch.Tensor): | |||
""" | |||
""" | |||
#print(mammo_roi_x.shape) | |||
#mammo_loss_and_gt = torch.cat((mammo_roi_x[:,0,:,:].unsqueeze(1), mammo_roi_x[:,0,:,:].unsqueeze(1)), 1) | |||
#print(mammo_loss_and_gt.shape) | |||
#mammo_loss_and_gt = | |||
#start = time.time() | |||
#print("between two forwards = ", start - self.time_c, flush= True) | |||
output = dict() | |||
#print() | |||
#print(mammo_x.shape,flush=True) | |||
if mammo_loss_and_gt is not None: | |||
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)" | |||
loss_type_lable = mammo_loss_and_gt[:,0,:,:] # 1 means supervised and zero means unsupervised | |||
mask = mammo_loss_and_gt[:,1,:,:] # 1 means cancer and zero means background | |||
#deel_time = time.time() | |||
logits , network_output, shapes = self.base_network(mammo_x) | |||
#print("base model took= ",time.time() - deel_time, flush=True) | |||
output['pixel_probs'] = self.softmax(logits) | |||
network_output_soft = self.softmax(network_output) | |||
loss_type, o1 = self.interpolate(shapes, loss_type_lable) | |||
mask, o2 = self.interpolate(shapes, mask) | |||
main_shape = loss_type.shape | |||
network_output_shape = network_output_soft.shape | |||
dup_loss_type = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1]) | |||
mask = mask.unsqueeze(dim = 1) | |||
dup_mask = torch.cat((1-mask, mask), 1) | |||
output = {**output, **o1, **o2} | |||
supervised_loss = torch.tensor(0.0).to(mammo_x.device) | |||
yhat = (network_output_soft * dup_loss_type).unsqueeze(-1) | |||
y = (dup_mask * dup_loss_type).unsqueeze(-1) | |||
focal_mask = dup_mask.clone().to(mammo_x.device) | |||
focal_net_out = torch.exp(torch.log(network_output_soft).to(mammo_x.device)).to(mammo_x.device) | |||
focal_mask[dup_loss_type == 0] = 0.0 | |||
#focal_net_out[dup_loss_type == 0] = 0.0 | |||
focal_shape = focal_mask.shape | |||
focal_mask = torch.cat((focal_mask,torch.ones((focal_shape[0],1,focal_shape[2])).to(mammo_x.device)), 1) | |||
focal_net_out = torch.cat((focal_net_out,torch.ones((focal_shape[0],1,focal_shape[2])).to(mammo_x.device)), 1) | |||
supervised_loss += balanced_focal_cross_entropy_loss_semi( | |||
focal_net_out.unsqueeze(-1), | |||
focal_mask.unsqueeze(-1), | |||
focal_gamma=self.gamma) | |||
entropy_loss_val = entropy_loss_normalized(network_output_soft)* (1-loss_type) | |||
entropy_loss_val = torch.mean(entropy_loss_val.flatten()[((1-loss_type) == 1).flatten()]) | |||
net_out_sup = yhat.flatten()[(dup_loss_type == 1).flatten()] | |||
mask_sup = mask.flatten()[(loss_type == 1).flatten()] | |||
dice_loss_val = dice_loss((net_out_sup.reshape(2,-1).unsqueeze(dim = 0).unsqueeze(dim = -1)),mask_sup.unsqueeze(dim = 0).unsqueeze(dim = -1).to(torch.int64)) | |||
output['suploss'] = supervised_loss | |||
output['focalloss'] = supervised_loss | |||
output['dice_loss'] = dice_loss_val | |||
output["loss"] = supervised_loss + self.ent_factor * entropy_loss_val | |||
output['org_pixel_labels'] = mammo_loss_and_gt | |||
output['ent_loss'] = entropy_loss_val | |||
output['ce_loss'] = F.binary_cross_entropy(yhat,y) | |||
output['ce_0_loss'] = F.binary_cross_entropy(yhat * (1-y) + y,y) # y = 1 -> y / y = 0 -> yhat | |||
output['ce_1_loss'] = F.binary_cross_entropy(yhat * y,y) # y = 0 -> y / y = 1 -> yhat | |||
#print(output) | |||
else: | |||
output['loss'] = torch.tensor(0.0) | |||
output['org_pixel_labels'] = mammo_loss_and_gt | |||
output['ent_loss'] = torch.tensor(0.0) | |||
output['ce_loss'] = torch.tensor(0.0) | |||
output['ce_0_loss'] = torch.tensor(0.0) | |||
output['ce_1_loss'] = torch.tensor(0.0) | |||
output['dice_loss'] = torch.tensor(0.0) | |||
output['suploss'] = torch.tensor(0.0) | |||
output['focalloss'] = torch.tensor(0.0) | |||
#print("whole forward took =", time.time() - start, flush = True) | |||
#self.time_c = time.time() | |||
return output | |||
def DeepLabV2S_ResNet101_MSC(n_classes): | |||
return MSC( | |||
base=DeepLabV2(n_classes=n_classes, n_blocks=[3, 8, 20, 3], atrous_rates=[3, 6, 9, 12]), | |||
scales=[0.5, 0.75]) | |||
def balanced_focal_cross_entropy_localization_loss( | |||
probs:torch.Tensor, gt: torch.Tensor, focal_gamma: float = 1) -> torch.Tensor: | |||
""" | |||
Calculates class balanced cross entropy loss weighted with the style of Focal loss. | |||
mean_over_classes(1/mean(focal_weights) * (focal_weights * cross_entropy_loss)) | |||
focal_weight: (1 - p_of_the_related_class)^gamma | |||
Args: | |||
probs (torch.Tensor): Probabilities assigned by the model of shape B C ... | |||
gt (torch.Tensor): The ground truth containing the number of the class, whether of shape B ... or B C ..., make sure the ... matches in both probs and gt | |||
focal_gamma (float): The power factor used in the weight of focal loss. Default is one which means the typical balanced cross entropy. | |||
Returns: | |||
torch.Tensor: The calculated | |||
""" | |||
localization_mask = (gt == 0 ).to(torch.int64) | |||
probs = probs * localization_mask # just compute loss on outside regions | |||
gt = torch.zeros(gt.shape).to(torch.int64).to(probs.device) # make int zero | |||
# if channel is one in the pixel probs, convert it to the binary mode! | |||
if probs.shape[1] == 1: | |||
probs = torch.cat([1 - probs, probs], dim=1) | |||
assert gt.max() <= probs.shape[1] - 1, f'Expected {probs.shape[1]} classes according to pixel probs\' shape but found {gt.max()} in GT mask' | |||
if len(gt.shape) == len(probs.shape): | |||
assert (gt.shape[1] == 1) or (gt.shape[1] == probs.shape[1]), f'Expected the channel dim to have either one channel or the same as probs while it is of shape {gt.shape} and probs is of shape {probs.shape}' | |||
if gt.shape[1] == 1: | |||
gt = gt[:, 0, ...] | |||
else: | |||
gt = torch.argmax(gt, dim=1) # transferring one-hot to count data | |||
assert len(gt.shape) == (len(probs.shape) - 1), f'Expected GT labels to be of shape B ... and pixel probs of shape B C ..., but received {gt.shape} and {probs.shape} instead.' | |||
# if probabilities and ground truth have different shapes, resize the ground truth | |||
if probs.shape[-2:] != gt.shape[-2:]: | |||
# convert gt to one-hot it before interpolation to prevent mistakes | |||
gt = F.one_hot(gt.long(), probs.shape[1]) # B H W C | |||
gt = torch.permute(gt, [0, -1, 1, 2]) | |||
# binarize | |||
gt = resize_binarize_roi_torch(gt.float(), probs.shape[-2:]) | |||
# convert one-hot to numbers with argmax | |||
gt = torch.argmax(gt, dim=1) # Eliminating the channel | |||
gt = gt.long() # B H W | |||
# flattening and bringing class channel o index 0 | |||
probs = probs.transpose(0, 1).flatten(1) # C N | |||
gt = gt.flatten(0) # N | |||
c = probs.shape[0] | |||
gt_related_prob = probs[gt, torch.arange(gt.shape[0])] | |||
losses = [] | |||
for ci in range(c): | |||
c_probs = gt_related_prob[gt == ci] | |||
if torch.numel(c_probs) > 0: | |||
if focal_gamma == 1: | |||
losses.append(F.binary_cross_entropy( | |||
c_probs, | |||
torch.ones_like(c_probs))) | |||
else: | |||
w = torch.pow(torch.abs(1 - c_probs.detach()), focal_gamma) | |||
losses.append( | |||
(1.0 / (torch.mean(w) + 1e-4)) * | |||
F.binary_cross_entropy( | |||
c_probs, | |||
torch.ones_like(c_probs), | |||
weight=w)) | |||
return torch.mean(torch.stack(losses)) | |||
@@ -0,0 +1,77 @@ | |||
#!/usr/bin/env python | |||
# coding: utf-8 | |||
# | |||
# Author: Kazuto Nakashima | |||
# URL: http://kazuto1011.github.io | |||
# Created: 2017-11-19 | |||
from __future__ import absolute_import, print_function | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from .resnet import ConvBlock, PoolBlock, ResnetBlock,ResnetLayer , ResNet | |||
class AtrousSpatialPyramidPooling(nn.Module): | |||
""" | |||
(ASPP) | |||
""" | |||
def __init__(self, in_channels, out_channels, rates): | |||
super(AtrousSpatialPyramidPooling, self).__init__() | |||
self.conv_layers = nn.ModuleList() | |||
for i, rate in enumerate(rates): | |||
conv_layer = nn.Conv2d(in_channels, out_channels, 3, 1, padding=rate, dilation=rate, bias=True) | |||
torch.nn.init.normal_(conv_layer.weight , 0 , 0.01) | |||
torch.nn.init.normal_(conv_layer.bias , 0 , 0.01) | |||
self.conv_layers.append(conv_layer) | |||
def forward(self, x): | |||
out = torch.cat([layer(x) for layer in self.conv_layers],dim=1) | |||
return out | |||
class DeepLabV2(nn.Module): | |||
def __init__(self, n_classes, n_blocks,atrous_rates): | |||
super(DeepLabV2, self).__init__() | |||
ch = [32 * 2 ** p for p in range(6)] | |||
self.layer1= PoolBlock(ch[0]) | |||
self.layer2= ResnetLayer(n_blocks[0], ch[0], ch[2], 1, 1) | |||
self.layer3= ResnetLayer(n_blocks[1], ch[2], ch[3], 2, 1) | |||
self.layer4= ResnetLayer(n_blocks[2], ch[3], ch[4], 2, 1) | |||
self.atrous = AtrousSpatialPyramidPooling(ch[4], n_classes, atrous_rates) | |||
self.up_sample = nn.Upsample(scale_factor=2) | |||
self.conv2d_transpose1 = nn.ConvTranspose2d(in_channels=8,out_channels=4,kernel_size=3,stride=2) | |||
self.conv2d_transpose2 = nn.ConvTranspose2d(in_channels=4,out_channels=4,kernel_size=3,stride=2) | |||
self.conv2d = nn.Conv2d(in_channels=4,out_channels=2,kernel_size=3,stride=1) | |||
def forward(self, x): | |||
x = self.layer1(x) | |||
x = self.layer2(x) | |||
x = self.layer3(x) | |||
x = self.layer4(x) | |||
x = self.atrous(x) | |||
x = self.up_sample(x) | |||
x = self.conv2d_transpose1(x) | |||
x = self.conv2d_transpose2(x) | |||
x = self.conv2d(x) | |||
return x | |||
def freeze_bn(self): | |||
for m in self.modules(): | |||
if isinstance(m, ConvBlock.BATCH_NORM): | |||
m.eval() | |||
# if __name__ == "__main__": | |||
# model = DeepLabV2( | |||
# n_classes=21, n_blocks=[3, 4, 23, 3], atrous_rates=[6, 12, 18, 24] | |||
# ) | |||
# model.eval() | |||
# image = torch.randn(1, 3, 513, 513) | |||
# print(model) | |||
# print("input:", image.shape) | |||
# print("output:", model(image).shape) |
@@ -0,0 +1,76 @@ | |||
#!/usr/bin/env python | |||
# coding: utf-8 | |||
# | |||
# Author: Kazuto Nakashima | |||
# URL: http://kazuto1011.github.io | |||
# Created: 2017-11-19 | |||
from __future__ import absolute_import, print_function | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from .resnet import ConvBlock, PoolBlock, ResnetBlock,ResnetLayer , ResNet | |||
class AtrousSpatialPyramidPooling(nn.Module): | |||
""" | |||
(ASPP) | |||
""" | |||
def __init__(self, in_channels, out_channels, rates): | |||
super(AtrousSpatialPyramidPooling, self).__init__() | |||
self.conv_layers = nn.ModuleList() | |||
for i, rate in enumerate(rates): | |||
conv_layer = nn.Conv2d(in_channels, out_channels, 3, 1, padding=rate, dilation=rate, bias=True) | |||
torch.nn.init.normal_(conv_layer.weight , 0 , 0.01) | |||
torch.nn.init.normal_(conv_layer.bias , 0 , 0.01) | |||
self.conv_layers.append(conv_layer) | |||
def forward(self, x): | |||
out = torch.cat([layer(x) for layer in self.conv_layers],dim=1) | |||
return out | |||
class DeepLabV2(nn.Module): | |||
def __init__(self, n_classes, n_blocks,atrous_rates): | |||
super(DeepLabV2, self).__init__() | |||
ch = [32 * 2 ** p for p in range(6)] | |||
self.layer1= PoolBlock(ch[0]) | |||
self.layer2= ResnetLayer(n_blocks[0], ch[0], ch[2], 1, 1) | |||
self.layer3= ResnetLayer(n_blocks[1], ch[2], ch[3], 2, 1) | |||
self.layer4= ResnetLayer(n_blocks[2], ch[3], ch[4], 2, 1) | |||
self.atrous = AtrousSpatialPyramidPooling(ch[4], n_classes, atrous_rates) | |||
self.up_sample = nn.Upsample(scale_factor=2) | |||
self.conv2d_transpose = nn.ConvTranspose2d(in_channels=8,out_channels=4,kernel_size=3,stride=2) | |||
self.conv2d_transpose = nn.ConvTranspose2d(in_channels=4,out_channels=4,kernel_size=3,stride=2) | |||
self.conv2d = nn.Conv2d(in_channels=4,out_channels=2,kernel_size=3,stride=1) | |||
def forward(self, x): | |||
x = self.layer1(x) | |||
x = self.layer2(x) | |||
x = self.layer3(x) | |||
x = self.layer4(x) | |||
x = self.atrous(x) | |||
x = self.up_sample(x) | |||
x = self.conv2d_transpose(x) | |||
x = self.conv2d(x) | |||
return x | |||
def freeze_bn(self): | |||
for m in self.modules(): | |||
if isinstance(m, ConvBlock.BATCH_NORM): | |||
m.eval() | |||
# if __name__ == "__main__": | |||
# model = DeepLabV2( | |||
# n_classes=21, n_blocks=[3, 4, 23, 3], atrous_rates=[6, 12, 18, 24] | |||
# ) | |||
# model.eval() | |||
# image = torch.randn(1, 3, 513, 513) | |||
# print(model) | |||
# print("input:", image.shape) | |||
# print("output:", model(image).shape) |
@@ -0,0 +1,82 @@ | |||
#!/usr/bin/env python | |||
# coding: utf-8 | |||
# | |||
# Author: Kazuto Nakashima | |||
# URL: http://kazuto1011.github.io | |||
# Created: 2017-11-19 | |||
from __future__ import absolute_import, print_function | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from torch.nn import Dropout | |||
from .resnet import ConvBlock, PoolBlock, ResnetBlock,ResnetLayer , ResNet | |||
class AtrousSpatialPyramidPooling(nn.Module): | |||
""" | |||
(ASPP) | |||
""" | |||
def __init__(self, in_channels, out_channels, rates): | |||
super(AtrousSpatialPyramidPooling, self).__init__() | |||
self.conv_layers = nn.ModuleList() | |||
for i, rate in enumerate(rates): | |||
conv_layer = nn.Conv2d(in_channels, out_channels, 3, 1, padding=rate, dilation=rate, bias=True) | |||
torch.nn.init.normal_(conv_layer.weight , 0 , 0.01) | |||
torch.nn.init.normal_(conv_layer.bias , 0 , 0.01) | |||
self.conv_layers.append(conv_layer) | |||
def forward(self, x): | |||
out = torch.cat([layer(x) for layer in self.conv_layers],dim=1) | |||
return out | |||
class DeepLabV2(nn.Module): | |||
def __init__(self, n_classes, n_blocks,atrous_rates): | |||
super(DeepLabV2, self).__init__() | |||
ch = [32 * 2 ** p for p in range(6)] | |||
self.dropout = Dropout(0.1) | |||
self.layer1= PoolBlock(ch[0]) | |||
self.layer2= ResnetLayer(n_blocks[0], ch[0], ch[2], 1, 1) | |||
self.layer3= ResnetLayer(n_blocks[1], ch[2], ch[3], 2, 1) | |||
self.layer4= ResnetLayer(n_blocks[2], ch[3], ch[4], 2, 1) | |||
self.atrous = AtrousSpatialPyramidPooling(ch[4], n_classes, atrous_rates) | |||
self.up_sample = nn.Upsample(scale_factor=2) | |||
self.conv2d_transpose1 = nn.ConvTranspose2d(in_channels=8,out_channels=4,kernel_size=3,stride=2) | |||
self.conv2d_transpose2 = nn.ConvTranspose2d(in_channels=4,out_channels=4,kernel_size=3,stride=2) | |||
self.conv2d = nn.Conv2d(in_channels=4,out_channels=2,kernel_size=3,stride=1) | |||
def forward(self, x): | |||
x = self.layer1(x) | |||
x = self.layer2(x) | |||
x = self.dropout(x) | |||
x = self.layer3(x) | |||
x = self.layer4(x) | |||
x = self.dropout(x) | |||
x = self.atrous(x) | |||
x = self.up_sample(x) | |||
x = self.dropout(x) | |||
x = self.conv2d_transpose1(x) | |||
x = self.conv2d_transpose2(x) | |||
x = self.dropout(x) | |||
x = self.conv2d(x) | |||
return x | |||
def freeze_bn(self): | |||
for m in self.modules(): | |||
if isinstance(m, ConvBlock.BATCH_NORM): | |||
m.eval() | |||
# if __name__ == "__main__": | |||
# model = DeepLabV2( | |||
# n_classes=21, n_blocks=[3, 4, 23, 3], atrous_rates=[6, 12, 18, 24] | |||
# ) | |||
# model.eval() | |||
# image = torch.randn(1, 3, 513, 513) | |||
# print(model) | |||
# print("input:", image.shape) | |||
# print("output:", model(image).shape) |
@@ -0,0 +1,99 @@ | |||
import torch | |||
import torch.nn as nn | |||
from .deeplabv2 import DeepLabV2 | |||
from .losses import * | |||
from mlassistant.core import Model | |||
from ...full_segmentation.utils import sum_roi_4ch_to_1ch | |||
from torchvision.utils import draw_segmentation_masks | |||
from ....enums import * | |||
from typing import List | |||
from .msc import MSC | |||
from ....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss | |||
images_path = './../../images/' | |||
class EntropyMinimization(Model): | |||
def __init__(self, | |||
num_classes=2, | |||
_lambda=1.0, | |||
input_size=64, | |||
gamma=1.0, | |||
multi_head: torch.nn.Module = MSC, | |||
main_model: torch.nn.Module = DeepLabV2, | |||
atrous_rates: List[int] = [3, 6, 9, 12], | |||
**kwargs): | |||
super().__init__() | |||
self.multi_head = multi_head | |||
self.main_model = main_model | |||
self.atrous_rates = atrous_rates | |||
self.base_network = self.DeepLabV2S_ResNet101_MSC(n_classes=num_classes) | |||
self.num_classes = num_classes | |||
self._lambda = _lambda | |||
self.batch_count = 0 | |||
self.input_size = input_size | |||
self.gamma = gamma | |||
# self.activation = nn.ReLU() | |||
self.softmax = nn.Softmax(dim=1) | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_roi_x_exist: torch.Tensor, | |||
mammo_roi_x: torch.Tensor = None): | |||
output = dict() | |||
logits, network_output, shapes = self.base_network(mammo_x) | |||
output['pixel_probs'] = self.softmax(logits) | |||
if mammo_roi_x is not None: | |||
mammo_roi_x_exist = mammo_roi_x_exist.bool() | |||
segmentation_label = sum_roi_4ch_to_1ch(mammo_roi_x, | |||
ROIAggregation.mass_vs_others).squeeze() | |||
print(segmentation_label.shape) | |||
if len(segmentation_label.shape) == 2: | |||
segmentation_label = segmentation_label.unsqueeze(0) | |||
segmentation_labels = [] | |||
segmentation_dict = dict() | |||
for i, shape in enumerate(shapes): | |||
start_dim = int(len(shape) - 3) | |||
key = '_'.join([str(i) for i in shape]) | |||
if key not in segmentation_dict: | |||
segmentation = F.interpolate( | |||
segmentation_label.unsqueeze(1), | |||
size=(shape[-2], shape[-1]), | |||
mode="bilinear", | |||
align_corners=False).squeeze() | |||
if len(segmentation.shape) == 2: | |||
segmentation = segmentation.unsqueeze(0) | |||
output['org_pixel_labels_' + str(i)] = segmentation.clone() | |||
segmentation = torch.flatten(segmentation, start_dim=start_dim) | |||
segmentation_dict[key] = segmentation | |||
else: # already interpolated! | |||
segmentation = segmentation_dict[key] | |||
segmentation_labels.append(segmentation) | |||
final_segmentation_label = torch.cat(segmentation_labels, dim=start_dim) | |||
# ent_loss = torch.tensor(0.0).to(mammo_x.device) | |||
segmentation_loss = torch.tensor(0.0).to(mammo_x.device) | |||
network_output = self.softmax(network_output) | |||
# unsupervised_outputs = network_output[~mammo_roi_x_exist] | |||
# if len(unsupervised_outputs) > 0: | |||
# ent_loss = ent_loss + torch.mean(entropy_loss(unsupervised_outputs)) | |||
supervised_outputs, rois = network_output[mammo_roi_x_exist], final_segmentation_label[ | |||
mammo_roi_x_exist] | |||
if len(supervised_outputs) > 0: | |||
segmentation_loss = segmentation_loss + balanced_focal_cross_entropy_loss( | |||
supervised_outputs.unsqueeze(-1), | |||
rois.to(torch.int64).unsqueeze(-1), | |||
focal_gamma=self.gamma) | |||
output['org_pixel_labels'] = output['org_pixel_labels_0'] | |||
loss = segmentation_loss # + self._lambda * torch.mean(ent_loss) | |||
output['loss'] = loss | |||
else: | |||
output['loss'] = torch.tensor(0.0) | |||
output['org_pixel_labels'] = output['pixel_probs'] | |||
return output | |||
def DeepLabV2S_ResNet101_MSC(self, n_classes): | |||
return self.multi_head( | |||
base=self.main_model( | |||
n_classes=n_classes, n_blocks=[3, 4, 23, 3], atrous_rates=self.atrous_rates), | |||
scales=[0.5, 0.75]) |
@@ -0,0 +1,93 @@ | |||
from turtle import forward | |||
import torch | |||
import torch.nn as nn | |||
from .deeplabv2 import DeepLabV2 | |||
from .losses import * | |||
from mlassistant.core import Model | |||
from ...full_segmentation.utils import sum_roi_4ch_to_1ch | |||
from torchvision.utils import draw_segmentation_masks | |||
from ....enums import * | |||
from .msc import MSC | |||
from ....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss | |||
images_path = './../../images/' | |||
class EntropyMinimization(Model): | |||
def __init__(self, num_classes=2, _lambda=1.0, input_size=64, **kwargs): | |||
super().__init__() | |||
self.base_network = DeepLabV2S_ResNet101_MSC(n_classes=num_classes) | |||
self.num_classes = num_classes | |||
self._lambda = _lambda | |||
self.batch_count = 0 | |||
self.input_size = input_size | |||
# self.activation = nn.ReLU() | |||
self.softmax = nn.Softmax(dim=1) | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_roi_x_exist: torch.Tensor, | |||
mammo_roi_x: torch.Tensor = None): | |||
output = dict() | |||
mammo_roi_x_exist = mammo_roi_x_exist.bool() | |||
mammo_x = mammo_x.repeat(1, 3, 1, 1) | |||
segmentation_label = sum_roi_4ch_to_1ch(mammo_roi_x, ROIAggregation.mass_vs_others).squeeze() | |||
logits , network_output, shapes = self.base_network(mammo_x) | |||
# print(shapes) | |||
segmentation_labels = [] | |||
segmentation_dict = dict() | |||
for i,shape in enumerate(shapes): | |||
key = '_'.join([str(i) for i in shape]) | |||
if key not in segmentation_dict: | |||
segmentation = F.interpolate( | |||
segmentation_label.unsqueeze(1), | |||
size=(shape[-2], shape[-1]), | |||
mode="bilinear", | |||
align_corners=False).squeeze() | |||
output['org_pixel_labels_' + str(i)] = segmentation.clone() | |||
segmentation = torch.flatten(segmentation,start_dim=1) | |||
segmentation_dict[key] = segmentation | |||
else: # already interpolated! | |||
segmentation = segmentation_dict[key] | |||
segmentation_labels.append(segmentation) | |||
final_segmentation_label = torch.cat(segmentation_labels, dim=1) | |||
ent_loss = torch.tensor(0.0).to(mammo_x.device) | |||
segmentation_loss = torch.tensor(0.0).to(mammo_x.device) | |||
network_output = self.softmax(network_output) | |||
unsupervised_outputs = network_output[~mammo_roi_x_exist] | |||
if len(unsupervised_outputs) > 0: | |||
ent_loss = ent_loss + torch.mean(entropy_loss(unsupervised_outputs)) | |||
output['pixel_probs'] = self.softmax(logits) | |||
supervised_outputs, rois = network_output[mammo_roi_x_exist], final_segmentation_label[mammo_roi_x_exist] | |||
if len(supervised_outputs) > 0: | |||
segmentation_loss = segmentation_loss + balanced_focal_cross_entropy_loss( | |||
supervised_outputs.unsqueeze(-1), rois.to(torch.int64).unsqueeze(-1), focal_gamma=1) | |||
# for i in range(len(mammo_x)): | |||
# if mammo_roi_x_exist[i]: | |||
# for threshold in range(20): | |||
# binary_threshold = (threshold + 1) * 0.05 | |||
# image = (mammo_x[i] * 255).to(torch.uint8).cpu() | |||
# if len(image.shape)< 3: | |||
# image = image.unsqueeze(0) | |||
# # image | |||
# new_label = (segmentation_label[i]>0.5).bool().unsqueeze(0).cpu() | |||
# # new label | |||
# new_prediction = (output['pixel_probs'][i][1].clone().detach()>binary_threshold).bool().unsqueeze(0).cpu() | |||
# # new_pred | |||
# masks = torch.cat([new_label,new_prediction]).cpu() | |||
# colors = ["green","red"] | |||
# alpha = 0.7 | |||
# final_image = draw_segmentation_masks(image, masks, alpha, colors).transpose(0,1).transpose(1,2) | |||
# Image.fromarray(final_image.cpu().numpy()).save( | |||
# os.path.join(images_path, | |||
# str(self.batch_count) + '_' + str(i) + '_' + str(binary_threshold) + '_image_mask.png')) | |||
output['org_pixel_labels'] = output['org_pixel_labels_2'] | |||
loss = segmentation_loss + self._lambda * torch.mean(ent_loss) | |||
output['loss'] = loss | |||
return output | |||
def DeepLabV2S_ResNet101_MSC(n_classes): | |||
return MSC( | |||
base=DeepLabV2(n_classes=n_classes, n_blocks=[3, 8, 5, 3], atrous_rates=[3, 6, 9, 12]), | |||
scales=[0.5, 8.0]) |
@@ -0,0 +1,161 @@ | |||
import torch | |||
import torch.nn as nn | |||
from .deeplabv2 import DeepLabV2 | |||
from .losses import * | |||
from mlassistant.core import Model | |||
from ...full_segmentation.utils import sum_roi_4ch_to_1ch | |||
from torchvision.utils import draw_segmentation_masks | |||
from ....utils.mask_processes.resize_binarize_roi import resize_binarize_roi_torch | |||
from ....enums import * | |||
from .msc import MSC | |||
from ....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss | |||
images_path = './../../images/' | |||
class EntropyMinimization(Model): | |||
def __init__(self, num_classes=2, _lambda=1.0, input_size=64 , gamma=1.0, **kwargs): | |||
super().__init__() | |||
self.base_network = DeepLabV2S_ResNet101_MSC(n_classes=num_classes) | |||
self.num_classes = num_classes | |||
self._lambda = _lambda | |||
self.batch_count = 0 | |||
self.input_size = input_size | |||
self.gamma = gamma | |||
# self.activation = nn.ReLU() | |||
self.softmax = nn.Softmax(dim=1) | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_roi_x_exist: torch.Tensor, | |||
mammo_roi_x: torch.Tensor = None): | |||
output = dict() | |||
logits , network_output, shapes = self.base_network(mammo_x) | |||
output['pixel_probs'] = self.softmax(logits) | |||
if mammo_roi_x is not None: | |||
mammo_roi_x_exist = mammo_roi_x_exist.bool() | |||
segmentation_label = sum_roi_4ch_to_1ch(mammo_roi_x, ROIAggregation.mass_vs_others).squeeze() | |||
if len(segmentation_label.shape) == 2: | |||
segmentation_label = segmentation_label.unsqueeze(0) | |||
segmentation_labels = [] | |||
segmentation_dict = dict() | |||
for i,shape in enumerate(shapes): | |||
start_dim = int(len(shape) - 3) | |||
key = '_'.join([str(i) for i in shape]) | |||
if key not in segmentation_dict: | |||
segmentation = F.interpolate( | |||
segmentation_label.unsqueeze(1), | |||
size=(shape[-2], shape[-1]), | |||
mode="bilinear", | |||
align_corners=False).squeeze() | |||
if len(segmentation.shape) == 2: | |||
segmentation = segmentation.unsqueeze(0) | |||
output['org_pixel_labels_' + str(i)] = segmentation.clone() | |||
segmentation = torch.flatten(segmentation,start_dim=start_dim) | |||
segmentation_dict[key] = segmentation | |||
else: # already interpolated! | |||
segmentation = segmentation_dict[key] | |||
segmentation_labels.append(segmentation) | |||
final_segmentation_label = torch.cat(segmentation_labels, dim=start_dim) | |||
# ent_loss = torch.tensor(0.0).to(mammo_x.device) | |||
segmentation_loss = torch.tensor(0.0).to(mammo_x.device) | |||
network_output = self.softmax(network_output) | |||
unsupervised_outputs = network_output[~mammo_roi_x_exist] | |||
if len(unsupervised_outputs) > 0: | |||
ent_loss = ent_loss + torch.mean(entropy_loss(unsupervised_outputs)) | |||
# apply localization losses | |||
supervised_outputs, localizations = network_output[mammo_roi_x_exist], final_segmentation_label[mammo_roi_x_exist] | |||
if len(supervised_outputs) > 0: | |||
segmentation_loss = segmentation_loss + balanced_focal_cross_entropy_loss( | |||
supervised_outputs.unsqueeze(-1), localizations.to(torch.int64).unsqueeze(-1), focal_gamma=self.gamma) | |||
output['org_pixel_labels'] = output['org_pixel_labels_0'] | |||
loss = segmentation_loss # + self._lambda * torch.mean(ent_loss) | |||
output['loss'] = loss | |||
else: | |||
output['loss'] = torch.tensor(0.0) | |||
output['org_pixel_labels'] = output['pixel_probs'] | |||
return output | |||
def DeepLabV2S_ResNet101_MSC(n_classes): | |||
return MSC( | |||
base=DeepLabV2(n_classes=n_classes, n_blocks=[3, 8, 20, 3], atrous_rates=[3, 6, 9, 12]), | |||
scales=[0.5, 0.75]) | |||
def balanced_focal_cross_entropy_localization_loss( | |||
probs:torch.Tensor, gt: torch.Tensor, focal_gamma: float = 1) -> torch.Tensor: | |||
""" | |||
Calculates class balanced cross entropy loss weighted with the style of Focal loss. | |||
mean_over_classes(1/mean(focal_weights) * (focal_weights * cross_entropy_loss)) | |||
focal_weight: (1 - p_of_the_related_class)^gamma | |||
Args: | |||
probs (torch.Tensor): Probabilities assigned by the model of shape B C ... | |||
gt (torch.Tensor): The ground truth containing the number of the class, whether of shape B ... or B C ..., make sure the ... matches in both probs and gt | |||
focal_gamma (float): The power factor used in the weight of focal loss. Default is one which means the typical balanced cross entropy. | |||
Returns: | |||
torch.Tensor: The calculated | |||
""" | |||
localization_mask = (gt == 0 ).to(torch.int64) | |||
probs = probs * localization_mask # just compute loss on outside regions | |||
gt = torch.zeros(gt.shape).to(torch.int64).to(probs.device) # make int zero | |||
# if channel is one in the pixel probs, convert it to the binary mode! | |||
if probs.shape[1] == 1: | |||
probs = torch.cat([1 - probs, probs], dim=1) | |||
assert gt.max() <= probs.shape[1] - 1, f'Expected {probs.shape[1]} classes according to pixel probs\' shape but found {gt.max()} in GT mask' | |||
if len(gt.shape) == len(probs.shape): | |||
assert (gt.shape[1] == 1) or (gt.shape[1] == probs.shape[1]), f'Expected the channel dim to have either one channel or the same as probs while it is of shape {gt.shape} and probs is of shape {probs.shape}' | |||
if gt.shape[1] == 1: | |||
gt = gt[:, 0, ...] | |||
else: | |||
gt = torch.argmax(gt, dim=1) # transferring one-hot to count data | |||
assert len(gt.shape) == (len(probs.shape) - 1), f'Expected GT labels to be of shape B ... and pixel probs of shape B C ..., but received {gt.shape} and {probs.shape} instead.' | |||
# if probabilities and ground truth have different shapes, resize the ground truth | |||
if probs.shape[-2:] != gt.shape[-2:]: | |||
# convert gt to one-hot it before interpolation to prevent mistakes | |||
gt = F.one_hot(gt.long(), probs.shape[1]) # B H W C | |||
gt = torch.permute(gt, [0, -1, 1, 2]) | |||
# binarize | |||
gt = resize_binarize_roi_torch(gt.float(), probs.shape[-2:]) | |||
# convert one-hot to numbers with argmax | |||
gt = torch.argmax(gt, dim=1) # Eliminating the channel | |||
gt = gt.long() # B H W | |||
# flattening and bringing class channel o index 0 | |||
probs = probs.transpose(0, 1).flatten(1) # C N | |||
gt = gt.flatten(0) # N | |||
c = probs.shape[0] | |||
gt_related_prob = probs[gt, torch.arange(gt.shape[0])] | |||
losses = [] | |||
for ci in range(c): | |||
c_probs = gt_related_prob[gt == ci] | |||
if torch.numel(c_probs) > 0: | |||
if focal_gamma == 1: | |||
losses.append(F.binary_cross_entropy( | |||
c_probs, | |||
torch.ones_like(c_probs))) | |||
else: | |||
w = torch.pow(torch.abs(1 - c_probs.detach()), focal_gamma) | |||
losses.append( | |||
(1.0 / (torch.mean(w) + 1e-4)) * | |||
F.binary_cross_entropy( | |||
c_probs, | |||
torch.ones_like(c_probs), | |||
weight=w)) | |||
return torch.mean(torch.stack(losses)) | |||
@@ -0,0 +1,108 @@ | |||
import torch | |||
import torch.nn as nn | |||
from .deeplabv2 import DeepLabV2 | |||
from .losses import * | |||
from mlassistant.core import Model | |||
from ...full_segmentation.utils import sum_roi_4ch_to_1ch | |||
from torchvision.utils import draw_segmentation_masks | |||
from ....enums import * | |||
from typing import List | |||
from .msc import MSC | |||
from ....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss | |||
from ....augmentations.random_tumor_augmentor import RandomTumorAugmentor | |||
from ....augmentations.random_tumor_coloring import RandomTumorColoring | |||
images_path = './../../images/' | |||
class EntropyMinimization(Model): | |||
def __init__(self, | |||
num_classes=2, | |||
_lambda=1.0, | |||
input_size=64, | |||
gamma=1.0, | |||
multi_head: torch.nn.Module = MSC, | |||
main_model: torch.nn.Module = DeepLabV2, | |||
atrous_rates: List[int] = [3, 6, 9, 12], | |||
**kwargs): | |||
super().__init__() | |||
self.multi_head = multi_head | |||
self.main_model = main_model | |||
self.augmentor = RandomTumorAugmentor(probability=0.5) | |||
self.tumor_coloring = RandomTumorColoring(probability=0.5) | |||
self.atrous_rates = atrous_rates | |||
self.base_network = self.DeepLabV2S_ResNet101_MSC(n_classes=num_classes) | |||
self.num_classes = num_classes | |||
self._lambda = _lambda | |||
self.batch_count = 0 | |||
self.input_size = input_size | |||
self.gamma = gamma | |||
# self.activation = nn.ReLU() | |||
self.softmax = nn.Softmax(dim=1) | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_roi_x_exist: torch.Tensor, | |||
mammo_y_breast_type:torch.Tensor, | |||
mammo_roi_x: torch.Tensor): | |||
output = dict() | |||
if mammo_roi_x is not None: | |||
segmentation_label = sum_roi_4ch_to_1ch(mammo_roi_x, | |||
ROIAggregation.tumor_vs_norm).squeeze() | |||
segmentation_label[segmentation_label>0] = 1 | |||
if len(segmentation_label.shape) == 2: | |||
segmentation_label = segmentation_label.unsqueeze(0) | |||
mammo_x ,segmentation_label = self.tumor_coloring.forward(mammo_x,segmentation_label) | |||
for i in range(len(mammo_x)): | |||
mammo_x[i] ,segmentation_label[i] = self.augmentor.forward(mammo_x[i], segmentation_label[i], int(mammo_y_breast_type[i])) | |||
logits, network_output, shapes = self.base_network(mammo_x) | |||
output['pixel_probs'] = self.softmax(logits) | |||
if mammo_roi_x is not None: | |||
mammo_roi_x_exist = mammo_roi_x_exist.bool() | |||
segmentation_labels = [] | |||
segmentation_dict = dict() | |||
for i, shape in enumerate(shapes): | |||
start_dim = int(len(shape) - 3) | |||
key = '_'.join([str(i) for i in shape]) | |||
if key not in segmentation_dict: | |||
segmentation = F.interpolate( | |||
segmentation_label.unsqueeze(1), | |||
size=(shape[-2], shape[-1]), | |||
mode="bilinear", | |||
align_corners=False).squeeze() | |||
if len(segmentation.shape) == 2: | |||
segmentation = segmentation.unsqueeze(0) | |||
output['org_pixel_labels_' + str(i)] = segmentation.clone() | |||
segmentation = torch.flatten(segmentation, start_dim=start_dim) | |||
segmentation_dict[key] = segmentation | |||
else: # already interpolated! | |||
segmentation = segmentation_dict[key] | |||
segmentation_labels.append(segmentation) | |||
final_segmentation_label = torch.cat(segmentation_labels, dim=start_dim) | |||
# ent_loss = torch.tensor(0.0).to(mammo_x.device) | |||
segmentation_loss = torch.tensor(0.0).to(mammo_x.device) | |||
network_output = self.softmax(network_output) | |||
# unsupervised_outputs = network_output[~mammo_roi_x_exist] | |||
# if len(unsupervised_outputs) > 0: | |||
# ent_loss = ent_loss + torch.mean(entropy_loss(unsupervised_outputs)) | |||
supervised_outputs, rois = network_output[mammo_roi_x_exist], final_segmentation_label[ | |||
mammo_roi_x_exist] | |||
if len(supervised_outputs) > 0: | |||
segmentation_loss = segmentation_loss + balanced_focal_cross_entropy_loss( | |||
supervised_outputs.unsqueeze(-1), | |||
rois.to(torch.int64).unsqueeze(-1), | |||
focal_gamma=self.gamma) | |||
output['org_pixel_labels'] = output['org_pixel_labels_0'] | |||
loss = segmentation_loss # + self._lambda * torch.mean(ent_loss) | |||
output['loss'] = loss | |||
else: | |||
output['loss'] = torch.tensor(0.0) | |||
output['org_pixel_labels'] = output['pixel_probs'] | |||
return output | |||
def DeepLabV2S_ResNet101_MSC(self, n_classes): | |||
return self.multi_head( | |||
base=self.main_model( | |||
n_classes=n_classes, n_blocks=[3, 4, 23, 3], atrous_rates=self.atrous_rates), | |||
scales=[0.5, 0.75]) |
@@ -0,0 +1,52 @@ | |||
# this model is based on this article: ADVENT: Adversarial Entropy Minimization for Domain | |||
# Adaptation in Semantic Segmentation | |||
# https://arxiv.org/abs/1811.12833 | |||
import numpy as np | |||
import torch | |||
import torch.nn.functional as F | |||
from torch.autograd import Variable | |||
def entropy_loss(v): | |||
""" | |||
Entropy loss for probabilistic prediction vectors | |||
input: batch_size x channels x hw | |||
output: batch_size x 1 x hw | |||
""" | |||
assert v.dim() == 3 | |||
n, c, hw = v.size() | |||
denominator = n * hw * np.log2(c) | |||
return -torch.sum(torch.mul(v, torch.log2(v+1e-30)),dim=1) / denominator | |||
def entropy_loss_normalized(v): | |||
""" | |||
Entropy loss for probabilistic prediction vectors | |||
input: batch_size x channels x hw | |||
output: batch_size x 1 x hw | |||
""" | |||
assert v.dim() == 3 | |||
n, c, hw = v.size() | |||
return -torch.sum(torch.mul(v, torch.log2(v+1e-30)),dim=1) / np.log2(c) | |||
def cross_entropy_2d(predict, target): | |||
""" | |||
Args: | |||
predict:(n, c, h, w) | |||
target:(n, h, w) | |||
""" | |||
assert not target.requires_grad | |||
assert predict.dim() == 4 | |||
assert target.dim() == 3 | |||
assert predict.size(0) == target.size(0), f"{predict.size(0)} vs {target.size(0)}" | |||
assert predict.size(2) == target.size(1), f"{predict.size(2)} vs {target.size(1)}" | |||
assert predict.size(3) == target.size(2), f"{predict.size(3)} vs {target.size(3)}" | |||
n, c, h, w = predict.size() | |||
target_mask = (target >= 0) * (target != 255) | |||
target = target[target_mask] | |||
if not target.data.dim(): | |||
return Variable(torch.zeros(1)) | |||
predict = predict.transpose(1, 2).transpose(2, 3).contiguous() | |||
predict = predict[target_mask.view(n, h, w, 1).repeat(1, 1, 1, c)].view(-1, c) | |||
loss = F.cross_entropy(predict, target, size_average=True) | |||
return loss |
@@ -0,0 +1,47 @@ | |||
#!/usr/bin/env python | |||
# coding: utf-8 | |||
# | |||
# Author: Kazuto Nakashima | |||
# URL: http://kazuto1011.github.io | |||
# Created: 2018-03-26 | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
class MSC(nn.Module): | |||
""" | |||
Multi-scale inputs | |||
""" | |||
def __init__(self, base, scales=None): | |||
super(MSC, self).__init__() | |||
self.base = base | |||
if scales: | |||
self.scales = scales | |||
else: | |||
self.scales = [0.5, 0.75] | |||
def forward(self, x): | |||
# Original | |||
logits = self.base(x) | |||
_, _, H, W = logits.shape | |||
interp = lambda l: F.interpolate( | |||
l, size=(H, W), mode="bilinear", align_corners=False | |||
) | |||
# Scaled | |||
logits_pyramid = [] | |||
for scale in self.scales: | |||
h = F.interpolate(x, scale_factor=scale, mode="bilinear", align_corners=False) | |||
logits_pyramid.append(self.base(h)) | |||
# Pixel-wise max | |||
logits_all = [logits] + [interp(l) for l in logits_pyramid] | |||
logits_max = torch.max(torch.stack(logits_all), dim=0)[0] | |||
shapes = [logits.shape] + [out.shape for out in logits_pyramid] + [logits_max.shape] | |||
outputs = [logits] + logits_pyramid + [logits_max] | |||
out = torch.cat([torch.flatten(output,start_dim=2) for output in outputs],dim=2) | |||
return logits , out , shapes |
@@ -0,0 +1,47 @@ | |||
#!/usr/bin/env python | |||
# coding: utf-8 | |||
# | |||
# Author: Kazuto Nakashima | |||
# URL: http://kazuto1011.github.io | |||
# Created: 2018-03-26 | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
class MSC(nn.Module): | |||
""" | |||
Multi-scale inputs | |||
""" | |||
def __init__(self, base, scales=None): | |||
super(MSC, self).__init__() | |||
self.base = base | |||
if scales: | |||
self.scales = scales | |||
else: | |||
self.scales = [0.5, 0.75] | |||
def forward(self, x): | |||
# Original | |||
logits = self.base(x) | |||
_, _, H, W = logits.shape | |||
interp = lambda l: F.interpolate( | |||
l, size=(H, W), mode="bilinear", align_corners=False | |||
) | |||
# Scaled | |||
logits_pyramid = [] | |||
for scale in self.scales: | |||
h = F.interpolate(x, scale_factor=scale, mode="bilinear", align_corners=False) | |||
logits_pyramid.append(self.base(h)) | |||
# Pixel-wise max | |||
logits_all = [logits] + [interp(l) for l in logits_pyramid] | |||
logits_max = torch.max(torch.stack(logits_all), dim=0)[0] | |||
logits_confident = torch.min(torch.stack(logits_all), dim=0)[0] | |||
shapes = [logits.shape] + [out.shape for out in logits_pyramid] + [logits_max.shape] | |||
outputs = [logits] + logits_pyramid + [logits_confident] | |||
out = torch.cat([torch.flatten(output,start_dim=2) for output in outputs],dim=2) | |||
return logits , out , shapes |
@@ -0,0 +1,148 @@ | |||
#!/usr/bin/env python | |||
# coding: utf-8 | |||
# | |||
# Author: Kazuto Nakashima | |||
# URL: http://kazuto1011.github.io | |||
# Created: 2017-11-19 | |||
from __future__ import absolute_import, print_function | |||
from collections import OrderedDict | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from mlassistant.core import Model | |||
try: | |||
from encoding.nn import SyncBatchNorm | |||
_BATCH_NORM = SyncBatchNorm | |||
except: | |||
_BATCH_NORM = nn.BatchNorm2d | |||
_BOTTLENECK_EXPANSION = 4 | |||
class ConvBlock(Model): | |||
""" | |||
Cascade of 2D convolution, batch norm, and ReLU. | |||
""" | |||
BATCH_NORM = _BATCH_NORM | |||
def __init__(self, in_ch, out_ch, kernel_size, stride, padding, dilation, relu=True): | |||
super(ConvBlock, self).__init__() | |||
self.conv = nn.Conv2d( | |||
in_ch, out_ch, kernel_size, stride, padding, dilation, bias=False) | |||
self.bn = _BATCH_NORM(out_ch, eps=1e-5, momentum=1 - 0.999) | |||
self.relu_activation = nn.ReLU() | |||
self.relu = relu | |||
def forward(self, x): | |||
x = self.conv(x) | |||
x = self.bn(x) | |||
if self.relu: | |||
x = self.relu_activation(x) | |||
return x | |||
class ResnetBlock(nn.Module): | |||
""" | |||
Bottleneck block of MSRA ResNet. | |||
""" | |||
def __init__(self, in_ch, out_ch, stride, dilation, downsample): | |||
super(ResnetBlock, self).__init__() | |||
mid_ch = out_ch // _BOTTLENECK_EXPANSION | |||
self.reduce = ConvBlock(in_ch, mid_ch, 1, stride, 0, 1, True) | |||
self.conv3x3 = ConvBlock(mid_ch, mid_ch, 3, 1, dilation, dilation, True) | |||
self.increase = ConvBlock(mid_ch, out_ch, 1, 1, 0, 1, False) | |||
self.shortcut = ( | |||
ConvBlock(in_ch, out_ch, 1, stride, 0, 1, False) if downsample else nn.Identity()) | |||
def forward(self, x): | |||
h = self.reduce(x) | |||
h = self.conv3x3(h) | |||
h = self.increase(h) | |||
h = h + self.shortcut(x) | |||
return F.relu(h) | |||
class ResnetLayer(nn.Module): | |||
""" | |||
Residual layer with multi grids | |||
""" | |||
def __init__(self, n_layers, in_ch, out_ch, stride, dilation, multi_grids=None): | |||
super(ResnetLayer, self).__init__() | |||
self.layers = nn.ModuleList() | |||
if multi_grids is None: | |||
multi_grids = [1 for _ in range(n_layers)] | |||
else: | |||
assert n_layers == len(multi_grids) | |||
for i in range(n_layers): | |||
self.layers.append( | |||
ResnetBlock( | |||
in_ch=(in_ch if i == 0 else out_ch), | |||
out_ch=out_ch, | |||
stride=(stride if i == 0 else 1), | |||
dilation=dilation * multi_grids[i], | |||
downsample=(True if i == 0 else False))) | |||
def forward(self, x): | |||
for layer in self.layers: | |||
x = layer(x) | |||
return x | |||
class PoolBlock(nn.Module): | |||
""" | |||
The 1st conv layer. | |||
Note that the max pooling is different from both MSRA and FAIR ResNet. | |||
""" | |||
def __init__(self, out_ch): | |||
super(PoolBlock, self).__init__() | |||
self.conv1 = ConvBlock(1, out_ch, 7, 2, 3, 1) | |||
self.pool = nn.MaxPool2d(3, 2, 1, ceil_mode=True) | |||
def forward(self, x): | |||
x = self.conv1(x) | |||
x = self.pool(x) | |||
return x | |||
class ResNet(nn.Module): | |||
def __init__(self, n_classes, n_blocks): | |||
super(ResNet, self).__init__() | |||
ch = [64 * 2**p for p in range(6)] | |||
self.layer1 = PoolBlock(ch[0]) | |||
self.layer2 = ResnetLayer(n_blocks[0], ch[0], ch[2], 1, 1) | |||
self.layer3 = ResnetLayer(n_blocks[1], ch[2], ch[3], 2, 1) | |||
self.layer4 = ResnetLayer(n_blocks[2], ch[3], ch[4], 2, 1) | |||
self.pool5 = nn.AdaptiveAvgPool2d(1) | |||
self.flatten = nn.Flatten() | |||
self.fc = nn.Linear(ch[5], n_classes) | |||
def forward(self, x): | |||
x = self.layer1(x) | |||
x = self.layer2(x) | |||
x = self.layer3(x) | |||
x = self.layer4(x) | |||
x = self.pool5(x) | |||
x = self.flatten(x) | |||
x = self.fc(x) | |||
return x | |||
if __name__ == "__main__": | |||
model = ResNet(n_classes=1000, n_blocks=[3, 4, 23, 3]) | |||
model.eval() | |||
image = torch.randn(1, 3, 224, 224) | |||
print(model) | |||
print("input:", image.shape) | |||
print("output:", model(image).shape) |
@@ -0,0 +1,79 @@ | |||
from copy import deepcopy | |||
import math | |||
from random import uniform | |||
from typing import List | |||
import time | |||
from mlassistant.core import Model, ModelIO | |||
import torch | |||
from torch.nn import functional as F | |||
from torch import nn | |||
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi | |||
from .....utils.generalized_dice import dice_loss | |||
from .....utils.losses.tverskyLoss import tversky_loss | |||
from .....utils.dynamic_num_generator.num_generator import init_num_generator1, init_num_generator2 | |||
from .....utils.dynamic_num_generator.wrapper import dynamic_num_generator | |||
class ConvUnextv1(Model): | |||
def __init__(self, | |||
base_model: Model, | |||
gamma = 4 , smooth=1, alpha=0.7, beta=0.3) -> None: | |||
super(ConvUnextv1, self).__init__() | |||
self.base = base_model | |||
self.gamma = gamma | |||
self.smooth = smooth | |||
self.alpha = alpha | |||
self.beta = beta | |||
print(self.gamma) | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_loss_and_gt: torch.Tensor) -> ModelIO: | |||
"""first forwards supervised samples through student net, then all samples through both nets. """ | |||
output = dict(loss=0.0) | |||
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)" | |||
loss_type = mammo_loss_and_gt[:,0,:,:] | |||
mask = mammo_loss_and_gt[:,1,:,:] | |||
network_output = self.base(mammo_x) | |||
network_output = nn.LogSoftmax(dim=1)(network_output).to(mammo_x.device) # logsoftmax # B | |||
network_output_soft = torch.exp(network_output).to(mammo_x.device) # softmax output of model | |||
output['pixel_probs'] = network_output_soft | |||
main_shape = mask.shape | |||
network_output_shape = network_output.shape | |||
loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2]) | |||
mask = mask.unsqueeze(dim = 1) | |||
mask_channels = torch.cat((1-mask, mask), 1) | |||
dummy_mask = mask_channels.clone().to(mammo_x.device) | |||
dummy_net_out = network_output_soft.clone().to(mammo_x.device) | |||
dummy_mask[loss_channels == 0] = 0.0 | |||
dummy_shape = dummy_mask.shape | |||
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_net_out = torch.cat((dummy_net_out,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_yhat = network_output_soft * loss_channels | |||
dummy_y = mask_channels * loss_channels | |||
focal_loss = balanced_focal_cross_entropy_loss_semi( | |||
dummy_net_out, | |||
dummy_mask, | |||
focal_gamma=self.gamma) | |||
dice = dice_loss(dummy_net_out,dummy_mask) | |||
tversky = tversky_loss(inputs = dummy_yhat[:,1,...], targets = dummy_y[:,1,...], smooth=self.smooth, alpha = self.alpha, beta= self.beta) | |||
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y) | |||
output['ce_loss'] = ce_loss | |||
output['loss'] = focal_loss | |||
output['diceloss'] = dice | |||
output['focalloss'] = focal_loss | |||
output['suploss'] = output['loss'] | |||
output['tverskyloss'] = tversky | |||
output['org_pixel_labels'] = mammo_loss_and_gt | |||
return output |
@@ -0,0 +1,84 @@ | |||
from copy import deepcopy | |||
import math | |||
from random import uniform | |||
from typing import List | |||
import time | |||
from mlassistant.core import Model, ModelIO | |||
import torch | |||
from torch.nn import functional as F | |||
from torch import nn | |||
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi | |||
from .....utils.generalized_dice import dice_loss | |||
from .....utils.losses.tverskyLoss import tversky_loss | |||
from .....utils.dynamic_num_generator.num_generator import init_num_generator1, init_num_generator2 | |||
from .....utils.dynamic_num_generator.wrapper import dynamic_num_generator | |||
class Convnext(Model): | |||
def __init__(self, | |||
base_model: Model, | |||
gamma = 4 , smooth=1, alpha=0.7, beta=0.3) -> None: | |||
super(Convnext, self).__init__() | |||
self.base = base_model | |||
self.gamma = gamma | |||
self.smooth = smooth | |||
self.alpha = alpha | |||
self.beta = beta | |||
print(self.gamma) | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_loss_and_gt: torch.Tensor) -> ModelIO: | |||
"""first forwards supervised samples through student net, then all samples through both nets. """ | |||
output = dict(loss=0.0) | |||
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)" | |||
loss_type = mammo_loss_and_gt[:,0,:,:] | |||
mask = mammo_loss_and_gt[:,1,:,:] | |||
print("before model") | |||
network_output = self.base._forward_features(mammo_x) # logsoftmax # B | |||
print(network_output[0].shape) | |||
print(network_output[1].shape) | |||
print(network_output[2].shape) | |||
print(network_output[3].shape) | |||
print(network_output.shape) | |||
network_output_soft = torch.exp(network_output).to(mammo_x.device) # softmax output of model | |||
output['pixel_probs'] = network_output_soft | |||
main_shape = mask.shape | |||
network_output_shape = network_output.shape | |||
loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2]) | |||
mask = mask.unsqueeze(dim = 1) | |||
mask_channels = torch.cat((1-mask, mask), 1) | |||
dummy_mask = mask_channels.clone().to(mammo_x.device) | |||
dummy_net_out = network_output_soft.clone().to(mammo_x.device) | |||
dummy_mask[loss_channels == 0] = 0.0 | |||
dummy_shape = dummy_mask.shape | |||
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_net_out = torch.cat((dummy_net_out,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_yhat = network_output_soft * loss_channels | |||
dummy_y = mask_channels * loss_channels | |||
focal_loss = balanced_focal_cross_entropy_loss_semi( | |||
dummy_net_out, | |||
dummy_mask, | |||
focal_gamma=self.gamma) | |||
dice = dice_loss(dummy_net_out,dummy_mask) | |||
tversky = tversky_loss(inputs = dummy_yhat[:,1,...], targets = dummy_y[:,1,...], smooth=self.smooth, alpha = self.alpha, beta= self.beta) | |||
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y) | |||
output['ce_loss'] = ce_loss | |||
output['loss'] = focal_loss | |||
output['diceloss'] = dice | |||
output['focalloss'] = focal_loss | |||
output['suploss'] = output['loss'] | |||
output['tverskyloss'] = tversky | |||
output['org_pixel_labels'] = mammo_loss_and_gt | |||
return output |
@@ -0,0 +1,80 @@ | |||
from copy import deepcopy | |||
import math | |||
from random import uniform | |||
from typing import List | |||
import time | |||
from mlassistant.core import Model, ModelIO | |||
import torch | |||
from torch.nn import functional as F | |||
from torch import nn | |||
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi | |||
from .....utils.generalized_dice import dice_loss | |||
from .....utils.losses.tverskyLoss import tversky_loss | |||
from .....utils.dynamic_num_generator.num_generator import init_num_generator1, init_num_generator2 | |||
from .....utils.dynamic_num_generator.wrapper import dynamic_num_generator | |||
class DEEPVLAB3(Model): | |||
def __init__(self, | |||
base_model: Model, | |||
gamma = 4 , smooth=1, alpha=0.7, beta=0.3) -> None: | |||
super(DEEPVLAB3, self).__init__() | |||
self.base = base_model | |||
self.gamma = gamma | |||
self.smooth = smooth | |||
self.alpha = alpha | |||
self.beta = beta | |||
print(self.gamma) | |||
print("gamma") | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_loss_and_gt: torch.Tensor) -> ModelIO: | |||
"""first forwards supervised samples through student net, then all samples through both nets. """ | |||
output = dict(loss=0.0) | |||
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)" | |||
loss_type = mammo_loss_and_gt[:,0,:,:] | |||
mask = mammo_loss_and_gt[:,1,:,:] | |||
network_output = self.base(mammo_x.expand(mammo_x.shape[0],3,mammo_x.shape[2],mammo_x.shape[3]).to(mammo_x.device))["out"] # logsoftmax # B | |||
network_output = nn.LogSoftmax(dim=1)(network_output).to(mammo_x.device) # logsoftmax # B | |||
network_output_soft = torch.exp(network_output).to(mammo_x.device) # softmax output of model | |||
output['pixel_probs'] = network_output_soft | |||
main_shape = mask.shape | |||
network_output_shape = network_output.shape | |||
loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2]) | |||
mask = mask.unsqueeze(dim = 1) | |||
mask_channels = torch.cat((1-mask, mask), 1) | |||
dummy_mask = mask_channels.clone().to(mammo_x.device) | |||
dummy_net_out = network_output_soft.clone().to(mammo_x.device) | |||
dummy_mask[loss_channels == 0] = 0.0 | |||
dummy_shape = dummy_mask.shape | |||
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_net_out = torch.cat((dummy_net_out,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_yhat = network_output_soft * loss_channels | |||
dummy_y = mask_channels * loss_channels | |||
focal_loss = balanced_focal_cross_entropy_loss_semi( | |||
dummy_net_out, | |||
dummy_mask, | |||
focal_gamma=self.gamma) | |||
dice = dice_loss(dummy_net_out,dummy_mask) | |||
tversky = tversky_loss(inputs = dummy_yhat[:,1,...], targets = dummy_y[:,1,...], smooth=self.smooth, alpha = self.alpha, beta= self.beta) | |||
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y) | |||
output['ce_loss'] = ce_loss | |||
output['loss'] = focal_loss | |||
output['diceloss'] = dice | |||
output['focalloss'] = focal_loss | |||
output['suploss'] = output['loss'] | |||
output['tverskyloss'] = tversky | |||
output['org_pixel_labels'] = mammo_loss_and_gt | |||
return output |
@@ -0,0 +1,180 @@ | |||
import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
from torch.distributions.uniform import Uniform | |||
def kaiming_normal_init_weight(model): | |||
for m in model.modules(): | |||
if isinstance(m, nn.Conv3d): | |||
torch.nn.init.kaiming_normal_(m.weight) | |||
elif isinstance(m, nn.BatchNorm3d): | |||
m.weight.data.fill_(1) | |||
m.bias.data.zero_() | |||
return model | |||
def sparse_init_weight(model): | |||
for m in model.modules(): | |||
if isinstance(m, nn.Conv3d): | |||
torch.nn.init.sparse_(m.weight, sparsity=0.1) | |||
elif isinstance(m, nn.BatchNorm3d): | |||
m.weight.data.fill_(1) | |||
m.bias.data.zero_() | |||
return model | |||
class ConvBlock(nn.Module): | |||
"""two convolution layers with batch norm and leaky relu""" | |||
def __init__(self, in_channels, out_channels, dropout_p): | |||
super(ConvBlock, self).__init__() | |||
self.conv_conv = nn.Sequential( | |||
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), | |||
nn.BatchNorm2d(out_channels), | |||
nn.LeakyReLU(), | |||
nn.Dropout(dropout_p), | |||
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), | |||
nn.BatchNorm2d(out_channels), | |||
nn.LeakyReLU() | |||
) | |||
def forward(self, x): | |||
return self.conv_conv(x) | |||
class DownBlock(nn.Module): | |||
"""Downsampling followed by ConvBlock""" | |||
def __init__(self, in_channels, out_channels, dropout_p): | |||
super(DownBlock, self).__init__() | |||
self._conv = ConvBlock(in_channels, out_channels, dropout_p) | |||
self.maxpool = nn.MaxPool2d(2) | |||
def forward(self, x, resx , res = True): | |||
if res: | |||
return self.maxpool(self._conv(x) + resx) | |||
return self.maxpool(self._conv(x)) | |||
class UpBlock(nn.Module): | |||
"""Upssampling followed by ConvBlock""" | |||
def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, | |||
bilinear=True): | |||
super(UpBlock, self).__init__() | |||
self.bilinear = bilinear | |||
if bilinear: | |||
self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size=1) | |||
self.up = nn.Upsample( | |||
scale_factor=2, mode='bilinear', align_corners=True) | |||
else: | |||
self.up = nn.ConvTranspose2d( | |||
in_channels1, in_channels2, kernel_size=2, stride=2) | |||
self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p) | |||
self.res = nn.Conv2d(in_channels2 * 2, out_channels,1) | |||
def forward(self, x1, x2): | |||
if self.bilinear: | |||
x1 = self.conv1x1(x1) | |||
x1 = self.up(x1) | |||
x = torch.cat([x2, x1], dim=1) | |||
return self.conv(x) + self.res(x) # for resudial blocks | |||
class Encoder(nn.Module): | |||
def __init__(self, params): | |||
super(Encoder, self).__init__() | |||
self.params = params | |||
self.in_chns = self.params['in_chns'] | |||
self.ft_chns = self.params['feature_chns'] | |||
self.n_class = self.params['class_num'] | |||
self.bilinear = self.params['bilinear'] | |||
self.dropout = self.params['dropout'] | |||
assert (len(self.ft_chns) == 5) | |||
self.in_conv = ConvBlock( | |||
self.in_chns, self.ft_chns[0], self.dropout[0]) | |||
self.down1 = DownBlock( | |||
self.ft_chns[0], self.ft_chns[1], self.dropout[1]) | |||
self.down2 = DownBlock( | |||
self.ft_chns[1], self.ft_chns[2], self.dropout[2]) | |||
self.down3 = DownBlock( | |||
self.ft_chns[2], self.ft_chns[3], self.dropout[3]) | |||
self.down4 = DownBlock( | |||
self.ft_chns[3], self.ft_chns[4], self.dropout[4]) | |||
self.down_res = nn.ModuleList([nn.Conv2d(self.in_chns,self.ft_chns[0],1), | |||
nn.Conv2d(self.ft_chns[0],self.ft_chns[1],1), | |||
nn.Conv2d(self.ft_chns[1],self.ft_chns[2],1), | |||
nn.Conv2d(self.ft_chns[2],self.ft_chns[3],1)]) | |||
def forward(self, x): | |||
x0 = self.in_conv(x) + self.down_res[0](x) | |||
x1 = self.down1(x0,self.down_res[1](x0)) | |||
x2 = self.down2(x1,self.down_res[2](x1)) | |||
x3 = self.down3(x2,self.down_res[3](x2)) | |||
x4 = self.down4(x3,None,False) | |||
return [x0, x1, x2, x3, x4] | |||
class Decoder(nn.Module): | |||
def __init__(self, params): | |||
super(Decoder, self).__init__() | |||
self.params = params | |||
self.in_chns = self.params['in_chns'] | |||
self.ft_chns = self.params['feature_chns'] | |||
self.n_class = self.params['class_num'] | |||
self.bilinear = self.params['bilinear'] | |||
assert (len(self.ft_chns) == 5) | |||
self.up1 = UpBlock( | |||
self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p=0.0) | |||
self.up2 = UpBlock( | |||
self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p=0.0) | |||
self.up3 = UpBlock( | |||
self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p=0.0) | |||
self.up4 = UpBlock( | |||
self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p=0.0) | |||
self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, | |||
kernel_size=3, padding=1) | |||
def forward(self, feature): | |||
x0 = feature[0] | |||
x1 = feature[1] | |||
x2 = feature[2] | |||
x3 = feature[3] | |||
x4 = feature[4] | |||
xup0 = self.up1(x4, x3) | |||
xup1 = self.up2(xup0, x2) | |||
xup2 = self.up3(xup1, x1) | |||
xup3 = self.up4(xup2, x0) | |||
output = self.out_conv(xup3) | |||
return output | |||
class RESUNet(nn.Module): | |||
def __init__(self, in_chns, class_num): | |||
super(RESUNet, self).__init__() | |||
params = {'in_chns': in_chns, | |||
'feature_chns': [64, 128, 128, 256, 512], | |||
'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], | |||
'class_num': class_num, | |||
'bilinear': False, | |||
'acti_func': 'relu'} | |||
self.encoder = Encoder(params) | |||
self.decoder = Decoder(params) | |||
def forward(self, x, need_fp=False): | |||
feature = self.encoder(x) | |||
if need_fp: | |||
outs = self.decoder([torch.cat((feat, nn.Dropout2d(0.5)(feat))) for feat in feature]) | |||
return outs.chunk(2) | |||
output = self.decoder(feature) | |||
return output |
@@ -0,0 +1,50 @@ | |||
from copy import deepcopy | |||
import math | |||
from random import uniform | |||
from typing import List | |||
import time | |||
from mlassistant.core import Model, ModelIO | |||
import torch | |||
from torch.nn import functional as F | |||
from torch import nn | |||
from ......utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss | |||
from ......utils.generalized_dice import dice_loss | |||
from ......utils.losses.tverskyLoss import tversky_loss | |||
from ......utils.dynamic_num_generator.num_generator import init_num_generator1, init_num_generator2 | |||
from ......utils.dynamic_num_generator.wrapper import dynamic_num_generator | |||
class RESUNET(Model): | |||
def __init__(self, | |||
base_model: Model, | |||
gamma = 4 , smooth=1, alpha=0.7, beta=0.3) -> None: | |||
super(RESUNET, self).__init__() | |||
self.base = base_model | |||
self.gamma = gamma | |||
self.smooth = smooth | |||
self.alpha = alpha | |||
self.beta = beta | |||
print(self.gamma) | |||
print("gamma") | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_roi_x: torch.Tensor) -> ModelIO: | |||
output = dict(loss=0.0) | |||
network_output = self.base.forward(mammo_x).softmax(dim=1) | |||
network_output_soft = network_output | |||
output['pixel_probs'] = network_output_soft | |||
focal_loss = balanced_focal_cross_entropy_loss( | |||
network_output_soft, | |||
mammo_roi_x, | |||
focal_gamma=self.gamma) | |||
output['loss'] = focal_loss | |||
output['org_pixel_labels'] = mammo_roi_x | |||
return output |
@@ -0,0 +1,79 @@ | |||
from copy import deepcopy | |||
import math | |||
from random import uniform | |||
from typing import List | |||
import time | |||
from mlassistant.core import Model, ModelIO | |||
import torch | |||
from torch.nn import functional as F | |||
from torch import nn | |||
from ......utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi | |||
from ......utils.generalized_dice import dice_loss | |||
from ......utils.losses.tverskyLoss import tversky_loss | |||
from ......utils.dynamic_num_generator.num_generator import init_num_generator1, init_num_generator2 | |||
from ......utils.dynamic_num_generator.wrapper import dynamic_num_generator | |||
class RESUNET(Model): | |||
def __init__(self, | |||
base_model: Model, | |||
gamma = 4 , smooth=1, alpha=0.7, beta=0.3) -> None: | |||
super(RESUNET, self).__init__() | |||
self.base = base_model | |||
self.gamma = gamma | |||
self.smooth = smooth | |||
self.alpha = alpha | |||
self.beta = beta | |||
print(self.gamma) | |||
print("gamma") | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_loss_and_gt: torch.Tensor) -> ModelIO: | |||
"""first forwards supervised samples through student net, then all samples through both nets. """ | |||
output = dict(loss=0.0) | |||
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)" | |||
loss_type = mammo_loss_and_gt[:,0,:,:] | |||
mask = mammo_loss_and_gt[:,1,:,:] | |||
network_output = self.base.forward(mammo_x).softmax(dim=1) # logsoftmax # B | |||
network_output_soft = network_output #torch.exp(network_output).to(mammo_x.device) # softmax output of model | |||
output['pixel_probs'] = network_output_soft | |||
main_shape = mask.shape | |||
network_output_shape = network_output.shape | |||
loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2]) | |||
mask = mask.unsqueeze(dim = 1) | |||
mask_channels = torch.cat((1-mask, mask), 1) | |||
dummy_mask = mask_channels.clone().to(mammo_x.device) | |||
dummy_net_out = network_output_soft.clone().to(mammo_x.device) | |||
dummy_mask[loss_channels == 0] = 0.0 | |||
dummy_shape = dummy_mask.shape | |||
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_net_out = torch.cat((dummy_net_out,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_yhat = network_output_soft * loss_channels | |||
dummy_y = mask_channels * loss_channels | |||
focal_loss = balanced_focal_cross_entropy_loss_semi( | |||
dummy_net_out, | |||
dummy_mask, | |||
focal_gamma=self.gamma) | |||
dice = dice_loss(dummy_net_out,dummy_mask) | |||
tversky = tversky_loss(inputs = dummy_yhat[:,1,...], targets = dummy_y[:,1,...], smooth=self.smooth, alpha = self.alpha, beta= self.beta) | |||
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y) | |||
output['ce_loss'] = ce_loss | |||
output['loss'] = focal_loss | |||
output['diceloss'] = dice | |||
output['focalloss'] = focal_loss | |||
output['suploss'] = output['loss'] | |||
output['tverskyloss'] = tversky | |||
output['org_pixel_labels'] = mammo_loss_and_gt | |||
return output |
@@ -0,0 +1,107 @@ | |||
from copy import deepcopy | |||
import math | |||
from random import uniform | |||
from typing import List | |||
import time | |||
from mlassistant.core import Model, ModelIO | |||
import torch | |||
from torch.nn import functional as F | |||
from torch import nn | |||
from ......utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi | |||
from ......utils.generalized_dice import dice_loss | |||
from ......utils.losses.ent_losses import entropy_loss_normalized | |||
from ......utils.losses.tverskyLoss import tversky_loss | |||
from ......utils.dynamic_num_generator.num_generator import init_num_generator1, init_num_generator2 | |||
from ......utils.dynamic_num_generator.wrapper import dynamic_num_generator | |||
class RESUNET(Model): | |||
def __init__(self, | |||
base_model: Model, | |||
gamma = 2 , smooth=1, alpha=0.7, beta=0.3 , ent_factor = lambda x : 0.1) -> None: | |||
super(RESUNET, self).__init__() | |||
self.base = base_model | |||
self.gamma = gamma | |||
self.smooth = smooth | |||
self.alpha = alpha | |||
self.beta = beta | |||
self.ent_factor = ent_factor | |||
print(self.gamma) | |||
self.epochNUM = 1 | |||
print(self.epochNUM) | |||
def train(self,mode : bool=True): | |||
super().train(mode) | |||
self.epochNUM += 1 | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_loss_and_gt: torch.Tensor) -> ModelIO: | |||
"""first forwards supervised samples through student net, then all samples through both nets. """ | |||
output = dict(loss=0.0) | |||
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)" | |||
loss_type = mammo_loss_and_gt[:,0,:,:] | |||
mask = mammo_loss_and_gt[:,1,:,:] | |||
network_output = self.base.forward(mammo_x).softmax(dim=1) # logsoftmax # B | |||
network_output_soft = network_output #torch.exp(network_output).to(mammo_x.device) # softmax output of model | |||
output['pixel_probs'] = network_output_soft | |||
main_shape = mask.shape | |||
network_output_shape = network_output.shape | |||
loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2]) | |||
mask = mask.unsqueeze(dim = 1) | |||
mask_channels = torch.cat((1-mask, mask), 1) | |||
dummy_mask = mask_channels.clone().to(mammo_x.device) | |||
dummy_net_out = network_output_soft.clone().to(mammo_x.device) | |||
dummy_mask[loss_channels == 0] = 0.0 | |||
dummy_shape = dummy_mask.shape | |||
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_net_out = torch.cat((dummy_net_out,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
# dummy_yhat = network_output_soft * loss_channels | |||
# dummy_y = mask_channels * loss_channels | |||
focal_loss = balanced_focal_cross_entropy_loss_semi( | |||
dummy_net_out, | |||
dummy_mask, | |||
focal_gamma=self.gamma) | |||
#dice = dice_loss(dummy_net_out,dummy_mask) | |||
#tversky = tversky_loss(inputs = dummy_yhat[:,1,...], targets = dummy_y[:,1,...], smooth=self.smooth, alpha = self.alpha, beta= self.beta) | |||
#ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y) | |||
entropy_loss_val = entropy_loss_normalized(network_output_soft) | |||
entropy_suploss_val = torch.mean(entropy_loss_val * loss_type) | |||
entropy_unsuploss_val = torch.mean(entropy_loss_val * (1-loss_type)) | |||
sup_ratio = torch.sum(loss_type == 1) | |||
all_cel = torch.prod(torch.tensor(loss_type.shape)).to(mammo_x.device) | |||
entropy_sup = entropy_suploss_val * (all_cel/sup_ratio) | |||
entropy_unsup = entropy_unsuploss_val * (all_cel/(all_cel - sup_ratio + 1e-30)) | |||
entropy_loss_val = torch.mean(entropy_loss_val) | |||
entloss = torch.mean(torch.tensor([entropy_sup, entropy_unsup])).to(mammo_x.device) | |||
coef = self.ent_factor(int(self.epochNUM/2)) | |||
#output['ce_loss'] = ce_loss | |||
output['loss'] = focal_loss + coef * entropy_loss_val | |||
output['totalloss'] = output['loss'] | |||
#output['diceloss'] = dice | |||
output['focalloss'] = focal_loss | |||
output['suploss'] = focal_loss | |||
output['unsuploss'] = entropy_loss_val | |||
output['coef_loss'] = torch.tensor(coef).to(mammo_x.device) | |||
#output['tverskyloss'] = tversky | |||
output['org_pixel_labels'] = mammo_loss_and_gt | |||
output["entallloss"] = entropy_loss_val | |||
output["supentloss"] = entropy_sup | |||
output["unsupentloss"] = entropy_unsup | |||
output["entloss"] = entloss | |||
return output |
@@ -0,0 +1,96 @@ | |||
from copy import deepcopy | |||
import math | |||
from random import uniform | |||
from typing import List | |||
import time | |||
from mlassistant.core import Model, ModelIO | |||
import torch | |||
from torch.nn import functional as F | |||
from torch import nn | |||
from ......utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi | |||
from ......utils.generalized_dice import dice_loss | |||
from ......utils.losses.tverskyLoss import tversky_loss | |||
from ......utils.losses.region_topk_focal_ce import calculate_topk_focal_ce_loss_per_region as topk | |||
from ......utils.dynamic_num_generator.num_generator import init_num_generator1, init_num_generator2 | |||
from ......utils.dynamic_num_generator.wrapper import dynamic_num_generator | |||
class RESUNET(Model): | |||
def __init__(self, | |||
base_model: Model, | |||
gamma = 4 , smooth=1, alpha=0.7, beta=0.3, gskernel = None) -> None: | |||
super(RESUNET, self).__init__() | |||
self.base = base_model | |||
self.gamma = gamma | |||
self.smooth = smooth | |||
self.alpha = alpha | |||
self.beta = beta | |||
self.gskernel = gskernel | |||
self.epochNUM = 1 | |||
print(self.gamma) | |||
print("gamma") | |||
def train(self,mode : bool=True): | |||
super().train(mode) | |||
self.epochNUM += 1 | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_loss_and_gt: torch.Tensor) -> ModelIO: | |||
"""first forwards supervised samples through student net, then all samples through both nets. """ | |||
output = dict(loss=0.0) | |||
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)" | |||
loss_type = mammo_loss_and_gt[:,0,:,:] | |||
mask = mammo_loss_and_gt[:,1,:,:] | |||
network_output = self.base.forward(mammo_x).softmax(dim=1) # logsoftmax # B | |||
network_output_soft = network_output #torch.exp(network_output).to(mammo_x.device) # softmax output of model | |||
output['pixel_probs'] = network_output_soft | |||
main_shape = mask.shape | |||
network_output_shape = network_output.shape | |||
loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2]) | |||
mask = mask.unsqueeze(dim = 1) | |||
mask_channels = torch.cat((1-mask, mask), 1) | |||
dummy_mask = mask_channels.clone().to(mammo_x.device) | |||
dummy_net_out = network_output_soft.clone().to(mammo_x.device) | |||
dummy_mask[loss_channels == 0] = 0.0 | |||
dummy_shape = dummy_mask.shape | |||
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_net_out = torch.cat((dummy_net_out,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_yhat = network_output_soft * loss_channels | |||
dummy_y = mask_channels * loss_channels | |||
focal_loss = balanced_focal_cross_entropy_loss_semi( | |||
dummy_net_out, | |||
dummy_mask, | |||
focal_gamma=self.gamma) | |||
dice = dice_loss(dummy_yhat,dummy_y) | |||
tversky = tversky_loss(inputs = dummy_yhat[:,1,...], targets = dummy_y[:,1,...], smooth=self.smooth, alpha = self.alpha, beta= self.beta) | |||
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y) | |||
topk_loss = topk(model_preds=dummy_yhat,roi_mask = dummy_y[:,0,...], fill_ratio=0.5, focal_gamma=self.gamma) | |||
coef = self.gskernel(int(self.epochNUM/2)) | |||
output['ce_loss'] = ce_loss | |||
output['topk_loss'] = topk_loss | |||
output['gscoef_loss'] = torch.tensor(coef).to(mammo_x.device) | |||
output['loss'] = coef * topk_loss + (1-coef) * focal_loss | |||
output['diceloss'] = dice | |||
output['focalloss'] = focal_loss | |||
output['suploss'] = output['loss'] | |||
output['tverskyloss'] = tversky | |||
output['org_pixel_labels'] = mammo_loss_and_gt | |||
return output |
@@ -0,0 +1,80 @@ | |||
from copy import deepcopy | |||
import math | |||
from random import uniform | |||
from typing import List | |||
import time | |||
from mlassistant.core import Model, ModelIO | |||
import torch | |||
from torch.nn import functional as F | |||
from torch import nn | |||
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi | |||
from .....utils.generalized_dice import dice_loss | |||
from .....utils.losses.tverskyLoss import tversky_loss | |||
from .....utils.dynamic_num_generator.num_generator import init_num_generator1, init_num_generator2 | |||
from .....utils.dynamic_num_generator.wrapper import dynamic_num_generator | |||
class RESUNET(Model): | |||
def __init__(self, | |||
base_model: Model, | |||
gamma = 2 , smooth=1, alpha=0.7, beta=0.3) -> None: | |||
super(RESUNET, self).__init__() | |||
self.base = base_model | |||
self.gamma = gamma | |||
self.smooth = smooth | |||
self.alpha = alpha | |||
self.beta = beta | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_loss_and_gt: torch.Tensor) -> ModelIO: | |||
"""first forwards supervised samples through student net, then all samples through both nets. """ | |||
output = dict(loss=0.0) | |||
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)" | |||
loss_type = mammo_loss_and_gt[:,0,:,:] | |||
mask = mammo_loss_and_gt[:,1,:,:] | |||
network_output = self.base._forward(mammo_x) # logsoftmax # B | |||
network_output_soft = torch.exp(network_output).to(mammo_x.device) # softmax output of model | |||
output['pixel_probs'] = network_output_soft | |||
main_shape = mask.shape | |||
network_output_shape = network_output.shape | |||
loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2]) | |||
mask = mask.unsqueeze(dim = 1) | |||
mask_channels = torch.cat((1-mask, mask), 1) | |||
dummy_mask = mask_channels.clone().to(mammo_x.device) | |||
dummy_net_out = network_output_soft.clone().to(mammo_x.device) | |||
dummy_mask[loss_channels == 0] = 0.0 | |||
dummy_shape = dummy_mask.shape | |||
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_net_out = torch.cat((dummy_net_out,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_yhat = network_output_soft * loss_channels | |||
dummy_y = mask_channels * loss_channels | |||
focal_loss = balanced_focal_cross_entropy_loss_semi( | |||
dummy_net_out, | |||
dummy_mask, | |||
focal_gamma=self.gamma) | |||
dice = dice_loss(dummy_net_out,dummy_mask) | |||
tversky = tversky_loss(inputs = dummy_yhat[:,1,...], targets = dummy_y[:,1,...], smooth=self.smooth, alpha = self.alpha, beta= self.beta) | |||
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y) | |||
output['ce_1_loss'] = F.binary_cross_entropy(dummy_yhat * dummy_y,dummy_y) | |||
output['ce_0_loss'] = F.binary_cross_entropy(dummy_yhat * (1 - dummy_y) + dummy_y, dummy_y) | |||
output['ce_loss'] = ce_loss | |||
output['loss'] = ce_loss | |||
output['diceloss'] = dice | |||
output['focalloss'] = focal_loss | |||
output['suploss'] = output['loss'] | |||
output['tverskyloss'] = tversky | |||
output['org_pixel_labels'] = mammo_loss_and_gt | |||
return output |
@@ -0,0 +1,80 @@ | |||
from copy import deepcopy | |||
import math | |||
from random import uniform | |||
from typing import List | |||
import time | |||
from mlassistant.core import Model, ModelIO | |||
import torch | |||
from torch.nn import functional as F | |||
from torch import nn | |||
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi | |||
from .....utils.generalized_dice import dice_loss | |||
from .....utils.losses.tverskyLoss import tversky_loss | |||
from .....utils.dynamic_num_generator.num_generator import init_num_generator1, init_num_generator2 | |||
from .....utils.dynamic_num_generator.wrapper import dynamic_num_generator | |||
class RESUNET(Model): | |||
def __init__(self, | |||
base_model: Model, | |||
gamma = 2 , smooth=1, alpha=0.7, beta=0.3) -> None: | |||
super(RESUNET, self).__init__() | |||
self.base = base_model | |||
self.gamma = gamma | |||
self.smooth = smooth | |||
self.alpha = alpha | |||
self.beta = beta | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_loss_and_gt: torch.Tensor) -> ModelIO: | |||
"""first forwards supervised samples through student net, then all samples through both nets. """ | |||
output = dict(loss=0.0) | |||
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)" | |||
loss_type = mammo_loss_and_gt[:,0,:,:] | |||
mask = mammo_loss_and_gt[:,1,:,:] | |||
network_output = self.base._forward(mammo_x) # logsoftmax # B | |||
network_output_soft = torch.exp(network_output).to(mammo_x.device) # softmax output of model | |||
output['pixel_probs'] = network_output_soft | |||
main_shape = mask.shape | |||
network_output_shape = network_output.shape | |||
loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2]) | |||
mask = mask.unsqueeze(dim = 1) | |||
mask_channels = torch.cat((1-mask, mask), 1) | |||
dummy_mask = mask_channels.clone().to(mammo_x.device) | |||
dummy_net_out = network_output_soft.clone().to(mammo_x.device) | |||
dummy_mask[loss_channels == 0] = 0.0 | |||
dummy_shape = dummy_mask.shape | |||
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_net_out = torch.cat((dummy_net_out,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_yhat = network_output_soft * loss_channels | |||
dummy_y = mask_channels * loss_channels | |||
focal_loss = balanced_focal_cross_entropy_loss_semi( | |||
dummy_net_out, | |||
dummy_mask, | |||
focal_gamma=self.gamma) | |||
dice = dice_loss(dummy_net_out,dummy_mask) | |||
tversky = tversky_loss(inputs = dummy_yhat[:,1,...], targets = dummy_y[:,1,...], smooth=self.smooth, alpha = self.alpha, beta= self.beta) | |||
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y) | |||
output['ce_1_loss'] = F.binary_cross_entropy(dummy_yhat * dummy_y,dummy_y) | |||
output['ce_0_loss'] = F.binary_cross_entropy(dummy_yhat * (1 - dummy_y) + dummy_y, dummy_y) | |||
output['ce_loss'] = ce_loss | |||
output['loss'] = output['ce_1_loss'] + output['ce_0_loss'] | |||
output['diceloss'] = dice | |||
output['focalloss'] = focal_loss | |||
output['suploss'] = output['loss'] | |||
output['tverskyloss'] = tversky | |||
output['org_pixel_labels'] = mammo_loss_and_gt | |||
return output |
@@ -0,0 +1,80 @@ | |||
from copy import deepcopy | |||
import math | |||
from random import uniform | |||
from typing import List | |||
import time | |||
from mlassistant.core import Model, ModelIO | |||
import torch | |||
from torch.nn import functional as F | |||
from torch import nn | |||
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi | |||
from .....utils.generalized_dice import dice_loss | |||
from .....utils.losses.tverskyLoss import tversky_loss | |||
from .....utils.dynamic_num_generator.num_generator import init_num_generator1, init_num_generator2 | |||
from .....utils.dynamic_num_generator.wrapper import dynamic_num_generator | |||
class RESUNET(Model): | |||
def __init__(self, | |||
base_model: Model, | |||
gamma = 2 , smooth=1, alpha=0.7, beta=0.3) -> None: | |||
super(RESUNET, self).__init__() | |||
self.base = base_model | |||
self.gamma = gamma | |||
self.smooth = smooth | |||
self.alpha = alpha | |||
self.beta = beta | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_loss_and_gt: torch.Tensor) -> ModelIO: | |||
"""first forwards supervised samples through student net, then all samples through both nets. """ | |||
output = dict(loss=0.0) | |||
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)" | |||
loss_type = mammo_loss_and_gt[:,0,:,:] | |||
mask = mammo_loss_and_gt[:,1,:,:] | |||
network_output = self.base._forward(mammo_x) # logsoftmax # B | |||
network_output_soft = torch.exp(network_output).to(mammo_x.device) # softmax output of model | |||
output['pixel_probs'] = network_output_soft | |||
main_shape = mask.shape | |||
network_output_shape = network_output.shape | |||
loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2]) | |||
mask = mask.unsqueeze(dim = 1) | |||
mask_channels = torch.cat((1-mask, mask), 1) | |||
dummy_mask = mask_channels.clone().to(mammo_x.device) | |||
dummy_net_out = network_output_soft.clone().to(mammo_x.device) | |||
dummy_mask[loss_channels == 0] = 0.0 | |||
dummy_shape = dummy_mask.shape | |||
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_net_out = torch.cat((dummy_net_out,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_yhat = network_output_soft * loss_channels | |||
dummy_y = mask_channels * loss_channels | |||
focal_loss = balanced_focal_cross_entropy_loss_semi( | |||
dummy_net_out, | |||
dummy_mask, | |||
focal_gamma=self.gamma) | |||
dice = dice_loss(dummy_net_out,dummy_mask) | |||
tversky = tversky_loss(inputs = dummy_yhat[:,1,...], targets = dummy_y[:,1,...], smooth=self.smooth, alpha = self.alpha, beta= self.beta) | |||
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y) | |||
output['ce_1_loss'] = F.binary_cross_entropy(dummy_yhat * dummy_y,dummy_y) | |||
output['ce_0_loss'] = F.binary_cross_entropy(dummy_yhat * (1 - dummy_y) + dummy_y, dummy_y) | |||
output['ce_loss'] = ce_loss | |||
output['loss'] = dice | |||
output['diceloss'] = dice | |||
output['focalloss'] = focal_loss | |||
output['suploss'] = dice | |||
output['tverskyloss'] = tversky | |||
output['org_pixel_labels'] = mammo_loss_and_gt | |||
return output |
@@ -0,0 +1,79 @@ | |||
from copy import deepcopy | |||
import math | |||
from random import uniform | |||
from typing import List | |||
import time | |||
from mlassistant.core import Model, ModelIO | |||
import torch | |||
from torch.nn import functional as F | |||
from torch import nn | |||
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi | |||
from .....utils.generalized_dice import dice_loss | |||
from .....utils.losses.tverskyLoss import tversky_loss | |||
from .....utils.dynamic_num_generator.num_generator import init_num_generator1, init_num_generator2 | |||
from .....utils.dynamic_num_generator.wrapper import dynamic_num_generator | |||
class RESUNET(Model): | |||
def __init__(self, | |||
base_model: Model, | |||
gamma = 4 , smooth=1, alpha=0.7, beta=0.3) -> None: | |||
super(RESUNET, self).__init__() | |||
self.base = base_model | |||
self.gamma = gamma | |||
self.smooth = smooth | |||
self.alpha = alpha | |||
self.beta = beta | |||
print(self.gamma) | |||
print("gamma") | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_loss_and_gt: torch.Tensor) -> ModelIO: | |||
"""first forwards supervised samples through student net, then all samples through both nets. """ | |||
output = dict(loss=0.0) | |||
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)" | |||
loss_type = mammo_loss_and_gt[:,0,:,:] | |||
mask = mammo_loss_and_gt[:,1,:,:] | |||
network_output = self.base._forward(mammo_x) # logsoftmax # B | |||
network_output_soft = torch.exp(network_output).to(mammo_x.device) # softmax output of model | |||
output['pixel_probs'] = network_output_soft | |||
main_shape = mask.shape | |||
network_output_shape = network_output.shape | |||
loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2]) | |||
mask = mask.unsqueeze(dim = 1) | |||
mask_channels = torch.cat((1-mask, mask), 1) | |||
dummy_mask = mask_channels.clone().to(mammo_x.device) | |||
dummy_net_out = network_output_soft.clone().to(mammo_x.device) | |||
dummy_mask[loss_channels == 0] = 0.0 | |||
dummy_shape = dummy_mask.shape | |||
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_net_out = torch.cat((dummy_net_out,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_yhat = network_output_soft * loss_channels | |||
dummy_y = mask_channels * loss_channels | |||
focal_loss = balanced_focal_cross_entropy_loss_semi( | |||
dummy_net_out, | |||
dummy_mask, | |||
focal_gamma=self.gamma) | |||
dice = dice_loss(dummy_net_out,dummy_mask) | |||
tversky = tversky_loss(inputs = dummy_yhat[:,1,...], targets = dummy_y[:,1,...], smooth=self.smooth, alpha = self.alpha, beta= self.beta) | |||
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y) | |||
output['ce_loss'] = ce_loss | |||
output['loss'] = focal_loss | |||
output['diceloss'] = dice | |||
output['focalloss'] = focal_loss | |||
output['suploss'] = output['loss'] | |||
output['tverskyloss'] = tversky | |||
output['org_pixel_labels'] = mammo_loss_and_gt | |||
return output |
@@ -0,0 +1,107 @@ | |||
from copy import deepcopy | |||
import math | |||
from random import uniform | |||
from typing import List | |||
import time | |||
from mlassistant.core import Model, ModelIO | |||
import torch | |||
from torch.nn import functional as F | |||
from torch import nn | |||
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi | |||
from .....utils.generalized_dice import dice_loss | |||
from .....utils.losses.ent_losses import entropy_loss_normalized | |||
from .....utils.losses.tverskyLoss import tversky_loss | |||
from .....utils.dynamic_num_generator.num_generator import init_num_generator1, init_num_generator2 | |||
from .....utils.dynamic_num_generator.wrapper import dynamic_num_generator | |||
class RESUNET(Model): | |||
def __init__(self, | |||
base_model: Model, | |||
gamma = 2 , smooth=1, alpha=0.7, beta=0.3 , ent_factor = lambda x : 0.1) -> None: | |||
super(RESUNET, self).__init__() | |||
self.base = base_model | |||
self.gamma = gamma | |||
self.smooth = smooth | |||
self.alpha = alpha | |||
self.beta = beta | |||
self.ent_factor = ent_factor | |||
print(self.gamma) | |||
self.epochNUM = 200 | |||
print(self.epochNUM) | |||
def train(self,mode : bool=True): | |||
super().train(mode) | |||
self.epochNUM += 1 | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_loss_and_gt: torch.Tensor) -> ModelIO: | |||
"""first forwards supervised samples through student net, then all samples through both nets. """ | |||
output = dict(loss=0.0) | |||
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)" | |||
loss_type = mammo_loss_and_gt[:,0,:,:] | |||
mask = mammo_loss_and_gt[:,1,:,:] | |||
network_output = self.base._forward(mammo_x) # logsoftmax # B | |||
network_output_soft = torch.exp(network_output).to(mammo_x.device) # softmax output of model | |||
output['pixel_probs'] = network_output_soft | |||
main_shape = mask.shape | |||
network_output_shape = network_output.shape | |||
loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2]) | |||
mask = mask.unsqueeze(dim = 1) | |||
mask_channels = torch.cat((1-mask, mask), 1) | |||
dummy_mask = mask_channels.clone().to(mammo_x.device) | |||
dummy_net_out = network_output_soft.clone().to(mammo_x.device) | |||
dummy_mask[loss_channels == 0] = 0.0 | |||
dummy_shape = dummy_mask.shape | |||
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_net_out = torch.cat((dummy_net_out,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_yhat = network_output_soft * loss_channels | |||
dummy_y = mask_channels * loss_channels | |||
focal_loss = balanced_focal_cross_entropy_loss_semi( | |||
dummy_net_out, | |||
dummy_mask, | |||
focal_gamma=self.gamma) | |||
dice = dice_loss(dummy_net_out,dummy_mask) | |||
tversky = tversky_loss(inputs = dummy_yhat[:,1,...], targets = dummy_y[:,1,...], smooth=self.smooth, alpha = self.alpha, beta= self.beta) | |||
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y) | |||
entropy_loss_val = entropy_loss_normalized(network_output_soft) | |||
entropy_suploss_val = torch.mean(entropy_loss_val * loss_type) | |||
entropy_unsuploss_val = torch.mean(entropy_loss_val * (1-loss_type)) | |||
sup_ratio = torch.sum(loss_type == 1) | |||
all_cel = torch.prod(torch.tensor(loss_type.shape)).to(mammo_x.device) | |||
entropy_sup = entropy_suploss_val * (all_cel/sup_ratio) | |||
entropy_unsup = entropy_unsuploss_val * (all_cel/(all_cel - sup_ratio + 1e-30)) | |||
entropy_loss_val = torch.mean(entropy_loss_val) | |||
entloss = torch.mean(torch.tensor([entropy_sup, entropy_unsup])).to(mammo_x.device) | |||
coef = self.ent_factor(int(self.epochNUM/2)) | |||
output['ce_loss'] = ce_loss | |||
output['loss'] = focal_loss + coef * entropy_loss_val | |||
output['totalloss'] = output['loss'] | |||
output['diceloss'] = dice | |||
output['focalloss'] = focal_loss | |||
output['suploss'] = focal_loss | |||
output['unsuploss'] = entropy_loss_val | |||
output['coef_loss'] = torch.tensor(coef).to(mammo_x.device) | |||
output['tverskyloss'] = tversky | |||
output['org_pixel_labels'] = mammo_loss_and_gt | |||
output["entallloss"] = entropy_loss_val | |||
output["supentloss"] = entropy_sup | |||
output["unsupentloss"] = entropy_unsup | |||
output["entloss"] = entloss | |||
return output |
@@ -0,0 +1,94 @@ | |||
from copy import deepcopy | |||
import math | |||
from random import uniform | |||
from typing import List | |||
import time | |||
from mlassistant.core import Model, ModelIO | |||
import torch | |||
from torch.nn import functional as F | |||
from torch import nn | |||
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi | |||
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi_pixel_wise | |||
from .....utils.generalized_dice import dice_loss | |||
from .....utils.losses.tverskyLoss import tversky_loss | |||
from .....utils.dynamic_num_generator.num_generator import init_num_generator1, init_num_generator2 | |||
from .....utils.dynamic_num_generator.wrapper import dynamic_num_generator | |||
class RESUNET(Model): | |||
def __init__(self, | |||
base_model: Model, | |||
t = -2 , smooth=1, alpha=0.7, beta=0.3 , gamma = 2) -> None: | |||
super(RESUNET, self).__init__() | |||
self.base = base_model | |||
self.t = t | |||
self.smooth = smooth | |||
self.alpha = alpha | |||
self.beta = beta | |||
print(self.t , gamma) | |||
self.gamma = gamma | |||
print("gamma") | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_loss_and_gt: torch.Tensor) -> ModelIO: | |||
"""first forwards supervised samples through student net, then all samples through both nets. """ | |||
output = dict(loss=0.0) | |||
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)" | |||
loss_type = mammo_loss_and_gt[:,0,:,:] | |||
mask = mammo_loss_and_gt[:,1,:,:] | |||
network_output = self.base._forward(mammo_x) # logsoftmax # B | |||
network_output_soft = torch.exp(network_output).to(mammo_x.device) # softmax output of model | |||
output['pixel_probs'] = network_output_soft | |||
main_shape = mask.shape | |||
network_output_shape = network_output.shape | |||
loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2]) | |||
mask = mask.unsqueeze(dim = 1) | |||
mask_channels = torch.cat((1-mask, mask), 1) | |||
dummy_mask = mask_channels.clone().to(mammo_x.device) | |||
dummy_net_out = network_output_soft.clone().to(mammo_x.device) | |||
dummy_mask[loss_channels == 0] = 0.0 | |||
dummy_shape = dummy_mask.shape | |||
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_net_out = torch.cat((dummy_net_out,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_yhat = network_output_soft * loss_channels | |||
dummy_y = mask_channels * loss_channels | |||
focal_loss = balanced_focal_cross_entropy_loss_semi( | |||
dummy_net_out, | |||
dummy_mask, | |||
focal_gamma=self.gamma).to(mammo_x.device) | |||
dice = dice_loss(dummy_net_out,dummy_mask) | |||
tversky = tversky_loss(inputs = dummy_yhat[:,1,...], targets = dummy_y[:,1,...], smooth=self.smooth, alpha = self.alpha, beta= self.beta) | |||
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y) | |||
ce_per,losses = balanced_focal_cross_entropy_loss_semi_pixel_wise( | |||
dummy_net_out, | |||
dummy_mask, | |||
focal_gamma=self.gamma) | |||
ce_per = ce_per.to(mammo_x.device) | |||
tilted = 1/self.t * torch.log(torch.mean(torch.exp(self.t * ce_per))).to(mammo_x.device) | |||
til0 = 1/self.t * torch.log(torch.mean(torch.exp(self.t * losses[0].to(mammo_x.device)))).to(mammo_x.device) if len(losses) > 0 else torch.tensor(0) | |||
til1 = 1/self.t * torch.log(torch.mean(torch.exp(self.t * losses[1].to(mammo_x.device)))).to(mammo_x.device) if len(losses) > 1 else torch.tensor(0) | |||
til1,til0 = til1.to(mammo_x.device), til0.to(mammo_x.device) | |||
output['ce_loss'] = ce_loss | |||
output["tilted_loss"] = tilted | |||
output['loss'] = torch.mean(torch.stack([til0,til1])).to(mammo_x.device) + focal_loss | |||
output["til0_loss"] = til0 | |||
output["til1_loss"] = til1 | |||
output['diceloss'] = dice | |||
output['focalloss'] = focal_loss | |||
output['suploss'] = output['loss'] | |||
output['tverskyloss'] = tversky | |||
output['org_pixel_labels'] = mammo_loss_and_gt | |||
return output |
@@ -0,0 +1,85 @@ | |||
from copy import deepcopy | |||
import math | |||
from random import uniform | |||
from typing import List | |||
import time | |||
from mlassistant.core import Model, ModelIO | |||
import torch | |||
from torch.nn import functional as F | |||
from torch import nn | |||
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi | |||
from .....utils.generalized_dice import dice_loss | |||
from .....utils.losses.tverskyLoss import tversky_loss | |||
from .....utils.losses.region_topk_focal_ce import calculate_topk_focal_ce_loss_per_region as topk | |||
from .....utils.dynamic_num_generator.num_generator import init_num_generator1, init_num_generator2 | |||
from .....utils.dynamic_num_generator.wrapper import dynamic_num_generator | |||
class RESUNET(Model): | |||
def __init__(self, | |||
base_model: Model, | |||
gamma = 4 , smooth=1, alpha=0.7, beta=0.3) -> None: | |||
super(RESUNET, self).__init__() | |||
self.base = base_model | |||
self.gamma = gamma | |||
self.smooth = smooth | |||
self.alpha = alpha | |||
self.beta = beta | |||
print(self.gamma) | |||
print("gamma") | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_loss_and_gt: torch.Tensor) -> ModelIO: | |||
"""first forwards supervised samples through student net, then all samples through both nets. """ | |||
output = dict(loss=0.0) | |||
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)" | |||
loss_type = mammo_loss_and_gt[:,0,:,:] | |||
mask = mammo_loss_and_gt[:,1,:,:] | |||
network_output = self.base._forward(mammo_x) # logsoftmax # B | |||
network_output_soft = torch.exp(network_output).to(mammo_x.device) # softmax output of model | |||
output['pixel_probs'] = network_output_soft | |||
main_shape = mask.shape | |||
network_output_shape = network_output.shape | |||
loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2]) | |||
mask = mask.unsqueeze(dim = 1) | |||
mask_channels = torch.cat((1-mask, mask), 1) | |||
dummy_mask = mask_channels.clone().to(mammo_x.device) | |||
dummy_net_out = network_output_soft.clone().to(mammo_x.device) | |||
dummy_mask[loss_channels == 0] = 0.0 | |||
dummy_shape = dummy_mask.shape | |||
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_net_out = torch.cat((dummy_net_out,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_yhat = network_output_soft * loss_channels | |||
dummy_y = mask_channels * loss_channels | |||
focal_loss = balanced_focal_cross_entropy_loss_semi( | |||
dummy_net_out, | |||
dummy_mask, | |||
focal_gamma=self.gamma) | |||
dice = dice_loss(dummy_yhat,dummy_y) | |||
tversky = tversky_loss(inputs = dummy_yhat[:,1,...], targets = dummy_y[:,1,...], smooth=self.smooth, alpha = self.alpha, beta= self.beta) | |||
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y) | |||
topk_loss = topk(model_preds=dummy_yhat,roi_mask = dummy_y[:,0,...], fill_ratio=0.5, focal_gamma=self.gamma) | |||
output['ce_loss'] = ce_loss | |||
output['topk_loss'] = topk_loss | |||
output['loss'] = topk_loss | |||
output['diceloss'] = dice | |||
output['focalloss'] = focal_loss | |||
output['suploss'] = output['loss'] | |||
output['tverskyloss'] = tversky | |||
output['org_pixel_labels'] = mammo_loss_and_gt | |||
return output |
@@ -0,0 +1,87 @@ | |||
from copy import deepcopy | |||
import math | |||
from random import uniform | |||
from typing import List | |||
import time | |||
from mlassistant.core import Model, ModelIO | |||
import torch | |||
from torch.nn import functional as F | |||
from torch import nn | |||
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi | |||
from .....utils.generalized_dice import dice_loss | |||
from .....utils.losses.tverskyLoss import tversky_loss | |||
from .....utils.losses.region_topk_focal_ce import calculate_topk_focal_ce_loss_per_region as topk | |||
from .....utils.dynamic_num_generator.num_generator import init_num_generator1, init_num_generator2 | |||
from .....utils.dynamic_num_generator.wrapper import dynamic_num_generator | |||
class RESUNET(Model): | |||
def __init__(self, | |||
base_model: Model, | |||
gamma = 4 , smooth=1, alpha=0.7, beta=0.3) -> None: | |||
super(RESUNET, self).__init__() | |||
self.base = base_model | |||
self.gamma = gamma | |||
self.smooth = smooth | |||
self.alpha = alpha | |||
self.beta = beta | |||
print(self.gamma) | |||
print("gamma") | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_loss_and_gt: torch.Tensor) -> ModelIO: | |||
"""first forwards supervised samples through student net, then all samples through both nets. """ | |||
output = dict(loss=0.0) | |||
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)" | |||
loss_type = mammo_loss_and_gt[:,0,:,:] | |||
mask = mammo_loss_and_gt[:,1,:,:] | |||
network_output = self.base._forward(mammo_x) # logsoftmax # B | |||
network_output_soft = torch.exp(network_output).to(mammo_x.device) # softmax output of model | |||
output['pixel_probs'] = network_output_soft | |||
main_shape = mask.shape | |||
network_output_shape = network_output.shape | |||
loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2]) | |||
mask = mask.unsqueeze(dim = 1) | |||
mask_channels = torch.cat((1-mask, mask), 1) | |||
dummy_mask = mask_channels.clone().to(mammo_x.device) | |||
dummy_net_out = network_output_soft.clone().to(mammo_x.device) | |||
dummy_mask[loss_channels == 0] = 0.0 | |||
dummy_shape = dummy_mask.shape | |||
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_net_out = torch.cat((dummy_net_out,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_yhat = network_output_soft * loss_channels | |||
dummy_y = mask_channels * loss_channels | |||
focal_loss = balanced_focal_cross_entropy_loss_semi( | |||
dummy_net_out, | |||
dummy_mask, | |||
focal_gamma=self.gamma) | |||
dice = dice_loss(dummy_yhat,dummy_y) | |||
tversky = tversky_loss(inputs = dummy_yhat[:,1,...], targets = dummy_y[:,1,...], smooth=self.smooth, alpha = self.alpha, beta= self.beta) | |||
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y) | |||
topk_loss = topk(model_preds=dummy_yhat,roi_mask = dummy_y[:,0,...], fill_ratio=0.5, focal_gamma=self.gamma) | |||
output['ce_loss'] = ce_loss | |||
output['topk_loss'] = topk_loss | |||
output['loss'] = topk_loss | |||
output['diceloss'] = dice | |||
output['focalloss'] = focal_loss | |||
output['suploss'] = output['loss'] | |||
output['tverskyloss'] = tversky | |||
output['org_pixel_labels'] = mammo_loss_and_gt | |||
return output |
@@ -0,0 +1,85 @@ | |||
from copy import deepcopy | |||
import math | |||
from random import uniform | |||
from typing import List | |||
import time | |||
from mlassistant.core import Model, ModelIO | |||
import torch | |||
from torch.nn import functional as F | |||
from torch import nn | |||
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi | |||
from .....utils.generalized_dice import dice_loss | |||
from .....utils.losses.tverskyLoss import tversky_loss | |||
from .....utils.losses.focal_tversky import focal_tversky | |||
from .....utils.losses.region_topk_focal_ce import calculate_topk_focal_ce_loss_per_region as topk | |||
from .....utils.dynamic_num_generator.num_generator import init_num_generator1, init_num_generator2 | |||
from .....utils.dynamic_num_generator.wrapper import dynamic_num_generator | |||
class RESUNET(Model): | |||
def __init__(self, | |||
base_model: Model, | |||
gamma = 4 , smooth=1, alpha=0.7, beta=0.3) -> None: | |||
super(RESUNET, self).__init__() | |||
self.base = base_model | |||
self.gamma = gamma | |||
self.smooth = smooth | |||
self.alpha = alpha | |||
self.beta = beta | |||
print(self.gamma) | |||
print("gamma") | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_loss_and_gt: torch.Tensor) -> ModelIO: | |||
"""first forwards supervised samples through student net, then all samples through both nets. """ | |||
output = dict(loss=0.0) | |||
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)" | |||
loss_type = mammo_loss_and_gt[:,0,:,:] | |||
mask = mammo_loss_and_gt[:,1,:,:] | |||
network_output = self.base._forward(mammo_x) # logsoftmax # B | |||
network_output_soft = torch.exp(network_output).to(mammo_x.device) # softmax output of model | |||
output['pixel_probs'] = network_output_soft | |||
main_shape = mask.shape | |||
network_output_shape = network_output.shape | |||
loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2]) | |||
mask = mask.unsqueeze(dim = 1) | |||
mask_channels = torch.cat((1-mask, mask), 1) | |||
dummy_mask = mask_channels.clone().to(mammo_x.device) | |||
dummy_net_out = network_output_soft.clone().to(mammo_x.device) | |||
dummy_mask[loss_channels == 0] = 0.0 | |||
dummy_shape = dummy_mask.shape | |||
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_net_out = torch.cat((dummy_net_out,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_yhat = network_output_soft * loss_channels | |||
dummy_y = mask_channels * loss_channels | |||
# focal_loss = balanced_focal_cross_entropy_loss_semi( | |||
# dummy_net_out, | |||
# dummy_mask, | |||
# focal_gamma=self.gamma) | |||
dice = dice_loss(dummy_yhat,dummy_y) | |||
tversky = tversky_loss(inputs = dummy_yhat[:,1,...], targets = dummy_y[:,1,...], smooth=self.smooth, alpha = self.alpha, beta= self.beta) | |||
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y) | |||
#topk_loss = topk(model_preds=dummy_yhat,roi_mask = dummy_y[:,0,...], fill_ratio=0.5, focal_gamma=self.gamma) | |||
tvfocal_loss = focal_tversky(dummy_y[:,1,...].flatten(),dummy_yhat[:,1,...].flatten()) | |||
output['ce_loss'] = ce_loss | |||
output['tvfocal_loss'] = tvfocal_loss | |||
output['loss'] = tvfocal_loss | |||
output['diceloss'] = dice | |||
#output['focalloss'] = focal_loss | |||
output['suploss'] = output['loss'] | |||
output['tverskyloss'] = tversky | |||
output['org_pixel_labels'] = mammo_loss_and_gt | |||
return output |
@@ -0,0 +1,174 @@ | |||
import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
from torch.distributions.uniform import Uniform | |||
def kaiming_normal_init_weight(model): | |||
for m in model.modules(): | |||
if isinstance(m, nn.Conv3d): | |||
torch.nn.init.kaiming_normal_(m.weight) | |||
elif isinstance(m, nn.BatchNorm3d): | |||
m.weight.data.fill_(1) | |||
m.bias.data.zero_() | |||
return model | |||
def sparse_init_weight(model): | |||
for m in model.modules(): | |||
if isinstance(m, nn.Conv3d): | |||
torch.nn.init.sparse_(m.weight, sparsity=0.1) | |||
elif isinstance(m, nn.BatchNorm3d): | |||
m.weight.data.fill_(1) | |||
m.bias.data.zero_() | |||
return model | |||
class ConvBlock(nn.Module): | |||
"""two convolution layers with batch norm and leaky relu""" | |||
def __init__(self, in_channels, out_channels, dropout_p): | |||
super(ConvBlock, self).__init__() | |||
self.conv_conv = nn.Sequential( | |||
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), | |||
nn.BatchNorm2d(out_channels), | |||
nn.LeakyReLU(), | |||
nn.Dropout(dropout_p), | |||
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), | |||
nn.BatchNorm2d(out_channels), | |||
nn.LeakyReLU() | |||
) | |||
def forward(self, x): | |||
return self.conv_conv(x) | |||
class DownBlock(nn.Module): | |||
"""Downsampling followed by ConvBlock""" | |||
def __init__(self, in_channels, out_channels, dropout_p): | |||
super(DownBlock, self).__init__() | |||
self.maxpool_conv = nn.Sequential( | |||
nn.MaxPool2d(2), | |||
ConvBlock(in_channels, out_channels, dropout_p) | |||
) | |||
def forward(self, x): | |||
return self.maxpool_conv(x) | |||
class UpBlock(nn.Module): | |||
"""Upssampling followed by ConvBlock""" | |||
def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, | |||
bilinear=True): | |||
super(UpBlock, self).__init__() | |||
self.bilinear = bilinear | |||
if bilinear: | |||
self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size=1) | |||
self.up = nn.Upsample( | |||
scale_factor=2, mode='bilinear', align_corners=True) | |||
else: | |||
self.up = nn.ConvTranspose2d( | |||
in_channels1, in_channels2, kernel_size=2, stride=2) | |||
self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p) | |||
def forward(self, x1, x2): | |||
if self.bilinear: | |||
x1 = self.conv1x1(x1) | |||
x1 = self.up(x1) | |||
x = torch.cat([x2, x1], dim=1) | |||
return self.conv(x) | |||
class Encoder(nn.Module): | |||
def __init__(self, params): | |||
super(Encoder, self).__init__() | |||
self.params = params | |||
self.in_chns = self.params['in_chns'] | |||
self.ft_chns = self.params['feature_chns'] | |||
self.n_class = self.params['class_num'] | |||
self.bilinear = self.params['bilinear'] | |||
self.dropout = self.params['dropout'] | |||
assert (len(self.ft_chns) == 5) | |||
self.in_conv = ConvBlock( | |||
self.in_chns, self.ft_chns[0], self.dropout[0]) | |||
self.down1 = DownBlock( | |||
self.ft_chns[0], self.ft_chns[1], self.dropout[1]) | |||
self.down2 = DownBlock( | |||
self.ft_chns[1], self.ft_chns[2], self.dropout[2]) | |||
self.down3 = DownBlock( | |||
self.ft_chns[2], self.ft_chns[3], self.dropout[3]) | |||
self.down4 = DownBlock( | |||
self.ft_chns[3], self.ft_chns[4], self.dropout[4]) | |||
def forward(self, x): | |||
x0 = self.in_conv(x) | |||
x1 = self.down1(x0) | |||
x2 = self.down2(x1) | |||
x3 = self.down3(x2) | |||
x4 = self.down4(x3) | |||
return [x0, x1, x2, x3, x4] | |||
class Decoder(nn.Module): | |||
def __init__(self, params): | |||
super(Decoder, self).__init__() | |||
self.params = params | |||
self.in_chns = self.params['in_chns'] | |||
self.ft_chns = self.params['feature_chns'] | |||
self.n_class = self.params['class_num'] | |||
self.bilinear = self.params['bilinear'] | |||
assert (len(self.ft_chns) == 5) | |||
self.up1 = UpBlock( | |||
self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p=0.0) | |||
self.up2 = UpBlock( | |||
self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p=0.0) | |||
self.up3 = UpBlock( | |||
self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p=0.0) | |||
self.up4 = UpBlock( | |||
self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p=0.0) | |||
self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, | |||
kernel_size=3, padding=1) | |||
def forward(self, feature): | |||
x0 = feature[0] | |||
x1 = feature[1] | |||
x2 = feature[2] | |||
x3 = feature[3] | |||
x4 = feature[4] | |||
x = self.up1(x4, x3) | |||
x = self.up2(x, x2) | |||
x = self.up3(x, x1) | |||
x = self.up4(x, x0) | |||
output = self.out_conv(x) | |||
return output | |||
class UNet(nn.Module): | |||
def __init__(self, in_chns, class_num): | |||
super(UNet, self).__init__() | |||
params = {'in_chns': in_chns, | |||
'feature_chns': [64, 128, 128, 256, 512], | |||
'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], | |||
'class_num': class_num, | |||
'bilinear': False, | |||
'acti_func': 'relu'} | |||
self.encoder = Encoder(params) | |||
self.decoder = Decoder(params) | |||
def forward(self, x, need_fp=False): | |||
feature = self.encoder(x) | |||
if need_fp: | |||
outs = self.decoder([torch.cat((feat, nn.Dropout2d(0.5)(feat))) for feat in feature]) | |||
return outs.chunk(2) | |||
output = self.decoder(feature) | |||
return output |
@@ -0,0 +1,79 @@ | |||
from copy import deepcopy | |||
import math | |||
from random import uniform | |||
from typing import List | |||
import time | |||
from mlassistant.core import Model, ModelIO | |||
import torch | |||
from torch.nn import functional as F | |||
from torch import nn | |||
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi | |||
from .....utils.generalized_dice import dice_loss | |||
from .....utils.losses.tverskyLoss import tversky_loss | |||
from .....utils.dynamic_num_generator.num_generator import init_num_generator1, init_num_generator2 | |||
from .....utils.dynamic_num_generator.wrapper import dynamic_num_generator | |||
class RESUNET(Model): | |||
def __init__(self, | |||
base_model: Model, | |||
gamma = 4 , smooth=1, alpha=0.7, beta=0.3) -> None: | |||
super(RESUNET, self).__init__() | |||
self.base = base_model | |||
self.gamma = gamma | |||
self.smooth = smooth | |||
self.alpha = alpha | |||
self.beta = beta | |||
print(self.gamma) | |||
print("gamma") | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_loss_and_gt: torch.Tensor) -> ModelIO: | |||
"""first forwards supervised samples through student net, then all samples through both nets. """ | |||
output = dict(loss=0.0) | |||
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)" | |||
loss_type = mammo_loss_and_gt[:,0,:,:] | |||
mask = mammo_loss_and_gt[:,1,:,:] | |||
network_output = self.base._forward(mammo_x) # logsoftmax # B | |||
network_output_soft = torch.exp(network_output).to(mammo_x.device) # softmax output of model | |||
output['pixel_probs'] = network_output_soft | |||
main_shape = mask.shape | |||
network_output_shape = network_output.shape | |||
loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2]) | |||
mask = mask.unsqueeze(dim = 1) | |||
mask_channels = torch.cat((1-mask, mask), 1) | |||
dummy_mask = mask_channels.clone().to(mammo_x.device) | |||
dummy_net_out = network_output_soft.clone().to(mammo_x.device) | |||
dummy_mask[loss_channels == 0] = 0.0 | |||
dummy_shape = dummy_mask.shape | |||
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_net_out = torch.cat((dummy_net_out,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_yhat = network_output_soft * loss_channels | |||
dummy_y = mask_channels * loss_channels | |||
focal_loss = balanced_focal_cross_entropy_loss_semi( | |||
dummy_net_out, | |||
dummy_mask, | |||
focal_gamma=self.gamma) | |||
dice = dice_loss(dummy_net_out,dummy_mask) | |||
tversky = tversky_loss(inputs = dummy_yhat[:,1,...], targets = dummy_y[:,1,...], smooth=self.smooth, alpha = self.alpha, beta= self.beta) | |||
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y) | |||
output['ce_loss'] = ce_loss | |||
output['loss'] = focal_loss | |||
output['diceloss'] = dice | |||
output['focalloss'] = focal_loss | |||
output['suploss'] = output['loss'] | |||
output['tverskyloss'] = tversky | |||
output['org_pixel_labels'] = mammo_loss_and_gt | |||
return output |
@@ -0,0 +1,156 @@ | |||
from copy import deepcopy | |||
import math | |||
from random import uniform | |||
from typing import List | |||
import time | |||
from mlassistant.core import Model, ModelIO | |||
import torch | |||
from torch import nn | |||
from ....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi | |||
from ....utils.generalized_dice import dice_loss | |||
from ....utils.dynamic_num_generator.num_generator import init_num_generator1, init_num_generator2 | |||
from ....utils.dynamic_num_generator.wrapper import dynamic_num_generator | |||
class UncertaintyAwareStudentMeanTeacherNet(Model): | |||
""" adapted from the paper 'Uncertainty-aware Self-ensembling Model for Semi-supervised | |||
3D Left Atrium Segmentation'. link: https://arxiv.org/pdf/1907.07034.pdf | |||
Args: | |||
alpha (float): exponential moving average decay | |||
total_num_iteration (int): total number of iteration this run will have | |||
base_model (Model): base networks for both student and teacher | |||
x_arg (str): argument name of `forward` indicating images, shape: B ... | |||
x_roi_arg (str): argument name of `forward` indicating ground-truth ROIs, shape: B ... | |||
x_annotated_arg (str): argument name of `forward` indicating being labeled or not for inputs, shape: B | |||
uncertainty_ratio_thd (float): the initial ratio threshold for uncertainty pruning | |||
std (float): standard deviation of gaussian noise to be applied to the input | |||
mean (float): mean of gaussian noise to be applied to the input | |||
T (int): number of perturbations for teacher net | |||
lambda_coef (float): lambda coefficient for calculating unsupervised loss | |||
use_supervised_in_consistency_checking (bool): if true then use supervised samples in unsupervised task (consistency checking) | |||
""" | |||
def __init__(self, | |||
alpha: float, | |||
base_model: Model, | |||
std: float = math.sqrt(0.0001), | |||
mean: float = 0.0, | |||
T: int=1, | |||
use_supervised_in_consistency_checking: bool = True, | |||
uncertainty_thd_calculator=dynamic_num_generator(num_iters=200, func=init_num_generator2()), | |||
unsup_coef_generator=dynamic_num_generator(num_iters=200, func=init_num_generator1()), | |||
gamma = 2) -> None: | |||
super(UncertaintyAwareStudentMeanTeacherNet, self).__init__() | |||
assert alpha >= 0 and alpha <= 1, f'alpha should be in range [0, 1]; got {alpha} instead.' | |||
assert T > 0, f'number of perturbations to be applied for teacher net should be at least one; got {T} instead.' | |||
self._T: int = T | |||
self._alpha: float = alpha | |||
self._mean: float = mean | |||
self._std: float = std | |||
self._use_supervised_in_consistency_checking: bool = use_supervised_in_consistency_checking | |||
self._uncertainty_thd_calculator = uncertainty_thd_calculator | |||
self._unsup_coef_generator = unsup_coef_generator | |||
self._student: Model = base_model | |||
self._teacher: Model = deepcopy(self._student) | |||
self.gamma = gamma | |||
def _update_teacher_params(self, | |||
_: Model, | |||
__: torch.Tensor) -> None: | |||
"""a backward hook to update teacher network parameters, using exponential moving | |||
average method.""" | |||
for s_p, t_p in zip(self._student.parameters(), self._teacher.parameters()): | |||
t_p.data = self._alpha * t_p.data + (1 - self._alpha) * s_p.data | |||
def _apply_gaussian_noise(self, | |||
input ) -> torch.Tensor: | |||
inp = input | |||
noise = torch.randn(inp.size()) * self._std + self._mean | |||
noise = noise.to(inp.device) | |||
inp = inp + (noise * inp) | |||
input = inp | |||
return input | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_loss_and_gt: torch.Tensor) -> ModelIO: | |||
"""first forwards supervised samples through student net, then all samples through both nets. """ | |||
output = dict(loss=0.0) | |||
if mammo_loss_and_gt is None: | |||
output['loss'] = torch.tensor(0.0) | |||
output['loss'] = torch.tensor(0.0) | |||
output['mse_loss'] = torch.tensor(0.0) | |||
output['suploss'] = torch.tensor(0.0) | |||
output['focalloss'] = torch.tensor(0.0) | |||
#output['org_pixel_labels'] = output['pixel_probs'] | |||
# output['ce_loss'] = torch.tensor(0.0) | |||
# output['ce_0_loss'] = torch.tensor(0.0) | |||
# output['ce_1_loss'] = torch.tensor(0.0) | |||
# output['dice_loss'] = torch.tensor(0.0) | |||
return output | |||
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)" | |||
loss_type_lable = mammo_loss_and_gt[:,0,:,:] # 1 means supervised and zero means unsupervised | |||
mask = mammo_loss_and_gt[:,1,:,:] # 1 means cancer and zero means backgroundru ru network_output = self._student._forward(mammo_x) | |||
network_output = self._student._forward(mammo_x) # torch.nn.LogSoftmax has been applied on last layer | |||
network_output_soft = torch.exp(network_output).to(mammo_x.device) | |||
output['pixel_probs'] = network_output_soft | |||
main_shape = loss_type_lable.shape | |||
network_output_shape = network_output.shape | |||
dup_loss_type = loss_type_lable.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2]) | |||
mask = mask.unsqueeze(dim = 1) | |||
dup_mask = torch.cat((1-mask, mask), 1) | |||
supervised_loss = torch.tensor(0.0).to(mammo_x.device) | |||
focal_mask = dup_mask.clone().to(mammo_x.device) | |||
focal_net_out = network_output_soft.clone().to(mammo_x.device) | |||
focal_mask[dup_loss_type == 0] = 0.0 | |||
focal_shape = focal_mask.shape | |||
focal_mask = torch.cat((focal_mask,torch.ones((focal_shape[0],1,focal_shape[2],focal_shape[2])).to(mammo_x.device)), 1) | |||
focal_net_out = torch.cat((focal_net_out,torch.ones((focal_shape[0],1,focal_shape[2],focal_shape[2])).to(mammo_x.device)), 1) | |||
supervised_loss += balanced_focal_cross_entropy_loss_semi( | |||
focal_net_out, | |||
focal_mask, | |||
focal_gamma=self.gamma) | |||
dice = dice_loss(focal_net_out,focal_mask) | |||
with torch.no_grad(): | |||
teacher_outs: List[torch.Tensor] = list() | |||
for _ in range(self._T): | |||
teacher_outs.append(self._teacher._forward( | |||
self._apply_gaussian_noise(mammo_x))) | |||
out_t_u = torch.exp(torch.stack(teacher_outs, dim=2)).mean(dim=2) # B C H' W' | |||
uncertainty = -1 * torch.sum(out_t_u * torch.log(out_t_u), 1) # B H' W' | |||
mse = torch.sum((network_output_soft - out_t_u) ** 2, dim=1) | |||
# consider only most certain pixels for mse loss | |||
mse_ = torch.zeros_like(mse) | |||
UThreshhold = self._uncertainty_thd_calculator() | |||
mask_mse = uncertainty < UThreshhold | |||
mse_[mask_mse] += mse[mask_mse] | |||
active_pxs = torch.sum(mask_mse.long(), dim=[1, 2]) # B | |||
mse_loss = torch.sum(mse_ / active_pxs[:, None, None], dim=[1, 2]) | |||
mse_loss = torch.mean(mse_loss) | |||
mse_coef = self._unsup_coef_generator() | |||
output['loss'] = mse_coef * mse_loss | |||
output['mse_loss'] = mse_loss | |||
output['mse_loss'] = mse_coef | |||
output['Uthreshhold'] = UThreshhold | |||
output['suploss'] = supervised_loss | |||
output['focalloss'] = supervised_loss | |||
output['org_pixel_labels'] = mammo_loss_and_gt | |||
# output['ce_loss'] = torch.tensor(0.0) | |||
# output['ce_0_loss'] = torch.tensor(0.0) | |||
# output['ce_1_loss'] = torch.tensor(0.0) | |||
# output['dice_loss'] = torch.tensor(0.0) | |||
return output |
@@ -0,0 +1,80 @@ | |||
from copy import deepcopy | |||
import math | |||
from random import uniform | |||
from typing import List | |||
import time | |||
from mlassistant.core import Model, ModelIO | |||
import torch | |||
from torch.nn import functional as F | |||
from torch import nn | |||
from ....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi | |||
from ....utils.generalized_dice import dice_loss | |||
from ....utils.losses.tverskyLoss import tversky_loss | |||
from ....utils.dynamic_num_generator.num_generator import init_num_generator1, init_num_generator2 | |||
from ....utils.dynamic_num_generator.wrapper import dynamic_num_generator | |||
class RESUNET(Model): | |||
def __init__(self, | |||
base_model: Model, | |||
gamma = 2 , smooth=1, alpha=0.7, beta=0.3) -> None: | |||
super(RESUNET, self).__init__() | |||
self.base = base_model | |||
self.gamma = gamma | |||
self.smooth = smooth | |||
self.alpha = alpha | |||
self.beta = beta | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_loss_and_gt: torch.Tensor) -> ModelIO: | |||
"""first forwards supervised samples through student net, then all samples through both nets. """ | |||
output = dict(loss=0.0) | |||
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)" | |||
loss_type = mammo_loss_and_gt[:,0,:,:] | |||
mask = mammo_loss_and_gt[:,1,:,:] | |||
network_output = self.base._forward(mammo_x) # logsoftmax # B | |||
network_output_soft = torch.exp(network_output).to(mammo_x.device) # softmax output of model | |||
output['pixel_probs'] = network_output_soft | |||
main_shape = mask.shape | |||
network_output_shape = network_output.shape | |||
loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2]) | |||
mask = mask.unsqueeze(dim = 1) | |||
mask_channels = torch.cat((1-mask, mask), 1) | |||
dummy_mask = mask_channels.clone().to(mammo_x.device) | |||
dummy_net_out = network_output_soft.clone().to(mammo_x.device) | |||
dummy_mask[loss_channels == 0] = 0.0 | |||
dummy_shape = dummy_mask.shape | |||
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_net_out = torch.cat((dummy_net_out,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_yhat = network_output_soft * loss_channels | |||
dummy_y = mask_channels * loss_channels | |||
focal_loss = balanced_focal_cross_entropy_loss_semi( | |||
dummy_net_out, | |||
dummy_mask, | |||
focal_gamma=self.gamma) | |||
dice = dice_loss(dummy_net_out,dummy_mask) | |||
tversky = tversky_loss(inputs = dummy_yhat[:,1,...], targets = dummy_y[:,1,...], smooth=self.smooth, alpha = self.alpha, beta= self.beta) | |||
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y) | |||
output['ce_1_loss'] = F.binary_cross_entropy(dummy_yhat * dummy_y,dummy_y) | |||
output['ce_0_loss'] = F.binary_cross_entropy(dummy_yhat * (1 - dummy_y) + dummy_y, dummy_y) | |||
output['ce_loss'] = ce_loss | |||
output['loss'] = 1/2 *(dice + focal_loss) | |||
output['diceloss'] = dice | |||
output['focalloss'] = focal_loss | |||
output['suploss'] = dice | |||
output['tverskyloss'] = tversky | |||
output['org_pixel_labels'] = mammo_loss_and_gt | |||
return output |
@@ -0,0 +1,396 @@ | |||
import copy | |||
import math | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
import numpy as np | |||
from torch.nn import Dropout, Softmax, Linear, Conv2d, LayerNorm | |||
from torch.nn.modules.utils import _pair | |||
from mlassistant.core import ModelIO, Model | |||
from ......utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi | |||
from ......utils.generalized_dice import dice_loss | |||
from ......utils.losses.tverskyLoss import tversky_loss | |||
def np2th(weights, conv=False): | |||
"""Possibly convert HWIO to OIHW.""" | |||
if conv: | |||
weights = weights.transpose([3, 2, 0, 1]) | |||
return torch.from_numpy(weights) | |||
def swish(x): | |||
return x * torch.sigmoid(x) | |||
ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish} | |||
class Attention(nn.Module): | |||
def __init__(self, config, vis): | |||
super(Attention, self).__init__() | |||
self.vis = vis | |||
self.num_attention_heads = config['transformer']["num_heads"] | |||
self.attention_head_size = int(config['hidden_size'] / self.num_attention_heads) | |||
self.all_head_size = self.num_attention_heads * self.attention_head_size | |||
self.query = Linear(config['hidden_size'], self.all_head_size) | |||
self.key = Linear(config['hidden_size'], self.all_head_size) | |||
self.value = Linear(config['hidden_size'], self.all_head_size) | |||
self.out = Linear(config['hidden_size'], config['hidden_size']) | |||
self.attn_dropout = Dropout(config['transformer']["attention_dropout_rate"]) | |||
self.proj_dropout = Dropout(config['transformer']["attention_dropout_rate"]) | |||
self.softmax = Softmax(dim=-1) | |||
def transpose_for_scores(self, x): | |||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) | |||
x = x.view(*new_x_shape) | |||
return x.permute(0, 2, 1, 3) | |||
def forward(self, hidden_states): | |||
mixed_query_layer = self.query(hidden_states) | |||
mixed_key_layer = self.key(hidden_states) | |||
mixed_value_layer = self.value(hidden_states) | |||
query_layer = self.transpose_for_scores(mixed_query_layer) | |||
key_layer = self.transpose_for_scores(mixed_key_layer) | |||
value_layer = self.transpose_for_scores(mixed_value_layer) | |||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) | |||
attention_scores = attention_scores / math.sqrt(self.attention_head_size) | |||
attention_probs = self.softmax(attention_scores) | |||
weights = attention_probs if self.vis else None | |||
attention_probs = self.attn_dropout(attention_probs) | |||
context_layer = torch.matmul(attention_probs, value_layer) | |||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() | |||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) | |||
context_layer = context_layer.view(*new_context_layer_shape) | |||
attention_output = self.out(context_layer) | |||
attention_output = self.proj_dropout(attention_output) | |||
return attention_output, weights | |||
class Mlp(nn.Module): | |||
def __init__(self, config): | |||
super(Mlp, self).__init__() | |||
self.fc1 = Linear(config['hidden_size'], config['transformer']["mlp_dim"]) | |||
self.fc2 = Linear(config['transformer']["mlp_dim"], config['hidden_size']) | |||
self.act_fn = ACT2FN["gelu"] | |||
self.dropout = Dropout(config['transformer']["dropout_rate"]) | |||
self._init_weights() | |||
def _init_weights(self): | |||
nn.init.xavier_uniform_(self.fc1.weight) | |||
nn.init.xavier_uniform_(self.fc2.weight) | |||
nn.init.normal_(self.fc1.bias, std=1e-6) | |||
nn.init.normal_(self.fc2.bias, std=1e-6) | |||
def forward(self, x): | |||
x = self.fc1(x) | |||
x = self.act_fn(x) | |||
x = self.dropout(x) | |||
x = self.fc2(x) | |||
x = self.dropout(x) | |||
return x | |||
class Embeddings(nn.Module): | |||
"""Construct the embeddings from patch, position embeddings. | |||
""" | |||
def __init__(self, config, img_size, in_channels=3): | |||
super(Embeddings, self).__init__() | |||
self.hybrid = None | |||
self.config = config | |||
img_size = _pair(img_size) | |||
if config['patches'].get("grid") is not None: # ResNet | |||
grid_size = config['patches']["grid"] | |||
patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1]) | |||
patch_size_real = (patch_size[0] * 16, patch_size[1] * 16) | |||
n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1]) | |||
self.hybrid = True | |||
else: | |||
patch_size = _pair(config['patches']["size"]) | |||
n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) | |||
self.hybrid = False | |||
self.patch_embeddings = Conv2d(in_channels=in_channels, | |||
out_channels=config['hidden_size'], | |||
kernel_size=patch_size, | |||
stride=patch_size) | |||
self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config['hidden_size'])) | |||
self.dropout = Dropout(config['transformer']["dropout_rate"]) | |||
def forward(self, x): | |||
if self.hybrid: | |||
x, features = self.hybrid_model(x) | |||
else: | |||
features = None | |||
x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2)) | |||
x = x.flatten(2) | |||
x = x.transpose(-1, -2) # (B, n_patches, hidden) | |||
embeddings = x + self.position_embeddings | |||
embeddings = self.dropout(embeddings) | |||
return embeddings, features | |||
class Block(nn.Module): | |||
def __init__(self, config, vis): | |||
super(Block, self).__init__() | |||
self.hidden_size = config['hidden_size'] | |||
self.attention_norm = LayerNorm(config['hidden_size'], eps=1e-6) | |||
self.ffn_norm = LayerNorm(config['hidden_size'], eps=1e-6) | |||
self.ffn = Mlp(config) | |||
self.attn = Attention(config, vis) | |||
def forward(self, x): | |||
h = x | |||
x = self.attention_norm(x) | |||
x, weights = self.attn(x) | |||
x = x + h | |||
h = x | |||
x = self.ffn_norm(x) | |||
x = self.ffn(x) | |||
x = x + h | |||
return x, weights | |||
class Encoder(nn.Module): | |||
def __init__(self, config, vis): | |||
super(Encoder, self).__init__() | |||
self.vis = vis | |||
self.layer = nn.ModuleList() | |||
self.encoder_norm = LayerNorm(config['hidden_size'], eps=1e-6) | |||
for _ in range(config['transformer']["num_layers"]): | |||
layer = Block(config, vis) | |||
self.layer.append(copy.deepcopy(layer)) | |||
def forward(self, hidden_states): | |||
attn_weights = [] | |||
for layer_block in self.layer: | |||
hidden_states, weights = layer_block(hidden_states) | |||
if self.vis: | |||
attn_weights.append(weights) | |||
encoded = self.encoder_norm(hidden_states) | |||
return encoded, attn_weights | |||
class Transformer(nn.Module): | |||
def __init__(self, config, img_size, vis): | |||
super(Transformer, self).__init__() | |||
self.embeddings = Embeddings(config, img_size=img_size) | |||
self.encoder = Encoder(config, vis) | |||
def forward(self, input_ids): | |||
embedding_output, features = self.embeddings(input_ids) | |||
encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden) | |||
return encoded, attn_weights, features | |||
class Conv2dReLU(nn.Sequential): | |||
def __init__( | |||
self, | |||
in_channels, | |||
out_channels, | |||
kernel_size, | |||
padding=0, | |||
stride=1, | |||
use_batchnorm=True, | |||
): | |||
conv = nn.Conv2d( | |||
in_channels, | |||
out_channels, | |||
kernel_size, | |||
stride=stride, | |||
padding=padding, | |||
bias=not (use_batchnorm), | |||
) | |||
relu = nn.ReLU(inplace=True) | |||
bn = nn.BatchNorm2d(out_channels) | |||
super(Conv2dReLU, self).__init__(conv, bn, relu) | |||
class DecoderBlock(nn.Module): | |||
def __init__( | |||
self, | |||
in_channels, | |||
out_channels, | |||
skip_channels=0, | |||
use_batchnorm=True, | |||
): | |||
super().__init__() | |||
self.conv1 = Conv2dReLU( | |||
in_channels + skip_channels, | |||
out_channels, | |||
kernel_size=3, | |||
padding=1, | |||
use_batchnorm=use_batchnorm, | |||
) | |||
self.conv2 = Conv2dReLU( | |||
out_channels, | |||
out_channels, | |||
kernel_size=3, | |||
padding=1, | |||
use_batchnorm=use_batchnorm, | |||
) | |||
self.up = nn.UpsamplingBilinear2d(scale_factor=2) | |||
def forward(self, x, skip=None): | |||
x = self.up(x) | |||
if skip is not None: | |||
x = torch.cat([x, skip], dim=1) | |||
x = self.conv1(x) | |||
x = self.conv2(x) | |||
return x | |||
class SegmentationHead(nn.Sequential): | |||
def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1): | |||
conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2) | |||
upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity() | |||
super().__init__(conv2d, upsampling) | |||
class DecoderCup(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.config = config | |||
head_channels = 512 | |||
self.conv_more = Conv2dReLU( | |||
config['hidden_size'], | |||
head_channels, | |||
kernel_size=3, | |||
padding=1, | |||
use_batchnorm=True, | |||
) | |||
decoder_channels = config['decoder_channels'] | |||
in_channels = [head_channels] + list(decoder_channels[:-1]) | |||
out_channels = decoder_channels | |||
skip_channels=[0,0,0,0] | |||
blocks = [ | |||
DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels) | |||
] | |||
self.blocks = nn.ModuleList(blocks) | |||
def forward(self, hidden_states, features=None): | |||
B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden) | |||
h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch)) | |||
x = hidden_states.permute(0, 2, 1) | |||
x = x.contiguous().view(B, hidden, h, w) | |||
x = self.conv_more(x) | |||
for i, decoder_block in enumerate(self.blocks): | |||
if features is not None: | |||
skip = features[i] if (i < self.config.n_skip) else None | |||
else: | |||
skip = None | |||
x = decoder_block(x, skip=skip) | |||
return x | |||
class VisionTransformer(nn.Module): | |||
def __init__(self, img_size=512, num_classes=2, zero_head=False, vis=False): | |||
super(VisionTransformer, self).__init__() | |||
config = self._get_b16_config() | |||
self.num_classes = num_classes | |||
self.zero_head = zero_head | |||
self.transformer = Transformer(config, img_size, vis) | |||
self.decoder = DecoderCup(config) | |||
self.segmentation_head = SegmentationHead( | |||
in_channels=config['decoder_channels'][-1], | |||
out_channels=config['n_classes'], | |||
kernel_size=3, | |||
) | |||
self.config = config | |||
def forward(self, x): | |||
if x.size()[1] == 1: | |||
x = x.repeat(1,3,1,1) | |||
x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden) | |||
x = self.decoder(x, features) | |||
logits = self.segmentation_head(x) | |||
return F.softmax(logits, dim=1) | |||
def _get_b16_config(self): | |||
"""Returns the ViT-B/16 configuration.""" | |||
config = {} | |||
config['patches'] = {'size': (16, 16)} | |||
config['hidden_size'] = 768 | |||
config['transformer'] = {} | |||
config['transformer']['mlp_dim'] = 3072 | |||
config['transformer']['num_heads'] = 12 | |||
config['transformer']['num_layers'] = 12 | |||
config['transformer']['attention_dropout_rate'] = 0.0 | |||
config['transformer']['dropout_rate'] = 0.1 | |||
config['decoder_channels'] = (512, 256, 128, 64) | |||
config['n_classes'] = 2 | |||
return config | |||
class TransUNet(Model): | |||
""" | |||
TransUNet architecture for baseline evaluation. | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
self.trans_unet = VisionTransformer() | |||
self.trans_unet = torch.compile(self.trans_unet) | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_loss_and_gt: torch.Tensor) -> ModelIO: | |||
loss_type = mammo_loss_and_gt[:, 0, ...] | |||
ground_truth = mammo_loss_and_gt[:, 1, ...] * loss_type | |||
out_prob_map = self.trans_unet(mammo_x) | |||
main_shape = ground_truth.shape | |||
out_prob_map_shape = out_prob_map.shape | |||
loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],out_prob_map_shape[1],main_shape[1], main_shape[2]) | |||
ground_truth = ground_truth.unsqueeze(dim = 1) | |||
mask_channels = torch.cat((1 - ground_truth, ground_truth), 1) | |||
dummy_mask = mask_channels.clone().to(mammo_x.device) | |||
dummy_net_out = out_prob_map.clone().to(mammo_x.device) | |||
dummy_mask[loss_channels == 0] = 0.0 | |||
dummy_shape = dummy_mask.shape | |||
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_net_out = torch.cat((dummy_net_out, torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_yhat = out_prob_map * loss_channels | |||
dummy_y = mask_channels * loss_channels | |||
dice = dice_loss(dummy_net_out,dummy_mask) | |||
tversky = tversky_loss(inputs = dummy_yhat[:, 1, ...], targets = dummy_y[:, 1, ...], smooth=1, alpha = 0.7, beta=0.3) | |||
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y) | |||
output = { | |||
'pixel_probs': out_prob_map, | |||
'org_pixel_labels': mammo_loss_and_gt, | |||
'loss': balanced_focal_cross_entropy_loss_semi(dummy_net_out, dummy_mask, 2), | |||
'ce': ce_loss, | |||
'dice_loss': dice, | |||
'tversky': tversky | |||
} | |||
return output |
@@ -0,0 +1,401 @@ | |||
import copy | |||
import math | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
import numpy as np | |||
from torch.nn import Dropout, Softmax, Linear, Conv2d, LayerNorm | |||
from torch.nn.modules.utils import _pair | |||
from mlassistant.core import ModelIO, Model | |||
from ......utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi | |||
from ......utils.generalized_dice import dice_loss | |||
from ......utils.losses.breast_topk_focal_ce import calculate_breast_topk_focal_ce_loss | |||
def np2th(weights, conv=False): | |||
"""Possibly convert HWIO to OIHW.""" | |||
if conv: | |||
weights = weights.transpose([3, 2, 0, 1]) | |||
return torch.from_numpy(weights) | |||
def swish(x): | |||
return x * torch.sigmoid(x) | |||
ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish} | |||
class Attention(nn.Module): | |||
def __init__(self, config, vis): | |||
super(Attention, self).__init__() | |||
self.vis = vis | |||
self.num_attention_heads = config['transformer']["num_heads"] | |||
self.attention_head_size = int(config['hidden_size'] / self.num_attention_heads) | |||
self.all_head_size = self.num_attention_heads * self.attention_head_size | |||
self.query = Linear(config['hidden_size'], self.all_head_size) | |||
self.key = Linear(config['hidden_size'], self.all_head_size) | |||
self.value = Linear(config['hidden_size'], self.all_head_size) | |||
self.out = Linear(config['hidden_size'], config['hidden_size']) | |||
self.attn_dropout = Dropout(config['transformer']["attention_dropout_rate"]) | |||
self.proj_dropout = Dropout(config['transformer']["attention_dropout_rate"]) | |||
self.softmax = Softmax(dim=-1) | |||
def transpose_for_scores(self, x): | |||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) | |||
x = x.view(*new_x_shape) | |||
return x.permute(0, 2, 1, 3) | |||
def forward(self, hidden_states): | |||
mixed_query_layer = self.query(hidden_states) | |||
mixed_key_layer = self.key(hidden_states) | |||
mixed_value_layer = self.value(hidden_states) | |||
query_layer = self.transpose_for_scores(mixed_query_layer) | |||
key_layer = self.transpose_for_scores(mixed_key_layer) | |||
value_layer = self.transpose_for_scores(mixed_value_layer) | |||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) | |||
attention_scores = attention_scores / math.sqrt(self.attention_head_size) | |||
attention_probs = self.softmax(attention_scores) | |||
weights = attention_probs if self.vis else None | |||
attention_probs = self.attn_dropout(attention_probs) | |||
context_layer = torch.matmul(attention_probs, value_layer) | |||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() | |||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) | |||
context_layer = context_layer.view(*new_context_layer_shape) | |||
attention_output = self.out(context_layer) | |||
attention_output = self.proj_dropout(attention_output) | |||
return attention_output, weights | |||
class Mlp(nn.Module): | |||
def __init__(self, config): | |||
super(Mlp, self).__init__() | |||
self.fc1 = Linear(config['hidden_size'], config['transformer']["mlp_dim"]) | |||
self.fc2 = Linear(config['transformer']["mlp_dim"], config['hidden_size']) | |||
self.act_fn = ACT2FN["gelu"] | |||
self.dropout = Dropout(config['transformer']["dropout_rate"]) | |||
self._init_weights() | |||
def _init_weights(self): | |||
nn.init.xavier_uniform_(self.fc1.weight) | |||
nn.init.xavier_uniform_(self.fc2.weight) | |||
nn.init.normal_(self.fc1.bias, std=1e-6) | |||
nn.init.normal_(self.fc2.bias, std=1e-6) | |||
def forward(self, x): | |||
x = self.fc1(x) | |||
x = self.act_fn(x) | |||
x = self.dropout(x) | |||
x = self.fc2(x) | |||
x = self.dropout(x) | |||
return x | |||
class Embeddings(nn.Module): | |||
"""Construct the embeddings from patch, position embeddings. | |||
""" | |||
def __init__(self, config, img_size, in_channels=3): | |||
super(Embeddings, self).__init__() | |||
self.hybrid = None | |||
self.config = config | |||
img_size = _pair(img_size) | |||
if config['patches'].get("grid") is not None: # ResNet | |||
grid_size = config['patches']["grid"] | |||
patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1]) | |||
patch_size_real = (patch_size[0] * 16, patch_size[1] * 16) | |||
n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1]) | |||
self.hybrid = True | |||
else: | |||
patch_size = _pair(config['patches']["size"]) | |||
n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) | |||
self.hybrid = False | |||
self.patch_embeddings = Conv2d(in_channels=in_channels, | |||
out_channels=config['hidden_size'], | |||
kernel_size=patch_size, | |||
stride=patch_size) | |||
self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config['hidden_size'])) | |||
self.dropout = Dropout(config['transformer']["dropout_rate"]) | |||
def forward(self, x): | |||
if self.hybrid: | |||
x, features = self.hybrid_model(x) | |||
else: | |||
features = None | |||
x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2)) | |||
x = x.flatten(2) | |||
x = x.transpose(-1, -2) # (B, n_patches, hidden) | |||
embeddings = x + self.position_embeddings | |||
embeddings = self.dropout(embeddings) | |||
return embeddings, features | |||
class Block(nn.Module): | |||
def __init__(self, config, vis): | |||
super(Block, self).__init__() | |||
self.hidden_size = config['hidden_size'] | |||
self.attention_norm = LayerNorm(config['hidden_size'], eps=1e-6) | |||
self.ffn_norm = LayerNorm(config['hidden_size'], eps=1e-6) | |||
self.ffn = Mlp(config) | |||
self.attn = Attention(config, vis) | |||
def forward(self, x): | |||
h = x | |||
x = self.attention_norm(x) | |||
x, weights = self.attn(x) | |||
x = x + h | |||
h = x | |||
x = self.ffn_norm(x) | |||
x = self.ffn(x) | |||
x = x + h | |||
return x, weights | |||
class Encoder(nn.Module): | |||
def __init__(self, config, vis): | |||
super(Encoder, self).__init__() | |||
self.vis = vis | |||
self.layer = nn.ModuleList() | |||
self.encoder_norm = LayerNorm(config['hidden_size'], eps=1e-6) | |||
for _ in range(config['transformer']["num_layers"]): | |||
layer = Block(config, vis) | |||
self.layer.append(copy.deepcopy(layer)) | |||
def forward(self, hidden_states): | |||
attn_weights = [] | |||
for layer_block in self.layer: | |||
hidden_states, weights = layer_block(hidden_states) | |||
if self.vis: | |||
attn_weights.append(weights) | |||
encoded = self.encoder_norm(hidden_states) | |||
return encoded, attn_weights | |||
class Transformer(nn.Module): | |||
def __init__(self, config, img_size, vis): | |||
super(Transformer, self).__init__() | |||
self.embeddings = Embeddings(config, img_size=img_size) | |||
self.encoder = Encoder(config, vis) | |||
def forward(self, input_ids): | |||
embedding_output, features = self.embeddings(input_ids) | |||
encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden) | |||
return encoded, attn_weights, features | |||
class Conv2dReLU(nn.Sequential): | |||
def __init__( | |||
self, | |||
in_channels, | |||
out_channels, | |||
kernel_size, | |||
padding=0, | |||
stride=1, | |||
use_batchnorm=True, | |||
): | |||
conv = nn.Conv2d( | |||
in_channels, | |||
out_channels, | |||
kernel_size, | |||
stride=stride, | |||
padding=padding, | |||
bias=not (use_batchnorm), | |||
) | |||
relu = nn.ReLU(inplace=True) | |||
bn = nn.BatchNorm2d(out_channels) | |||
super(Conv2dReLU, self).__init__(conv, bn, relu) | |||
class DecoderBlock(nn.Module): | |||
def __init__( | |||
self, | |||
in_channels, | |||
out_channels, | |||
skip_channels=0, | |||
use_batchnorm=True, | |||
): | |||
super().__init__() | |||
self.conv1 = Conv2dReLU( | |||
in_channels + skip_channels, | |||
out_channels, | |||
kernel_size=3, | |||
padding=1, | |||
use_batchnorm=use_batchnorm, | |||
) | |||
self.conv2 = Conv2dReLU( | |||
out_channels, | |||
out_channels, | |||
kernel_size=3, | |||
padding=1, | |||
use_batchnorm=use_batchnorm, | |||
) | |||
self.up = nn.UpsamplingBilinear2d(scale_factor=2) | |||
def forward(self, x, skip=None): | |||
x = self.up(x) | |||
if skip is not None: | |||
x = torch.cat([x, skip], dim=1) | |||
x = self.conv1(x) | |||
x = self.conv2(x) | |||
return x | |||
class SegmentationHead(nn.Sequential): | |||
def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1): | |||
conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2) | |||
upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity() | |||
super().__init__(conv2d, upsampling) | |||
class DecoderCup(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.config = config | |||
head_channels = 512 | |||
self.conv_more = Conv2dReLU( | |||
config['hidden_size'], | |||
head_channels, | |||
kernel_size=3, | |||
padding=1, | |||
use_batchnorm=True, | |||
) | |||
decoder_channels = config['decoder_channels'] | |||
in_channels = [head_channels] + list(decoder_channels[:-1]) | |||
out_channels = decoder_channels | |||
skip_channels=[0,0,0,0] | |||
blocks = [ | |||
DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels) | |||
] | |||
self.blocks = nn.ModuleList(blocks) | |||
def forward(self, hidden_states, features=None): | |||
B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden) | |||
h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch)) | |||
x = hidden_states.permute(0, 2, 1) | |||
x = x.contiguous().view(B, hidden, h, w) | |||
x = self.conv_more(x) | |||
for i, decoder_block in enumerate(self.blocks): | |||
if features is not None: | |||
skip = features[i] if (i < self.config.n_skip) else None | |||
else: | |||
skip = None | |||
x = decoder_block(x, skip=skip) | |||
return x | |||
class VisionTransformer(nn.Module): | |||
def __init__(self, img_size=512, num_classes=2, zero_head=False, vis=False): | |||
super(VisionTransformer, self).__init__() | |||
config = self._get_b16_config() | |||
self.num_classes = num_classes | |||
self.zero_head = zero_head | |||
self.transformer = Transformer(config, img_size, vis) | |||
self.decoder = DecoderCup(config) | |||
self.segmentation_head = SegmentationHead( | |||
in_channels=config['decoder_channels'][-1], | |||
out_channels=config['n_classes'], | |||
kernel_size=3, | |||
) | |||
self.config = config | |||
def forward(self, x): | |||
if x.size()[1] == 1: | |||
x = x.repeat(1,3,1,1) | |||
x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden) | |||
x = self.decoder(x, features) | |||
logits = self.segmentation_head(x) | |||
return F.softmax(logits, dim=1) | |||
def _get_b16_config(self): | |||
"""Returns the ViT-B/16 configuration.""" | |||
config = {} | |||
config['patches'] = {'size': (16, 16)} | |||
config['hidden_size'] = 768 | |||
config['transformer'] = {} | |||
config['transformer']['mlp_dim'] = 3072 | |||
config['transformer']['num_heads'] = 12 | |||
config['transformer']['num_layers'] = 12 | |||
config['transformer']['attention_dropout_rate'] = 0.0 | |||
config['transformer']['dropout_rate'] = 0.1 | |||
config['decoder_channels'] = (512, 256, 128, 64) | |||
config['n_classes'] = 2 | |||
return config | |||
class TransUNet(Model): | |||
""" | |||
TransUNet architecture for baseline evaluation. | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
self.trans_unet = VisionTransformer() | |||
self.iter_counter = 0 | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_loss_and_gt: torch.Tensor) -> ModelIO: | |||
self.iter_counter += 1 | |||
loss_type = mammo_loss_and_gt[:, 0, ...] | |||
ground_truth = mammo_loss_and_gt[:, 1, ...] * loss_type | |||
out_prob_map = self.trans_unet(mammo_x) | |||
main_shape = ground_truth.shape | |||
out_prob_map_shape = out_prob_map.shape | |||
loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],out_prob_map_shape[1],main_shape[1], main_shape[2]) | |||
ground_truth = ground_truth.unsqueeze(dim = 1) | |||
mask_channels = torch.cat((1 - ground_truth, ground_truth), 1) | |||
dummy_mask = mask_channels.clone().to(mammo_x.device) | |||
dummy_net_out = out_prob_map.clone().to(mammo_x.device) | |||
dummy_mask[loss_channels == 0] = 0.0 | |||
dummy_shape = dummy_mask.shape | |||
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_net_out = torch.cat((dummy_net_out, torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_yhat = out_prob_map * loss_channels | |||
dummy_y = mask_channels * loss_channels | |||
dice = dice_loss(dummy_net_out,dummy_mask) | |||
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y) | |||
coeff = (self.iter_counter // 200) / 50 | |||
output = { | |||
'pixel_probs': out_prob_map, | |||
'org_pixel_labels': mammo_loss_and_gt, | |||
'loss': calculate_breast_topk_focal_ce_loss(dummy_yhat, 0.85 + 0.1 * coeff, dummy_y[:, [1], ...], | |||
verbose=False, bottom_fill_ratio=0.2, thd_to_apply_neg_loss_on_pos_batch=0.1, | |||
pos_skip_mask=(1 - loss_type).bool(), neg_skip_mask=(1 - loss_type).bool()), | |||
'focal_loss': balanced_focal_cross_entropy_loss_semi(dummy_net_out, dummy_mask, 2), | |||
'ce_loss': ce_loss, | |||
'dice_loss': dice, | |||
} | |||
return output |
@@ -0,0 +1,61 @@ | |||
import torch | |||
import torch.nn.functional as F | |||
import numpy as np | |||
from .unet_model import UNet | |||
from mlassistant.core import ModelIO, Model | |||
from ......utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi | |||
from ......utils.generalized_dice import dice_loss | |||
class UNetBaseline(Model): | |||
""" | |||
UNet architecture for baseline evaluation. | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
self.scale_channels = [64, 128, 128, 256, 512] # Filter numbers for each scale | |||
self.shape = (512, 512) # Change the probability map shape according to the mammo scans shape | |||
self.input_channels = 1 | |||
self.class_num = 2 | |||
self.unet = UNet(1, 2) | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_loss_and_gt: torch.Tensor) -> ModelIO: | |||
loss_type = mammo_loss_and_gt[:, 0, ...] | |||
ground_truth = mammo_loss_and_gt[:, 1, ...] * loss_type | |||
out_prob_map = self.unet(mammo_x) | |||
main_shape = ground_truth.shape | |||
out_prob_map_shape = out_prob_map.shape | |||
loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],out_prob_map_shape[1],main_shape[1], main_shape[2]) | |||
ground_truth = ground_truth.unsqueeze(dim=1) | |||
mask_channels = torch.cat((1 - ground_truth, ground_truth), 1) | |||
dummy_mask = mask_channels.clone().to(mammo_x.device) | |||
dummy_net_out = out_prob_map.clone().to(mammo_x.device) | |||
dummy_mask[loss_channels == 0] = 0.0 | |||
dummy_shape = dummy_mask.shape | |||
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_net_out = torch.cat((dummy_net_out, torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_yhat = out_prob_map * loss_channels | |||
dummy_y = mask_channels * loss_channels | |||
dice = dice_loss(dummy_net_out, dummy_mask) | |||
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y) | |||
focal_loss = balanced_focal_cross_entropy_loss_semi(dummy_net_out, dummy_mask, 2) | |||
output = { | |||
'pixel_probs': out_prob_map, | |||
'org_pixel_labels': mammo_loss_and_gt, | |||
'loss': focal_loss, | |||
'ce_loss': ce_loss, | |||
'dice_loss': dice, | |||
} | |||
return output |
@@ -0,0 +1,169 @@ | |||
import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
from torch.distributions.uniform import Uniform | |||
def kaiming_normal_init_weight(model): | |||
for m in model.modules(): | |||
if isinstance(m, nn.Conv3d): | |||
torch.nn.init.kaiming_normal_(m.weight) | |||
elif isinstance(m, nn.BatchNorm3d): | |||
m.weight.data.fill_(1) | |||
m.bias.data.zero_() | |||
return model | |||
def sparse_init_weight(model): | |||
for m in model.modules(): | |||
if isinstance(m, nn.Conv3d): | |||
torch.nn.init.sparse_(m.weight, sparsity=0.1) | |||
elif isinstance(m, nn.BatchNorm3d): | |||
m.weight.data.fill_(1) | |||
m.bias.data.zero_() | |||
return model | |||
class ConvBlock(nn.Module): | |||
"""two convolution layers with batch norm and leaky relu""" | |||
def __init__(self, in_channels, out_channels, dropout_p): | |||
super(ConvBlock, self).__init__() | |||
self.conv_conv = nn.Sequential( | |||
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), | |||
nn.BatchNorm2d(out_channels), | |||
nn.LeakyReLU(), | |||
nn.Dropout(dropout_p), | |||
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), | |||
nn.BatchNorm2d(out_channels), | |||
nn.LeakyReLU() | |||
) | |||
def forward(self, x): | |||
return self.conv_conv(x) | |||
class DownBlock(nn.Module): | |||
"""Downsampling followed by ConvBlock""" | |||
def __init__(self, in_channels, out_channels, dropout_p): | |||
super(DownBlock, self).__init__() | |||
self.maxpool_conv = nn.Sequential( | |||
nn.MaxPool2d(2), | |||
ConvBlock(in_channels, out_channels, dropout_p) | |||
) | |||
def forward(self, x): | |||
return self.maxpool_conv(x) | |||
class UpBlock(nn.Module): | |||
"""Upssampling followed by ConvBlock""" | |||
def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, | |||
bilinear=True): | |||
super(UpBlock, self).__init__() | |||
self.bilinear = bilinear | |||
if bilinear: | |||
self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size=1) | |||
self.up = nn.Upsample( | |||
scale_factor=2, mode='bilinear', align_corners=True) | |||
else: | |||
self.up = nn.ConvTranspose2d( | |||
in_channels1, in_channels2, kernel_size=2, stride=2) | |||
self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p) | |||
def forward(self, x1, x2): | |||
if self.bilinear: | |||
x1 = self.conv1x1(x1) | |||
x1 = self.up(x1) | |||
x = torch.cat([x2, x1], dim=1) | |||
return self.conv(x) | |||
class Encoder(nn.Module): | |||
def __init__(self, params): | |||
super(Encoder, self).__init__() | |||
self.params = params | |||
self.in_chns = self.params['in_chns'] | |||
self.ft_chns = self.params['feature_chns'] | |||
self.n_class = self.params['class_num'] | |||
self.bilinear = self.params['bilinear'] | |||
self.dropout = self.params['dropout'] | |||
assert (len(self.ft_chns) == 5) | |||
self.in_conv = ConvBlock( | |||
self.in_chns, self.ft_chns[0], self.dropout[0]) | |||
self.down1 = DownBlock( | |||
self.ft_chns[0], self.ft_chns[1], self.dropout[1]) | |||
self.down2 = DownBlock( | |||
self.ft_chns[1], self.ft_chns[2], self.dropout[2]) | |||
self.down3 = DownBlock( | |||
self.ft_chns[2], self.ft_chns[3], self.dropout[3]) | |||
self.down4 = DownBlock( | |||
self.ft_chns[3], self.ft_chns[4], self.dropout[4]) | |||
def forward(self, x): | |||
x0 = self.in_conv(x) | |||
x1 = self.down1(x0) | |||
x2 = self.down2(x1) | |||
x3 = self.down3(x2) | |||
x4 = self.down4(x3) | |||
return [x0, x1, x2, x3, x4] | |||
class Decoder(nn.Module): | |||
def __init__(self, params): | |||
super(Decoder, self).__init__() | |||
self.params = params | |||
self.in_chns = self.params['in_chns'] | |||
self.ft_chns = self.params['feature_chns'] | |||
self.n_class = self.params['class_num'] | |||
self.bilinear = self.params['bilinear'] | |||
assert (len(self.ft_chns) == 5) | |||
self.up1 = UpBlock( | |||
self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p=0.0) | |||
self.up2 = UpBlock( | |||
self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p=0.0) | |||
self.up3 = UpBlock( | |||
self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p=0.0) | |||
self.up4 = UpBlock( | |||
self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p=0.0) | |||
self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, | |||
kernel_size=3, padding=1) | |||
def forward(self, feature): | |||
x0 = feature[0] | |||
x1 = feature[1] | |||
x2 = feature[2] | |||
x3 = feature[3] | |||
x4 = feature[4] | |||
x = self.up1(x4, x3) | |||
x = self.up2(x, x2) | |||
x = self.up3(x, x1) | |||
x = self.up4(x, x0) | |||
output = self.out_conv(x) | |||
return output | |||
class UNet(nn.Module): | |||
def __init__(self, in_chns, class_num): | |||
super(UNet, self).__init__() | |||
params = {'in_chns': in_chns, | |||
'feature_chns': [64, 128, 128, 256, 512], | |||
'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], | |||
'class_num': class_num, | |||
'bilinear': False, | |||
'acti_func': 'relu'} | |||
self.encoder = Encoder(params) | |||
self.decoder = Decoder(params) | |||
def forward(self, x): | |||
feature = self.encoder(x) | |||
output = self.decoder(feature) | |||
return output.softmax(dim=1) |
@@ -0,0 +1,101 @@ | |||
import torch | |||
import torch.nn.functional as F | |||
import torch.nn as nn | |||
import numpy as np | |||
from .utils import init_weight, unsupervised_loss | |||
from .....models.semi_supervised.base_line.Resunet.Heavy.Models.resunet import RESUNet as RESUNETSEGMENTOR | |||
from mlassistant.core import ModelIO, Model | |||
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi | |||
from .....utils.generalized_dice import dice_loss | |||
class Network(nn.Module): | |||
def __init__(self): | |||
super(Network, self).__init__() | |||
self.branch1 = RESUNETSEGMENTOR(1, 2) | |||
self.branch2 = RESUNETSEGMENTOR(1, 2) | |||
def forward(self, data, step=1): | |||
if step == 1: | |||
return self.branch1(data) | |||
elif step == 2: | |||
return self.branch2(data) | |||
class CPS(Model): | |||
""" | |||
The main design for CPS method with ResUNet as the | |||
base network. | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
self.iter_counter = 0 | |||
self.network = Network() | |||
init_weight(self.network.branch1.business_layer, nn.init.kaiming_normal_, | |||
nn.BatchNorm2d, 1e-5, 0.1, | |||
mode='fan_in', nonlinearity='relu') | |||
init_weight(self.network.branch2.business_layer, nn.init.kaiming_normal_, | |||
nn.BatchNorm2d, 1e-5, 0.1, | |||
mode='fan_in', nonlinearity='relu') | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_loss_and_gt: torch.Tensor) -> ModelIO: | |||
self.iter_counter += 1 | |||
pred1 = F.softmax(self.network(mammo_x, step=1), dim=1) | |||
pred2 = F.softmax(self.network(mammo_x, step=2), dim=1) | |||
unsup_loss = unsupervised_loss(pred1, pred2) | |||
loss_type = mammo_loss_and_gt[:, 0, ...] | |||
ground_truth = mammo_loss_and_gt[:, 1, ...] * loss_type | |||
main_shape = ground_truth.shape | |||
out_prob_map_shape = pred1.shape | |||
loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],out_prob_map_shape[1],main_shape[1], main_shape[2]) | |||
ground_truth = ground_truth.unsqueeze(dim = 1) | |||
mask_channels = torch.cat((1 - ground_truth, ground_truth), 1) | |||
dummy_mask = mask_channels.clone().to(mammo_x.device) | |||
dummy_net_out1 = pred1.clone().to(mammo_x.device) | |||
dummy_net_out2 = pred2.clone().to(mammo_x.device) | |||
dummy_mask[loss_channels == 0] = 0.0 | |||
dummy_shape = dummy_mask.shape | |||
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_net_out1 = torch.cat((dummy_net_out1, torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_net_out2 = torch.cat((dummy_net_out2, torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_yhat = pred1 * loss_channels | |||
dummy_y = mask_channels * loss_channels | |||
focal_loss1 = balanced_focal_cross_entropy_loss_semi(dummy_net_out1, dummy_mask, 4) | |||
focal_loss2 = balanced_focal_cross_entropy_loss_semi(dummy_net_out2, dummy_mask, 4) | |||
dice_ls = dice_loss(dummy_net_out1, dummy_mask) | |||
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y) | |||
weight = self._get_unsup_weight() | |||
output = { | |||
'pixel_probs': pred1, | |||
'org_pixel_labels': mammo_loss_and_gt, | |||
'loss': focal_loss1 + focal_loss2 + unsup_loss, | |||
'ce_loss': ce_loss, | |||
'dice_loss': dice_ls, | |||
} | |||
return output | |||
def _get_unsup_weight(self): | |||
if self.iter_counter // 200 >= 70: | |||
return 1.0 | |||
term = 1 - (self.iter_counter // 200) / 70 | |||
return np.exp(-5.0 * term * term) | |||
@@ -0,0 +1,36 @@ | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
def __init_weight(feature, conv_init, norm_layer, bn_eps, bn_momentum, | |||
**kwargs): | |||
for name, m in feature.named_modules(): | |||
if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): | |||
conv_init(m.weight, **kwargs) | |||
elif isinstance(m, norm_layer): | |||
m.eps = bn_eps | |||
m.momentum = bn_momentum | |||
nn.init.constant_(m.weight, 1) | |||
nn.init.constant_(m.bias, 0) | |||
def init_weight(module_list, conv_init, norm_layer, bn_eps, bn_momentum, | |||
**kwargs): | |||
if isinstance(module_list, list): | |||
for feature in module_list: | |||
__init_weight(feature, conv_init, norm_layer, bn_eps, bn_momentum, | |||
**kwargs) | |||
else: | |||
__init_weight(module_list, conv_init, norm_layer, bn_eps, bn_momentum, | |||
**kwargs) | |||
def unsupervised_loss(probability_map1, probability_map2): | |||
max1 = torch.max(probability_map1, dim=1)[1].long() | |||
max2 = torch.max(probability_map2, dim=1)[1].long() | |||
loss1 = F.binary_cross_entropy(probability_map1, max2) | |||
loss2 = F.binary_cross_entropy(probability_map2, max1) | |||
return loss1 + loss2 |
@@ -0,0 +1,97 @@ | |||
from typing import Tuple | |||
import torch | |||
import torch.nn.functional as F | |||
import numpy as np | |||
from .loss import unsupervised_loss | |||
from .unet import UNet | |||
from mlassistant.core import ModelIO, Model | |||
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi | |||
from .....utils.generalized_dice import dice_loss | |||
class ICT(Model): | |||
""" | |||
The main design for ICT method with UNet as the | |||
base network. | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
self.scale_channels = [64, 128, 128, 256, 512] # Filter numbers for each scale | |||
self.shape = (512, 512) # Change the probability map shape according to the mammo scans shape | |||
self.input_channels = 1 | |||
self.class_num = 2 | |||
self.iter_counter = 14400 | |||
self.main_unet = UNet(self.input_channels, self.scale_channels, self.class_num, False) | |||
self.mean_teacher_unet = UNet(self.input_channels, self.scale_channels, self.class_num, True) | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_loss_and_gt: torch.Tensor) -> ModelIO: | |||
self.iter_counter += 1 | |||
self._update_mean_teacher() | |||
combined_model_out, combined_mean_out = self._mix_up(mammo_x) | |||
model_out = self.main_unet(mammo_x) | |||
unsup_loss = unsupervised_loss(combined_mean_out, combined_model_out) | |||
loss_type = mammo_loss_and_gt[:, 0, ...] | |||
ground_truth = mammo_loss_and_gt[:, 1, ...] * loss_type | |||
main_shape = ground_truth.shape | |||
out_prob_map_shape = model_out.shape | |||
loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],out_prob_map_shape[1],main_shape[1], main_shape[2]) | |||
ground_truth = ground_truth.unsqueeze(dim = 1) | |||
mask_channels = torch.cat((1 - ground_truth, ground_truth), 1) | |||
dummy_mask = mask_channels.clone().to(mammo_x.device) | |||
dummy_net_out = model_out.clone().to(mammo_x.device) | |||
dummy_mask[loss_channels == 0] = 0.0 | |||
dummy_shape = dummy_mask.shape | |||
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_net_out = torch.cat((dummy_net_out, torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_yhat = model_out * loss_channels | |||
dummy_y = mask_channels * loss_channels | |||
focal_loss = balanced_focal_cross_entropy_loss_semi(dummy_net_out, dummy_mask, 2) | |||
dice_ls = dice_loss(dummy_net_out,dummy_mask) | |||
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y) | |||
weight = 0.6 * ((self.iter_counter // 200) / 400) | |||
output = { | |||
'pixel_probs': model_out, | |||
'org_pixel_labels': mammo_loss_and_gt, | |||
'loss': focal_loss * (0.8 - weight) + unsup_loss * (0.2 + weight), | |||
'ce_loss': ce_loss, | |||
'dice_loss': dice_ls, | |||
} | |||
return output | |||
def _mix_up(self, unlabeled_mammo_x: torch.Tensor) -> Tuple: | |||
lambda_coeff = np.random.beta(1., 1.) | |||
shuffled_indices = torch.from_numpy(np.random.permutation(unlabeled_mammo_x.shape[0])).to(unlabeled_mammo_x.device) | |||
mixed_mammo_x = lambda_coeff * unlabeled_mammo_x + \ | |||
(1 - lambda_coeff) * unlabeled_mammo_x[shuffled_indices] | |||
unlabeled_model_out = self.main_unet(mixed_mammo_x) | |||
with torch.no_grad(): | |||
unlabeled_mean_out = self.mean_teacher_unet(unlabeled_mammo_x) | |||
unlabeled_mean_out = lambda_coeff * unlabeled_mean_out + \ | |||
(1 - lambda_coeff) * unlabeled_mean_out[shuffled_indices] | |||
return unlabeled_model_out, unlabeled_mean_out | |||
def _update_mean_teacher(self, alpha=0.999): | |||
alpha = min(1 - 1 / (self.iter_counter + 1), alpha) | |||
for mean_param, param in zip(self.mean_teacher_unet.parameters(), self.main_unet.parameters()): | |||
mean_param.data.mul_(alpha).add_(1 - alpha, param.data) | |||
@@ -0,0 +1,132 @@ | |||
from typing import Tuple | |||
import torch | |||
import torch.nn.functional as F | |||
import numpy as np | |||
from copy import deepcopy | |||
from .loss import unsupervised_loss | |||
from .....models.semi_supervised.base_line.Resunet.Heavy.Models.resunet import RESUNet as RESUNETSEGMENTOR | |||
from mlassistant.core import ModelIO, Model | |||
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi | |||
from .....utils.generalized_dice import dice_loss | |||
class ICT(Model): | |||
""" | |||
The main design for ICT method with ResUNet as the | |||
base network. | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
self.scale_channels = [64, 128, 128, 256, 512] | |||
self.shape = (512, 512) | |||
self.input_channels = 1 | |||
self.class_num = 2 | |||
self.iter_counter = 0 | |||
encoders_conv_kws = [ | |||
{"inp_channel": 1, "out_channel": 32, "n_conv_blocks": 2}, | |||
{"inp_channel": 33, "out_channel": 64, "n_conv_blocks": 2}, | |||
{"inp_channel": 65, "out_channel": 128, "n_conv_blocks": 3}, | |||
{"inp_channel": 129, "out_channel": 256, "n_conv_blocks": 3}, | |||
] | |||
bridge_conv_kws = {"inp_channel": 257, "out_channel": 512} | |||
decoders_conv_kws = [ | |||
{"inp_channel": 512, "enc_inp_channel": 256, "out_channel": 256}, | |||
{"inp_channel": 256, "enc_inp_channel": 128, "out_channel": 64}, | |||
{"inp_channel": 64, "enc_inp_channel": 64, "out_channel": 32}, | |||
{"inp_channel": 32, "enc_inp_channel": 32, "out_channel": 16}, | |||
] | |||
self.main_resunet = ResunetSegmentor( | |||
seg_num_labels=self.class_num, | |||
concat_original_image=True, | |||
encoders_conv_kws=encoders_conv_kws, | |||
decoders_conv_kws=decoders_conv_kws, | |||
bridge_conv_kws=bridge_conv_kws, | |||
) | |||
self.mean_teacher_resunet = deepcopy(self.main_resunet) | |||
for param in self.mean_teacher_resunet.parameters(): | |||
param.detach_() | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_loss_and_gt: torch.Tensor) -> ModelIO: | |||
self.iter_counter += 1 | |||
self._update_mean_teacher() | |||
combined_model_out, combined_mean_out = self._mix_up(mammo_x) | |||
model_out = self.main_resunet(mammo_x, None)["pixel_probs"] | |||
unsup_loss = unsupervised_loss(combined_mean_out, combined_model_out) | |||
loss_type = mammo_loss_and_gt[:, 0, ...] | |||
ground_truth = mammo_loss_and_gt[:, 1, ...] * loss_type | |||
main_shape = ground_truth.shape | |||
out_prob_map_shape = model_out.shape | |||
loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],out_prob_map_shape[1],main_shape[1], main_shape[2]) | |||
ground_truth = ground_truth.unsqueeze(dim = 1) | |||
mask_channels = torch.cat((1 - ground_truth, ground_truth), 1) | |||
dummy_mask = mask_channels.clone().to(mammo_x.device) | |||
dummy_net_out = model_out.clone().to(mammo_x.device) | |||
dummy_mask[loss_channels == 0] = 0.0 | |||
dummy_shape = dummy_mask.shape | |||
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_net_out = torch.cat((dummy_net_out, torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_yhat = model_out * loss_channels | |||
dummy_y = mask_channels * loss_channels | |||
focal_loss = balanced_focal_cross_entropy_loss_semi(dummy_net_out, dummy_mask, 4) | |||
dice_ls = dice_loss(dummy_net_out,dummy_mask) | |||
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y) | |||
weight = self._get_unsup_weight() | |||
output = { | |||
'pixel_probs': model_out, | |||
'org_pixel_labels': mammo_loss_and_gt, | |||
'loss': focal_loss + unsup_loss, # * weight, | |||
'ce_loss': ce_loss, | |||
'dice_loss': dice_ls, | |||
} | |||
return output | |||
def _mix_up(self, unlabeled_mammo_x: torch.Tensor) -> Tuple: | |||
lambda_coeff = np.random.beta(1., 1.) | |||
shuffled_indices = torch.from_numpy(np.random.permutation(unlabeled_mammo_x.shape[0])).to(unlabeled_mammo_x.device) | |||
mixed_mammo_x = lambda_coeff * unlabeled_mammo_x + \ | |||
(1 - lambda_coeff) * unlabeled_mammo_x[shuffled_indices] | |||
unlabeled_model_out = self.main_resunet(mixed_mammo_x, None)["pixel_probs"] | |||
with torch.no_grad(): | |||
unlabeled_mean_out = self.mean_teacher_resunet(unlabeled_mammo_x, None)["pixel_probs"] | |||
unlabeled_mean_out = lambda_coeff * unlabeled_mean_out + \ | |||
(1 - lambda_coeff) * unlabeled_mean_out[shuffled_indices] | |||
return unlabeled_model_out, unlabeled_mean_out | |||
def _update_mean_teacher(self, alpha=0.999): | |||
alpha = min(1 - 1 / (self.iter_counter + 1), alpha) | |||
for mean_param, param in zip(self.mean_teacher_resunet.parameters(), self.main_resunet.parameters()): | |||
mean_param.data.mul_(alpha).add_(1 - alpha, param.data) | |||
def _get_unsup_weight(self): | |||
if self.iter_counter // 200 >= 70: | |||
return 1.0 | |||
term = 1 - (self.iter_counter // 200) / 70 | |||
return np.exp(-5.0 * term * term) | |||
@@ -0,0 +1,70 @@ | |||
import torch | |||
from torch.nn import BCELoss | |||
from .....utils.dice import dice_loss | |||
def supervised_loss(probability_map: torch.Tensor, ground_truth: torch.Tensor) -> torch.Tensor: | |||
""" | |||
Computes the total supervised loss of the batch. | |||
The loss is consisted of dice loss and cross entropy loss. | |||
:param probability_map: probability map output of the segmentation model | |||
:param ground_truth: ground truth of the processed input | |||
:return: loss of the labeled data | |||
""" | |||
binary_entropy_loss = BCELoss() | |||
loss = (binary_entropy_loss(probability_map, ground_truth) + | |||
dice_loss(probability_map.unsqueeze(dim=0), ground_truth.long())) / 2 | |||
return loss | |||
def unsupervised_loss(mean_teacher_mixed_up: torch.Tensor, model_mixed_up_output: torch.Tensor) -> torch.Tensor: | |||
""" | |||
Computes the total unsupervised loss for the batch. | |||
:param mean_teacher_mixed_up: probability map output of the mean teacher model | |||
:param model_mixed_up_output: probability map output of the main model using the mixed up data | |||
:return: loss of the unlabeled data | |||
""" | |||
loss = torch.mean((mean_teacher_mixed_up - model_mixed_up_output) ** 2) | |||
return loss | |||
def total_loss(labeled_out: torch.Tensor, | |||
ground_truth: torch.Tensor, | |||
unlabeled_mean_out: torch.Tensor, | |||
unlabeled_model_out: torch.Tensor, | |||
epoch: int) -> torch.Tensor: | |||
""" | |||
Computes the total loss consisting of both supervised and unsupervised losses. | |||
:param labeled_x: probability maps belonging to the labeled dataset | |||
:param unlabeled_x: probability maps belonging to the unlabeled dataset | |||
:param ground_truth: ground truth belonging to the labeled dataset | |||
:param epoch: current epoch of training | |||
:return: total loss | |||
""" | |||
sup_loss = supervised_loss(labeled_out, ground_truth) | |||
if unlabeled_mean_out is None: | |||
unsup_loss = 0 | |||
else: | |||
unsup_loss = unsupervised_loss(unlabeled_mean_out, unlabeled_model_out) | |||
weight = unsupervised_weight(epoch) | |||
loss = sup_loss * (0.8 - weight) + unsup_loss * (0.2 + weight) | |||
return loss | |||
def unsupervised_weight(epoch: int, max_epoch=200): | |||
""" | |||
Computes trade-off weight lambda for the unsupervised loss | |||
:param epoch: current epoch | |||
:param w_max: a coefficient used in the formula | |||
:param max_epoch: maximum number of epochs to run | |||
:return: trade-off weight lambda | |||
""" | |||
return 0.6 * (epoch / max_epoch) |
@@ -0,0 +1,157 @@ | |||
from typing import Tuple, List | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from torchvision import transforms | |||
class ConvolutionBlock(nn.Module): | |||
""" | |||
A block with two convolutional layers followed by batch norm | |||
layers and LeakyReLU as activation function. | |||
""" | |||
def __init__(self, input_channel: int, output_channel: int): | |||
super(ConvolutionBlock, self).__init__() | |||
self.conv = nn.Sequential( | |||
nn.Conv2d(input_channel, output_channel, kernel_size=3, padding=1), | |||
nn.BatchNorm2d(output_channel), | |||
nn.LeakyReLU(), | |||
nn.Conv2d(output_channel, output_channel, kernel_size=3, padding=1), | |||
nn.BatchNorm2d(output_channel), | |||
nn.LeakyReLU() | |||
) | |||
def forward(self, x: torch.Tensor) -> torch.Tensor: | |||
return self.conv(x) | |||
class DownSamplingBlock(nn.Module): | |||
""" | |||
A down-sampling block with a max pooling layer followed by | |||
a convolutional block. | |||
""" | |||
def __init__(self, input_channel: int, output_channel: int): | |||
super(DownSamplingBlock, self).__init__() | |||
self.down_sampler = nn.Sequential( | |||
nn.MaxPool2d(2), | |||
ConvolutionBlock(input_channel, output_channel) | |||
) | |||
def forward(self, x: torch.Tensor) -> torch.Tensor: | |||
return self.down_sampler(x) | |||
class UpSamplingBlock(nn.Module): | |||
""" | |||
An up-sampling block with a 2d transpose convolution layer | |||
for up-sampling followed by a ConvolutionBlock. | |||
""" | |||
def __init__(self, input_channel1, input_channel2, output_channel): | |||
super(UpSamplingBlock, self).__init__() | |||
self.up = nn.ConvTranspose2d(input_channel1, input_channel2, kernel_size=2, stride=2) | |||
self.conv = ConvolutionBlock(input_channel2 * 2, output_channel) | |||
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: | |||
x1 = self.up(x1) | |||
height = x1.size()[2] # Finding dimensions of the up-sampled image for cropping | |||
width = x1.size()[3] | |||
x2 = transforms.CenterCrop((height, width))(x2) # Cropping the image from the skip connection from the center | |||
x = torch.cat([x2, x1], dim=1) # Concatenating both images together | |||
return self.conv(x) | |||
class DownSampler(nn.Module): | |||
""" | |||
The module used for down-sampling process consisted of | |||
multiple down-sampling blocks and one convolutional block | |||
at the first part. | |||
""" | |||
def __init__(self, input_channel: int, scale_channels: list): | |||
super(DownSampler, self).__init__() | |||
self.input_channel = input_channel | |||
self.scale_channels = scale_channels | |||
self.conv = ConvolutionBlock(self.input_channel, self.scale_channels[0]) | |||
self.down_sampler1 = DownSamplingBlock(self.scale_channels[0], self.scale_channels[1]) | |||
self.down_sampler2 = DownSamplingBlock(self.scale_channels[1], self.scale_channels[2]) | |||
self.down_sampler3 = DownSamplingBlock(self.scale_channels[2], self.scale_channels[3]) | |||
self.down_sampler4 = DownSamplingBlock(self.scale_channels[3], self.scale_channels[4]) | |||
def forward(self, x: torch.Tensor) -> List: | |||
features = self.conv(x) | |||
down_sampled1 = self.down_sampler1(features) | |||
down_sampled2 = self.down_sampler2(down_sampled1) | |||
down_sampled3 = self.down_sampler3(down_sampled2) | |||
down_sampled4 = self.down_sampler4(down_sampled3) | |||
return [features, down_sampled1, down_sampled2, down_sampled3, down_sampled4] | |||
class UpSampler(nn.Module): | |||
""" | |||
The module used for up-sampling process consisted of | |||
multiple up-sampling blocks and one convolutional layer | |||
at the last part for the probability map. | |||
""" | |||
def __init__(self, scale_channels: list, class_num: int): | |||
super(UpSampler, self).__init__() | |||
self.scale_channels = scale_channels | |||
self.class_num = class_num | |||
self.up_sampler4 = UpSamplingBlock(scale_channels[4], scale_channels[3], scale_channels[3]) | |||
self.up_sampler3 = UpSamplingBlock(scale_channels[3], scale_channels[2], scale_channels[2]) | |||
self.up_sampler2 = UpSamplingBlock(scale_channels[2], scale_channels[1], scale_channels[1]) | |||
self.up_sampler1 = UpSamplingBlock(scale_channels[1], scale_channels[0], scale_channels[0]) | |||
self.probability_map_conv = nn.Sequential( | |||
nn.Conv2d(self.scale_channels[0], self.class_num, kernel_size=1), | |||
nn.Softmax(dim=1) | |||
) | |||
def forward(self, down_outputs: list) -> Tuple: | |||
""" | |||
NOTE: Numbers in front of the var names indicate the scale number. | |||
""" | |||
down_outputs0 = down_outputs.pop(0) | |||
down_outputs1 = down_outputs.pop(0) | |||
down_outputs2 = down_outputs.pop(0) | |||
down_outputs3 = down_outputs.pop(0) | |||
down_outputs4 = down_outputs.pop(0) | |||
up_sampled4 = self.up_sampler4(down_outputs4, down_outputs3) | |||
up_sampled3 = self.up_sampler3(up_sampled4, down_outputs2) | |||
up_sampled2 = self.up_sampler2(up_sampled3, down_outputs1) | |||
up_sampled1 = self.up_sampler1(up_sampled2, down_outputs0) | |||
output = self.probability_map_conv(up_sampled1) | |||
return output | |||
class UNet(nn.Module): | |||
""" | |||
Implementation of UNet architecture. | |||
""" | |||
def __init__(self, input_channel: list, scale_channels: list, class_num: int, ema_flag: bool): | |||
super(UNet, self).__init__() | |||
self.down_sampler = DownSampler(input_channel, scale_channels) | |||
self.up_sampler = UpSampler(scale_channels, class_num) | |||
if ema_flag: | |||
for param in self.down_sampler.parameters(): | |||
param.detach_() | |||
for param in self.up_sampler.parameters(): | |||
param.detach_() | |||
def forward(self, x: list) -> Tuple: | |||
down_sampled_x = self.down_sampler(x) | |||
up_sampled_x = self.up_sampler(down_sampled_x) | |||
return up_sampled_x |
@@ -0,0 +1,154 @@ | |||
import torch | |||
import numpy as np | |||
from typing import Tuple | |||
from torch.nn import BCELoss, KLDivLoss | |||
from .....utils.dice import dice_loss | |||
def single_supervised_loss(probability_map: torch.Tensor, ground_truth: torch.Tensor) -> torch.Tensor: | |||
""" | |||
Computes the supervised loss of a single scale of the model. | |||
The loss is consisted of dice loss and cross entropy loss. | |||
:param probability_map: probability map output of the segmentation model | |||
:param ground_truth: ground truth of the processed input | |||
:return: overall loss of the supervised data | |||
""" | |||
binary_entropy_loss = BCELoss() | |||
loss = (binary_entropy_loss(probability_map, ground_truth) + | |||
dice_loss(probability_map.unsqueeze(dim=0), ground_truth.long())) / 2 | |||
return loss | |||
def supervised_loss(probability_map: list, ground_truth: torch.Tensor) -> torch.Tensor: | |||
""" | |||
Computes the total supervised loss of the batch. | |||
The loss is consisted of dice loss and cross entropy loss. | |||
:param probability_map: probability map outputs of the segmentation model from different scales | |||
:param ground_truth: ground truth of the processed input | |||
:return: mean loss of the supervised data over all scales | |||
""" | |||
probability_map0 = probability_map.pop(0) | |||
probability_map1 = probability_map.pop(0) | |||
probability_map2 = probability_map.pop(0) | |||
probability_map3 = probability_map.pop(0) | |||
total_loss = (single_supervised_loss(probability_map0, ground_truth) + | |||
single_supervised_loss(probability_map1, ground_truth) + | |||
single_supervised_loss(probability_map2, ground_truth) + | |||
single_supervised_loss(probability_map3, ground_truth)) / 4 | |||
return total_loss | |||
def uncertainty_minimization(probability_map: list, mean_prob_map: torch.Tensor) -> Tuple[torch.Tensor, Tuple]: | |||
""" | |||
Computes the uncertainty minimization term for the unsupervised loss. | |||
:param mean_prob_map: mean probability map over all the scales | |||
:param probability_map: probability map outputs of the segmentation model from different scales | |||
:return: uncertainty minimization term and a tuple of the exponential weights | |||
""" | |||
kld_loss = KLDivLoss(reduction='none') | |||
kl_loss_scale0 = torch.sum(kld_loss(torch.log(probability_map[0]), | |||
mean_prob_map), dim=1, keepdim=True) | |||
exp_kl_loss_scale0 = torch.exp(-1 * kl_loss_scale0) | |||
kl_loss_scale1 = torch.sum(kld_loss(torch.log(probability_map[1]), | |||
mean_prob_map), dim=1, keepdim=True) | |||
exp_kl_loss_scale1 = torch.exp(-1 * kl_loss_scale1) | |||
kl_loss_scale2 = torch.sum(kld_loss(torch.log(probability_map[2]), | |||
mean_prob_map), dim=1, keepdim=True) | |||
exp_kl_loss_scale2 = torch.exp(-1 * kl_loss_scale2) | |||
kl_loss_scale3 = torch.sum(kld_loss(torch.log(probability_map[3]), | |||
mean_prob_map), dim=1, keepdim=True) | |||
exp_kl_loss_scale3 = torch.exp(-1 * kl_loss_scale3) | |||
UM = (torch.mean(kl_loss_scale0) + torch.mean(kl_loss_scale1) + | |||
torch.mean(kl_loss_scale2) + torch.mean(kl_loss_scale3)) / 4 | |||
return UM, (exp_kl_loss_scale0, exp_kl_loss_scale1, exp_kl_loss_scale2, exp_kl_loss_scale3) | |||
def uncertainty_rectification(probability_map: list, mean_prob_map: torch.Tensor, | |||
weights: tuple) -> torch.Tensor: | |||
""" | |||
Computes the uncertainty rectification term for the unsupervised loss. | |||
:param weights: weights computed by the kl divergences from the UM term | |||
:param mean_prob_map: mean probability map over all the scales | |||
:param probability_map: probability map outputs of the segmentation model from different scales | |||
:return: uncertainty rectification term | |||
""" | |||
mse_loss_scale0 = torch.mean(((probability_map[0] - mean_prob_map) ** 2) * | |||
weights[0]) / torch.mean(weights[0]) | |||
mse_loss_scale1 = torch.mean(((probability_map[1] - mean_prob_map) ** 2) * | |||
weights[1]) / torch.mean(weights[1]) | |||
mse_loss_scale2 = torch.mean(((probability_map[2] - mean_prob_map) ** 2) * | |||
weights[2]) / torch.mean(weights[2]) | |||
mse_loss_scale3 = torch.mean(((probability_map[3] - mean_prob_map) ** 2) * | |||
weights[3]) / torch.mean(weights[3]) | |||
UR = (mse_loss_scale0 + mse_loss_scale1 + mse_loss_scale2 + mse_loss_scale3) / 4 | |||
return UR | |||
def unsupervised_loss(probability_map: list) -> torch.Tensor: | |||
""" | |||
Computes the total unsupervised loss for the batch. | |||
:param probability_map: probability map outputs of the segmentation model from different scales | |||
:return: mean loss of the unsupervised data over all scales | |||
""" | |||
mean_prob_map = (probability_map[0] + probability_map[1] + probability_map[2] + probability_map[3]) / 4 | |||
UM, weights = uncertainty_minimization(probability_map, mean_prob_map) | |||
UR = uncertainty_rectification(probability_map, mean_prob_map, weights) | |||
loss = UM + UR | |||
return loss | |||
def total_loss(x: list, ground_truth: torch.Tensor, region_mask: torch.Tensor, epoch: int) -> torch.Tensor: | |||
""" | |||
Computes the total loss consisting of both supervised and unsupervised losses. | |||
:param labeled_x: probability maps belonging to the labeled dataset | |||
:param unlabeled_x: probability maps belonging to the unlabeled dataset | |||
:param ground_truth: ground truth belonging to the labeled dataset | |||
:param region_mask: regional mask for specifying supervised regions | |||
:param epoch: current epoch of training | |||
:return: total loss | |||
""" | |||
# sup_loss = supervised_loss(labeled_x, ground_truth) | |||
# if unlabeled_x is None: | |||
# unsup_loss = 0 | |||
# else: | |||
# unsup_loss = unsupervised_loss(unlabeled_x) | |||
# weight = unsupervised_weight(epoch) | |||
# loss = sup_loss * (0.8 - weight) + unsup_loss * (0.2 + weight) | |||
scales_outputs_sup, scales_outputs_unsup = [], [] | |||
for scale_output in x: | |||
scales_outputs_sup.append(scale_output * region_mask.to(torch.float64)) | |||
scales_outputs_unsup.append(scale_output * ~region_mask.to(torch.float64)) | |||
# sup_loss = supervised_loss(scales_outputs_sup, ground_truth * region_mask.to(torch.float64)) | |||
unsup_loss = unsupervised_loss(scales_outputs_unsup) | |||
return unsup_loss | |||
def unsupervised_weight(epoch: int, max_epoch=200): | |||
""" | |||
Computes trade-off weight lambda for the unsupervised loss | |||
:param epoch: current epoch | |||
:param w_max: a coefficient used in the formula | |||
:param max_epoch: maximum number of epochs to run | |||
:return: trade-off weight lambda | |||
""" | |||
return 0.6 * (epoch / max_epoch) |
@@ -0,0 +1,351 @@ | |||
import torch.nn | |||
from typing import Dict, List, Any | |||
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss | |||
from torch.nn import functional as F | |||
from PIL import Image | |||
from torchvision.utils import draw_segmentation_masks | |||
import os | |||
from mlassistant.core import Model | |||
from mlassistant.context.dataloader import DataloaderContext | |||
import numpy as np | |||
from typing import Dict, List, Any, Callable, Optional | |||
from torch import nn | |||
from torch.nn import functional as F | |||
from torchvision.models import resnet18 | |||
from mlassistant.core import Model | |||
from ....segmentation.submodules.res_enc_block import ResEncoderBlock | |||
from .....enums.roi.roi_aggregation_type import ROIAggregation | |||
from ....full_segmentation.utils import sum_roi_4ch_to_1ch | |||
from ....segmentation.utils import cal_patch_pixel_labels_by_patch_label_and_roi, \ | |||
create_patches_label_based_on_roi, AssembleDisassemble, \ | |||
reduce_resolution, select_non_boundaries | |||
from ....segmentation.losses import cal_focal_loss, cal_weighted_focal_loss | |||
class ResDecoderBlock(torch.nn.Module): | |||
def __init__(self, inp_channel, enc_inp_channel, out_channel, n_conv_blocks=2, | |||
conv_block=torch.nn.Conv2d, conv_block_kws={'kernel_size': 3, 'padding': 1}, | |||
activation=torch.nn.ReLU): | |||
super(ResDecoderBlock, self).__init__() | |||
self.up_sample = torch.nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) | |||
self.skip_connection = conv_block(inp_channel + enc_inp_channel, out_channel, **conv_block_kws) | |||
modules = [ | |||
torch.nn.BatchNorm2d(inp_channel + enc_inp_channel), | |||
activation(), # activation of the previous channel | |||
conv_block(inp_channel + enc_inp_channel, out_channel, **conv_block_kws)] | |||
for i in range(1, n_conv_blocks): | |||
modules += [ | |||
torch.nn.BatchNorm2d(out_channel), | |||
activation(), | |||
conv_block(out_channel, out_channel, **conv_block_kws)] | |||
self.seq = torch.nn.Sequential(*modules) | |||
def forward(self, x, x_enc): | |||
out = self.up_sample(x) | |||
out = torch.cat([out, x_enc], dim=1) | |||
return self.skip_connection(out) + self.seq(out) | |||
class ResUNet(Model): | |||
def __init__(self, | |||
seg_num_labels: int, | |||
concat_original_image: bool, | |||
encoders_conv_kws: Optional[List[Dict[str, Any]]], | |||
bridge_conv_kws: Optional[Dict[str, Any]], | |||
decoders_conv_kws: List[Dict[str, Any]], | |||
roi_agg_type: ROIAggregation = ROIAggregation.mass_vs_others, | |||
remove_boundaries: bool = False, | |||
pretrained_resnet18: bool = False, | |||
loss=(lambda log_pred, gt: cal_weighted_focal_loss(log_pred, gt, 4)), | |||
use_dropout: bool = False) -> None: | |||
super(ResUNet, self).__init__() | |||
assert encoders_conv_kws is not None or pretrained_resnet18, \ | |||
"for creating encoder part, either of number of channels or using pretrained version of resnet18 should be determined" | |||
assert not concat_original_image or not pretrained_resnet18, \ | |||
"concatenating is currently not supported when using pretrained version" | |||
num_encoders = 6 if pretrained_resnet18 else len(encoders_conv_kws) | |||
self.halving_depth = num_encoders - len(decoders_conv_kws) + 1 | |||
self._roi_agg_type: ROIAggregation = roi_agg_type | |||
self._remove_boundaries: bool = remove_boundaries | |||
self._loss = loss | |||
self._concat_original_image: bool = concat_original_image | |||
self.encoders: nn.ModuleList = None | |||
self.bridge: nn.Sequential = None | |||
self._create_encoder(pretrained_resnet18, encoders_conv_kws) | |||
self._create_bridge(bridge_conv_kws) | |||
self.decoders = nn.ModuleList([ResDecoderBlock(**kwargs) | |||
for kwargs in decoders_conv_kws]) | |||
last_channel = decoders_conv_kws[-1]['out_channel'] | |||
if not use_dropout: | |||
self.decider = nn.Sequential( | |||
nn.Conv2d(last_channel, seg_num_labels, kernel_size=1), | |||
nn.LogSoftmax(dim=1)) | |||
else: | |||
self.decider = nn.Sequential( | |||
nn.Dropout(0.3), | |||
nn.Conv2d(last_channel, seg_num_labels, kernel_size=1), | |||
nn.LogSoftmax(dim=1)) | |||
def _create_encoder(self, | |||
pretrained: bool, | |||
encoders_conv_kws: Optional[List[Dict[str, Any]]]) -> None: | |||
if pretrained: | |||
resnet = resnet18(pretrained=True) | |||
self._inp_channel = 3 | |||
self.encoders = nn.ModuleList([]) | |||
self.encoders.append( | |||
nn.Sequential( | |||
resnet.conv1, | |||
resnet.bn1, | |||
resnet.relu, | |||
resnet.maxpool)) | |||
for l in [resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4]: | |||
self.encoders.append(l) | |||
else: | |||
self._inp_channel = encoders_conv_kws[0]['inp_channel'] | |||
self.encoders = nn.ModuleList([ResEncoderBlock(**encoders_conv_kws[0])]) | |||
encoders_conv_kws.pop(0) | |||
self.encoders += nn.ModuleList([nn.Sequential( | |||
nn.AvgPool2d(kernel_size=2, stride=2), | |||
ResEncoderBlock(**kwargs) | |||
) | |||
for kwargs in encoders_conv_kws]) | |||
def _create_bridge(self, | |||
bridge_conv_kws: Optional[List[Dict[str, Any]]]) -> None: | |||
if bridge_conv_kws is None: | |||
self.bridge = nn.AvgPool2d(kernel_size=2, stride=2), | |||
else: | |||
self.bridge = nn.Sequential( | |||
nn.AvgPool2d(kernel_size=2, stride=2), | |||
ResEncoderBlock(**bridge_conv_kws)) | |||
def _forward(self, x): | |||
x = x.repeat_interleave(self._inp_channel, 1) | |||
encoders_outs = [] | |||
out = x | |||
x_ = x | |||
for enc in self.encoders: | |||
out = enc(out) | |||
encoders_outs.append(out) | |||
if self._concat_original_image: | |||
x_ = F.interpolate(x_, out.shape[2:], mode='bilinear', align_corners=True) | |||
out = torch.cat([out, x_], dim=1) | |||
out = self.bridge(out) | |||
for dec in self.decoders: | |||
out = dec(out, encoders_outs.pop(-1)) | |||
return self.decider(out) | |||
class ResunetSegmentor(ResUNet): | |||
def __init__(self, | |||
seg_num_labels: int, | |||
concat_original_image: bool, | |||
encoders_conv_kws: List[Dict[str, Any]], | |||
bridge_conv_kws: Dict[str, Any], | |||
decoders_conv_kws: List[Dict[str, Any]] , | |||
log_transform=False , | |||
is_base=True, | |||
base_saving_dir=None , | |||
gamma=2): | |||
""" | |||
Args: | |||
seg_num_labels (int): [description] = 2 (background vs nipple) | |||
concat_original_image (bool): [description] | |||
encoders_conv_kws (List[Dict[str, Any]]): [description] | |||
bridge_conv_kws (Dict[str, Any]): [description] | |||
decoders_conv_kws (List[Dict[str, Any]]): [description] | |||
""" | |||
self.is_base = is_base | |||
self.batch_count = 0 | |||
super(ResunetSegmentor, self).__init__( | |||
seg_num_labels, | |||
concat_original_image, | |||
encoders_conv_kws, | |||
bridge_conv_kws, | |||
decoders_conv_kws, | |||
) | |||
self.seg_num_labels = seg_num_labels | |||
self.log_transform = log_transform | |||
self.dropout_2 = torch.nn.Dropout(0.4) | |||
self.dropout_1 = torch.nn.Dropout(0.2) | |||
last_channel = decoders_conv_kws[-1]["out_channel"] | |||
self.decider = torch.nn.Sequential( | |||
torch.nn.Conv2d(last_channel, seg_num_labels, kernel_size=1), | |||
torch.nn.LogSoftmax(dim=1), | |||
) | |||
self.gamma = gamma | |||
self.base_model = None | |||
self.base_saving_dir = base_saving_dir | |||
self.debug = self.base_saving_dir is not None | |||
if self.debug: | |||
os.makedirs(self.base_saving_dir , exist_ok=True) | |||
def set_base_model(self,base_model:Model): | |||
self.base_model = base_model | |||
return | |||
def _forward(self, x): | |||
encoders_outs = [] | |||
scales_outs = [] | |||
out = x | |||
x_ = x | |||
# print("encoder output shapes:") | |||
for enc in self.encoders: | |||
out = enc(out) | |||
# print(out.shape) | |||
encoders_outs.append(out) | |||
out = self.dropout_1(out) | |||
if self._concat_original_image: | |||
x_ = F.interpolate(x_, out.shape[2:], mode='bilinear', align_corners=True) | |||
out = torch.cat([out, x_], dim=1) | |||
final_encoder_out = out | |||
final_encoder_out = final_encoder_out.flatten(2) | |||
out = self.dropout_2(out) | |||
out = self.bridge(out) | |||
# print("decoder output shapes:") | |||
for dec in self.decoders: | |||
# print(out.shape) | |||
encoder_out = encoders_outs.pop(-1) | |||
out = dec(out, encoder_out) | |||
scales_outs.append(F.interpolate(out, (512, 512))) | |||
out = self.dropout_1(out) | |||
scales_outs[-1] = self.decider(out) | |||
return scales_outs | |||
def draw_segmentations(self,mammo_x , pixel_probs , ground_truths, prefix=''): | |||
names = DataloaderContext.instance.dataloader.get_current_batch_samples_names() | |||
for image,prediction , ground_truth, name in zip(mammo_x, pixel_probs , ground_truths, names): | |||
image = (image * 255).to(torch.uint8) | |||
if len(image.shape)< 3: | |||
image = image.unsqueeze(0) | |||
image = torch.repeat_interleave(image, 3, dim=0).cpu() # rgb | |||
# image | |||
new_prediction = (prediction.clone().detach()>0.5).bool().unsqueeze(0) | |||
label = (ground_truth.clone().detach()>0.5).bool().unsqueeze(0) | |||
# new_pred | |||
masks = torch.cat([label]).cpu() | |||
# print(masks.shape) | |||
alpha = 0.4 | |||
final_image = draw_segmentation_masks(image, masks, alpha , colors=['green']).transpose(0,1).transpose(1,2) | |||
name = name.replace('/' , '_') | |||
Image.fromarray(final_image.cpu().numpy()).save( | |||
os.path.join(self.base_saving_dir , f"{name}_{prefix}_label.png")) | |||
masks = torch.cat([new_prediction]).cpu() | |||
# print(masks.shape) | |||
alpha = 0.4 | |||
final_image = draw_segmentation_masks(image, masks, alpha , colors=['red']).transpose(0,1).transpose(1,2) | |||
Image.fromarray(final_image.cpu().numpy()).save( | |||
os.path.join(self.base_saving_dir , f"{name}_{prefix}_pred.png")) | |||
return | |||
def draw_raw_segmentations(self,mammo_x , pixel_probs , ground_truths , prefix=""): | |||
names = DataloaderContext.instance.dataloader.get_current_batch_samples_names() | |||
for image, prediction , ground_truth, name in zip(mammo_x, pixel_probs , ground_truths, names): | |||
# image | |||
name = name.replace('/' , '_') | |||
if len(prefix) >0: | |||
name = f"{name}_{prefix}" | |||
image = (image * 255).to(torch.uint8) | |||
Image.fromarray(image.squeeze().cpu().numpy()).save(f"{self.base_saving_dir}/{name}.png") | |||
# pred | |||
prediction = (prediction.clone().detach().squeeze() * 255).to(torch.uint8) | |||
Image.fromarray(prediction.cpu().numpy()).save(f"{self.base_saving_dir}/{name}_prediction.png") | |||
# label | |||
label = (ground_truth.clone().detach().squeeze() * 255).to(torch.uint8) | |||
Image.fromarray(label.cpu().numpy()).save(f"{self.base_saving_dir}/{name}_label.png") | |||
return | |||
def draw_error_mask(self , mammo_x , probabilities , ground_truth): | |||
names = DataloaderContext.instance.dataloader.get_current_batch_samples_names() | |||
error_mask = torch.zeros(size = (len(mammo_x) , 3 , mammo_x.shape[-2] , mammo_x.shape[-1])) | |||
prob_channels = probabilities.shape[1] | |||
error_mask[:,0:prob_channels] = ground_truth - probabilities | |||
error_mask[error_mask < 0] = 0 | |||
error_mask *= 255 | |||
error_mask = error_mask.to(torch.uint8).cpu().numpy() | |||
for error_image,image ,name in zip(error_mask, mammo_x, names): | |||
image = (image * 255).to(torch.uint8).squeeze() | |||
name = name.replace('/' , '_') | |||
Image.fromarray(image.cpu().numpy()).save( | |||
os.path.join(self.base_saving_dir , f"{name}.png")) | |||
Image.fromarray(error_image.transpose(1,2,0)).save( | |||
os.path.join(self.base_saving_dir , f"{name}_error_mask.png")) | |||
return | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_breast_mask: torch.Tensor): | |||
batch_size = len(mammo_x) | |||
predicted_mask = None | |||
chosen_samples = [0 for i in range(batch_size)] | |||
if not self.is_base: | |||
predicted_mask = torch.exp(self.base_model._forward(mammo_x)).to(mammo_x.device)[:,1] | |||
predicted_mask[predicted_mask>=0.5] = 1 | |||
predicted_mask[predicted_mask < 0.5] = 0 | |||
intersection = predicted_mask.squeeze() * mammo_breast_mask.squeeze() | |||
union = ((predicted_mask.squeeze() + mammo_breast_mask.squeeze()) >0.5).to(torch.int) | |||
for i in range(batch_size): | |||
if (intersection[i].sum()/union[i].sum()) > 0.75: | |||
chosen_samples[i] = 1 | |||
output = dict() | |||
log_pixel_probs = self._forward(mammo_x) | |||
for i in range(len(log_pixel_probs)): | |||
log_pixel_probs[i] = torch.exp(log_pixel_probs[i]).to(mammo_x.device) | |||
return log_pixel_probs | |||
def apply_log_transform(self,pixel_probs): | |||
pixel_probs = pixel_probs/pixel_probs.max() | |||
pixel_probs = (pixel_probs * 255).to(torch.int16) | |||
for i in range(0, len(pixel_probs)): | |||
c = (255 / torch.log(1 + torch.max(pixel_probs[i][1]))).to(pixel_probs.device) | |||
pixel_probs[i][1] = c * (torch.log(pixel_probs[i][1] + 1)) | |||
pixel_probs[i][0] = 255 - pixel_probs[i][1] | |||
pixel_probs = pixel_probs.to(torch.float32) | |||
pixel_probs = pixel_probs/pixel_probs.max() | |||
return pixel_probs | |||
def refine_segmentation_using_clustering(self, segmentation , image): | |||
y,x = torch.where(segmentation >= 0.5) | |||
pixels_values = image[segmentation >= 0.5].tolist() | |||
median = np.median(np.array(pixels_values)) | |||
std = np.std(np.array(pixels_values)) | |||
pixels_values += [median-3*std , median+3*std] | |||
self.clusterer.fit(np.array(pixels_values).reshape(-1,1)) | |||
pixel_labels = self.clusterer.labels_ | |||
true_positives_label = pixel_labels[-1] # pixel labels of closed to 1.0 pixels | |||
false_positive_indices = [i for i in range(len(pixel_labels)) if pixel_labels[i] != true_positives_label] | |||
false_positive_indices = torch.tensor(false_positive_indices[:-1]).to(image.device) # no need for last one | |||
if len(false_positive_indices) >0: | |||
false_positive_x = x[false_positive_indices] | |||
false_positive_y = y[false_positive_indices] | |||
segmentation[false_positive_y , false_positive_x] = 0 | |||
return segmentation |
@@ -0,0 +1,80 @@ | |||
import torch | |||
from torch import nn | |||
import torch.nn.functional as F | |||
from .loss import unsupervised_loss | |||
from .utils import UpSampler, DownSampler | |||
from mlassistant.core import ModelIO, Model | |||
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi | |||
from .....utils.generalized_dice import dice_loss | |||
class URPC(Model): | |||
""" | |||
The main design for URPC architecture with UNet as the | |||
base network. | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
self.scale_channels = [64, 128, 128, 256, 512] # Filter numbers for each scale | |||
self.shape = (512, 512) # Change the probability map shape according to the mammo scans shape | |||
self.input_channels = 1 | |||
self.class_num = 2 | |||
self.iter_counter = 13400 | |||
self.down_sampler = DownSampler(self.input_channels, self.scale_channels) | |||
self.up_sampler = UpSampler(self.scale_channels, self.class_num, self.shape) | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_loss_and_gt: torch.Tensor) -> ModelIO: | |||
self.iter_counter += 1 | |||
down_sampled_x = self.down_sampler(mammo_x) | |||
out0, out1, out2, out3 = self.up_sampler(down_sampled_x) | |||
loss_type = mammo_loss_and_gt[:, 0, ...] | |||
ground_truth = mammo_loss_and_gt[:, 1, ...] * loss_type | |||
main_shape = ground_truth.shape | |||
out_prob_map_shape = out0.shape | |||
loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],out_prob_map_shape[1],main_shape[1], main_shape[2]) | |||
ground_truth = ground_truth.unsqueeze(dim = 1) | |||
mask_channels = torch.cat((1 - ground_truth, ground_truth), 1) | |||
dummy_mask = mask_channels.clone().to(mammo_x.device) | |||
dummy_net_out = out0.clone().to(mammo_x.device) | |||
dummy_mask[loss_channels == 0] = 0.0 | |||
dummy_shape = dummy_mask.shape | |||
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_net_out = torch.cat((dummy_net_out, torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_yhat = out0 * loss_channels | |||
dummy_y = mask_channels * loss_channels | |||
focal_loss = balanced_focal_cross_entropy_loss_semi(dummy_net_out, dummy_mask, 2) | |||
dice_ls = dice_loss(dummy_net_out,dummy_mask) | |||
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y) | |||
weight = 0.6 * ((self.iter_counter // 200) / 400) | |||
unsup_loss = unsupervised_loss([out0[:, [1], ...], out1[:, [1], ...], | |||
out2[:, [1], ...], out3[:, [1], ...]]) | |||
output = { | |||
'pixel_probs': out0, | |||
'org_pixel_labels': mammo_loss_and_gt, | |||
'loss': focal_loss * (0.8 - weight) + unsup_loss * (0.2 + weight), | |||
'ce_loss': ce_loss, | |||
'dice_loss': dice_ls, | |||
} | |||
return output | |||
# python main.py semi_supervised.segmentation.entrypoint_urpc -phase train -save_dir /home/dadbeh/Results -device cuda:2 -console_file false | |||
# python metric.py -L ../Results/TransUNetTopK/log.csv -M train_Seg_Sens_0 val_Seg_Sens_0 -M train_Seg_Sens_1 val_Seg_Sens_1 -M train_Seg_AvgSens val_Seg_AvgSens -M train_Obj_FP/P-1 val_Obj_FP/P-1 -M train_Obj_Sen-T-1 val_Obj_Sen-T-1 -M train_ls_Loss val_ls_Loss -M train_ls_focal_loss val_ls_focal_loss -M train_ls_ce_loss train_ls_ce_loss -M train_ls_dice_loss val_ls_dice_loss -M train_Obj_Prec-T-1 val_Obj_Prec-T-1 | |||
# python main.py semi_supervised.segmentation.entrypoint_urpc -phase saveModelOutput -samples_dir liveE15 -device cuda:2 -save_dir /home/dadbeh/Results -load_dir /home/dadbeh/TempResults -epoch 113 -console_file false | |||
# python main.py semi_supervised.segmentation.entrypoint_urpc -phase saveModelOutput -samples_dir liveE15 -device cuda:2 -save_dir /home/dadbeh/Results/TopKFocalCEWeightResults -load_dir /home/dadbeh/Results/TopKFocalCEWeightResults -epoch 37 -console_file false |
@@ -0,0 +1,92 @@ | |||
import torch | |||
from torch import nn | |||
import torch.nn.functional as F | |||
from .loss import unsupervised_loss | |||
from .resunet import ResunetSegmentor | |||
from mlassistant.core import ModelIO, Model | |||
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi | |||
from .....utils.generalized_dice import dice_loss | |||
class URPC(Model): | |||
""" | |||
The main design for URPC architecture with UNet as the | |||
base network. | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
self.shape = (512, 512) # Change the probability map shape according to the mammo scans shape | |||
self.input_channels = 1 | |||
self.class_num = 2 | |||
self.iter_counter = 0 | |||
encoders_conv_kws = [ | |||
{"inp_channel": 1, "out_channel": 32, "n_conv_blocks": 2}, | |||
{"inp_channel": 33, "out_channel": 64, "n_conv_blocks": 2}, | |||
{"inp_channel": 65, "out_channel": 128, "n_conv_blocks": 3}, | |||
{"inp_channel": 129, "out_channel": 256, "n_conv_blocks": 3}, | |||
] | |||
bridge_conv_kws = {"inp_channel": 257, "out_channel": 512} | |||
decoders_conv_kws = [ | |||
{"inp_channel": 512, "enc_inp_channel": 256, "out_channel": 256}, | |||
{"inp_channel": 256, "enc_inp_channel": 128, "out_channel": 64}, | |||
{"inp_channel": 64, "enc_inp_channel": 64, "out_channel": 32}, | |||
{"inp_channel": 32, "enc_inp_channel": 32, "out_channel": 16}, | |||
] | |||
self.base_network = ResunetSegmentor( | |||
seg_num_labels=self.class_num, | |||
concat_original_image=True, | |||
encoders_conv_kws=encoders_conv_kws, | |||
decoders_conv_kws=decoders_conv_kws, | |||
bridge_conv_kws=bridge_conv_kws, | |||
) | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_loss_and_gt: torch.Tensor) -> ModelIO: | |||
self.iter_counter += 1 | |||
out3, out2, out1, out0 = self.base_network(mammo_x, None) | |||
loss_type = mammo_loss_and_gt[:, 0, ...] | |||
ground_truth = mammo_loss_and_gt[:, 1, ...] * loss_type | |||
main_shape = ground_truth.shape | |||
out_prob_map_shape = out0.shape | |||
loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],out_prob_map_shape[1],main_shape[1], main_shape[2]) | |||
ground_truth = ground_truth.unsqueeze(dim = 1) | |||
mask_channels = torch.cat((1 - ground_truth, ground_truth), 1) | |||
dummy_mask = mask_channels.clone().to(mammo_x.device) | |||
dummy_net_out = out0.clone().to(mammo_x.device) | |||
dummy_mask[loss_channels == 0] = 0.0 | |||
dummy_shape = dummy_mask.shape | |||
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_net_out = torch.cat((dummy_net_out, torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_yhat = out0 * loss_channels | |||
dummy_y = mask_channels * loss_channels | |||
focal_loss = balanced_focal_cross_entropy_loss_semi(dummy_net_out, dummy_mask, 4) | |||
dice_ls = dice_loss(dummy_net_out,dummy_mask) | |||
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y) | |||
weight = 0.6 * ((self.iter_counter // 200) / 200) | |||
unsup_loss = unsupervised_loss([out0[:, [1], ...], out1[:, [1], ...], | |||
out2[:, [1], ...], out3[:, [1], ...]]) | |||
output = { | |||
'pixel_probs': out0, | |||
'org_pixel_labels': mammo_loss_and_gt, | |||
'loss': focal_loss * (0.8 - weight) + unsup_loss * (0.2 + weight), | |||
'ce_loss': ce_loss, | |||
'dice_loss': dice_ls, | |||
} | |||
return output |
@@ -0,0 +1,153 @@ | |||
from typing import Tuple, List | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from torchvision import transforms | |||
class ConvolutionBlock(nn.Module): | |||
""" | |||
A block with two convolutional layers followed by batch norm | |||
layers and LeakyReLU as activation function. | |||
""" | |||
def __init__(self, input_channel: int, output_channel: int): | |||
super(ConvolutionBlock, self).__init__() | |||
self.conv = nn.Sequential( | |||
nn.Conv2d(input_channel, output_channel, kernel_size=3, padding=1), | |||
nn.BatchNorm2d(output_channel), | |||
nn.LeakyReLU(), | |||
nn.Conv2d(output_channel, output_channel, kernel_size=3, padding=1), | |||
nn.BatchNorm2d(output_channel), | |||
nn.LeakyReLU() | |||
) | |||
def forward(self, x: torch.Tensor) -> torch.Tensor: | |||
return self.conv(x) | |||
class DownSamplingBlock(nn.Module): | |||
""" | |||
A down-sampling block with a max pooling layer followed by | |||
a convolutional block. | |||
""" | |||
def __init__(self, input_channel: int, output_channel: int): | |||
super(DownSamplingBlock, self).__init__() | |||
self.down_sampler = nn.Sequential( | |||
nn.MaxPool2d(2), | |||
ConvolutionBlock(input_channel, output_channel) | |||
) | |||
def forward(self, x: torch.Tensor) -> torch.Tensor: | |||
return self.down_sampler(x) | |||
class UpSamplingBlock(nn.Module): | |||
""" | |||
An up-sampling block with a 2d transpose convolution layer | |||
for up-sampling followed by a ConvolutionBlock. | |||
""" | |||
def __init__(self, input_channel1, input_channel2, output_channel): | |||
super(UpSamplingBlock, self).__init__() | |||
self.up = nn.ConvTranspose2d(input_channel1, input_channel2, kernel_size=2, stride=2) | |||
self.conv = ConvolutionBlock(input_channel2 * 2, output_channel) | |||
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: | |||
x1 = self.up(x1) | |||
height = x1.size()[2] # Finding dimensions of the up-sampled image for cropping | |||
width = x1.size()[3] | |||
x2 = transforms.CenterCrop((height, width))(x2) # Cropping the image from the skip connection from the center | |||
x = torch.cat([x2, x1], dim=1) # Concatenating both images together | |||
return self.conv(x) | |||
class DownSampler(nn.Module): | |||
""" | |||
The module used for down-sampling process consisted of | |||
multiple down-sampling blocks and one convolutional block | |||
at the first part. | |||
""" | |||
def __init__(self, input_channel: int, scale_channels: list): | |||
super(DownSampler, self).__init__() | |||
self.input_channel = input_channel | |||
self.scale_channels = scale_channels | |||
self.conv = ConvolutionBlock(self.input_channel, self.scale_channels[0]) | |||
self.down_sampler1 = DownSamplingBlock(self.scale_channels[0], self.scale_channels[1]) | |||
self.down_sampler2 = DownSamplingBlock(self.scale_channels[1], self.scale_channels[2]) | |||
self.down_sampler3 = DownSamplingBlock(self.scale_channels[2], self.scale_channels[3]) | |||
self.down_sampler4 = DownSamplingBlock(self.scale_channels[3], self.scale_channels[4]) | |||
def forward(self, x: torch.Tensor) -> List: | |||
features = self.conv(x) | |||
down_sampled1 = self.down_sampler1(features) | |||
down_sampled2 = self.down_sampler2(down_sampled1) | |||
down_sampled3 = self.down_sampler3(down_sampled2) | |||
down_sampled4 = self.down_sampler4(down_sampled3) | |||
return [features, down_sampled1, down_sampled2, down_sampled3, down_sampled4] | |||
class UpSampler(nn.Module): | |||
""" | |||
The module used for up-sampling process consisted of | |||
multiple up-sampling blocks and one convolutional layer | |||
at the last part for the probability map. | |||
""" | |||
def __init__(self, scale_channels: list, class_num: int, shape: tuple): | |||
super(UpSampler, self).__init__() | |||
self.scale_channels = scale_channels | |||
self.class_num = class_num | |||
self.shape = shape | |||
self.up_sampler4 = UpSamplingBlock(scale_channels[4], scale_channels[3], scale_channels[3]) | |||
self.up_sampler3 = UpSamplingBlock(scale_channels[3], scale_channels[2], scale_channels[2]) | |||
self.up_sampler2 = UpSamplingBlock(scale_channels[2], scale_channels[1], scale_channels[1]) | |||
self.up_sampler1 = UpSamplingBlock(scale_channels[1], scale_channels[0], scale_channels[0]) | |||
self.probability_map_conv0 = nn.Sequential( | |||
nn.Conv2d(self.scale_channels[0], self.class_num, kernel_size=1) | |||
) | |||
self.probability_map_conv1 = nn.Sequential( | |||
nn.Conv2d(self.scale_channels[1], self.class_num, kernel_size=1) | |||
) | |||
self.probability_map_conv2 = nn.Sequential( | |||
nn.Conv2d(self.scale_channels[2], self.class_num, kernel_size=1) | |||
) | |||
self.probability_map_conv3 = nn.Sequential( | |||
nn.Conv2d(self.scale_channels[3], self.class_num, kernel_size=1) | |||
) | |||
def forward(self, down_outputs: list) -> Tuple: | |||
""" | |||
NOTE: Numbers in front of the var names indicate the scale number. | |||
""" | |||
down_outputs0 = down_outputs.pop(0) | |||
down_outputs1 = down_outputs.pop(0) | |||
down_outputs2 = down_outputs.pop(0) | |||
down_outputs3 = down_outputs.pop(0) | |||
down_outputs4 = down_outputs.pop(0) | |||
up_sampled4 = self.up_sampler4(down_outputs4, down_outputs3) | |||
up_sampled4 = F.interpolate(up_sampled4, (64, 64)) | |||
output3 = self.probability_map_conv3(up_sampled4) | |||
interpolated_output3 = F.interpolate(output3, self.shape) | |||
up_sampled3 = self.up_sampler3(up_sampled4, down_outputs2) | |||
output2 = self.probability_map_conv2(up_sampled3) | |||
interpolated_output2 = F.interpolate(output2, self.shape) | |||
up_sampled2 = self.up_sampler2(up_sampled3, down_outputs1) | |||
output1 = self.probability_map_conv1(up_sampled2) | |||
interpolated_output1 = F.interpolate(output1, self.shape) | |||
up_sampled1 = self.up_sampler1(up_sampled2, down_outputs0) | |||
output0 = self.probability_map_conv0(up_sampled1) | |||
return torch.softmax(output0, dim=1), torch.softmax(interpolated_output1, dim=1), \ | |||
torch.softmax(interpolated_output2, dim=1), torch.softmax(interpolated_output3, dim=1) |
@@ -0,0 +1,111 @@ | |||
import torch | |||
from torch import nn | |||
import torch.nn.functional as F | |||
import resnet as resnet | |||
class DeepLabV3Plus(nn.Module): | |||
def __init__(self): | |||
super(DeepLabV3Plus, self).__init__() | |||
self.backbone = resnet.__dict__['resnet101'](pretrained=False) | |||
low_channels = 256 | |||
high_channels = 2048 | |||
self.head = ASPPModule(high_channels, [6, 12, 18]) | |||
self.reduce = nn.Sequential(nn.Conv2d(low_channels, 48, 1, bias=False), | |||
nn.BatchNorm2d(48), | |||
nn.ReLU(True)) | |||
self.fuse = nn.Sequential(nn.Conv2d(high_channels // 8 + 48, 256, 3, padding=1, bias=False), | |||
nn.BatchNorm2d(256), | |||
nn.ReLU(True), | |||
nn.Conv2d(256, 256, 3, padding=1, bias=False), | |||
nn.BatchNorm2d(256), | |||
nn.ReLU(True)) | |||
self.classifier = nn.Conv2d(256, cfg['nclass'], 1, bias=True) | |||
def forward(self, x, need_fp=False): | |||
h, w = x.shape[-2:] | |||
feats = self.backbone.base_forward(x) | |||
c1, c4 = feats[0], feats[-1] | |||
if need_fp: | |||
outs = self._decode(torch.cat((c1, nn.Dropout2d(0.5)(c1))), | |||
torch.cat((c4, nn.Dropout2d(0.5)(c4)))) | |||
outs = F.interpolate(outs, size=(h, w), mode="bilinear", align_corners=True) | |||
out, out_fp = outs.chunk(2) | |||
return out, out_fp | |||
out = self._decode(c1, c4) | |||
out = F.interpolate(out, size=(h, w), mode="bilinear", align_corners=True) | |||
return out | |||
def _decode(self, c1, c4): | |||
c4 = self.head(c4) | |||
c4 = F.interpolate(c4, size=c1.shape[-2:], mode="bilinear", align_corners=True) | |||
c1 = self.reduce(c1) | |||
feature = torch.cat([c1, c4], dim=1) | |||
feature = self.fuse(feature) | |||
out = self.classifier(feature) | |||
return out | |||
def ASPPConv(in_channels, out_channels, atrous_rate): | |||
block = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, padding=atrous_rate, | |||
dilation=atrous_rate, bias=False), | |||
nn.BatchNorm2d(out_channels), | |||
nn.ReLU(True)) | |||
return block | |||
class ASPPPooling(nn.Module): | |||
def __init__(self, in_channels, out_channels): | |||
super(ASPPPooling, self).__init__() | |||
self.gap = nn.Sequential(nn.AdaptiveAvgPool2d(1), | |||
nn.Conv2d(in_channels, out_channels, 1, bias=False), | |||
nn.BatchNorm2d(out_channels), | |||
nn.ReLU(True)) | |||
def forward(self, x): | |||
h, w = x.shape[-2:] | |||
pool = self.gap(x) | |||
return F.interpolate(pool, (h, w), mode="bilinear", align_corners=True) | |||
class ASPPModule(nn.Module): | |||
def __init__(self, in_channels, atrous_rates): | |||
super(ASPPModule, self).__init__() | |||
out_channels = in_channels // 8 | |||
rate1, rate2, rate3 = atrous_rates | |||
self.b0 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False), | |||
nn.BatchNorm2d(out_channels), | |||
nn.ReLU(True)) | |||
self.b1 = ASPPConv(in_channels, out_channels, rate1) | |||
self.b2 = ASPPConv(in_channels, out_channels, rate2) | |||
self.b3 = ASPPConv(in_channels, out_channels, rate3) | |||
self.b4 = ASPPPooling(in_channels, out_channels) | |||
self.project = nn.Sequential(nn.Conv2d(5 * out_channels, out_channels, 1, bias=False), | |||
nn.BatchNorm2d(out_channels), | |||
nn.ReLU(True)) | |||
def forward(self, x): | |||
feat0 = self.b0(x) | |||
feat1 = self.b1(x) | |||
feat2 = self.b2(x) | |||
feat3 = self.b3(x) | |||
feat4 = self.b4(x) | |||
y = torch.cat((feat0, feat1, feat2, feat3, feat4), 1) | |||
return self.project(y) |
@@ -0,0 +1,154 @@ | |||
import torch | |||
import torch.nn as nn | |||
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): | |||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, | |||
padding=dilation, groups=groups, bias=False, dilation=dilation) | |||
def conv1x1(in_planes, out_planes, stride=1): | |||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | |||
class Bottleneck(nn.Module): | |||
expansion = 4 | |||
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, | |||
base_width=64, dilation=1): | |||
super(Bottleneck, self).__init__() | |||
norm_layer = nn.BatchNorm2d | |||
width = int(planes * (base_width / 64.)) * groups | |||
self.conv1 = conv1x1(inplanes, width) | |||
self.bn1 = norm_layer(width) | |||
self.conv2 = conv3x3(width, width, stride, groups, dilation) | |||
self.bn2 = norm_layer(width) | |||
self.conv3 = conv1x1(width, planes * self.expansion) | |||
self.bn3 = norm_layer(planes * self.expansion) | |||
self.relu = nn.ReLU(inplace=True) | |||
self.downsample = downsample | |||
self.stride = stride | |||
def forward(self, x): | |||
identity = x | |||
out = self.conv1(x) | |||
out = self.bn1(out) | |||
out = self.relu(out) | |||
out = self.conv2(out) | |||
out = self.bn2(out) | |||
out = self.relu(out) | |||
out = self.conv3(out) | |||
out = self.bn3(out) | |||
if self.downsample is not None: | |||
identity = self.downsample(x) | |||
out += identity | |||
out = self.relu(out) | |||
return out | |||
class ResNet(nn.Module): | |||
def __init__(self, block, layers, groups=1, width_per_group=64): | |||
super(ResNet, self).__init__() | |||
norm_layer = nn.BatchNorm2d | |||
self._norm_layer = norm_layer | |||
self.inplanes = 128 | |||
self.dilation = 1 | |||
replace_stride_with_dilation = [False, False, True] | |||
if len(replace_stride_with_dilation) != 3: | |||
raise ValueError("replace_stride_with_dilation should be None " | |||
"or a 3-element tuple, got {}".format(replace_stride_with_dilation)) | |||
self.groups = groups | |||
self.base_width = width_per_group | |||
self.conv1 = nn.Sequential( | |||
nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False), | |||
norm_layer(64), | |||
nn.ReLU(inplace=True), | |||
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False), | |||
norm_layer(64), | |||
nn.ReLU(inplace=True), | |||
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False), | |||
) | |||
self.bn1 = norm_layer(self.inplanes) | |||
self.relu = nn.ReLU(inplace=True) | |||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |||
self.layer1 = self._make_layer(block, 64, layers[0]) | |||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, | |||
dilate=replace_stride_with_dilation[0]) | |||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, | |||
dilate=replace_stride_with_dilation[1]) | |||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, | |||
dilate=replace_stride_with_dilation[2]) | |||
for m in self.modules(): | |||
if isinstance(m, nn.Conv2d): | |||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | |||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): | |||
nn.init.constant_(m.weight, 1) | |||
nn.init.constant_(m.bias, 0) | |||
def _make_layer(self, block, planes, blocks, stride=1, dilate=False): | |||
norm_layer = self._norm_layer | |||
downsample = None | |||
previous_dilation = self.dilation | |||
if dilate: | |||
self.dilation *= stride | |||
stride = 1 | |||
if stride != 1 or self.inplanes != planes * block.expansion: | |||
downsample = nn.Sequential( | |||
conv1x1(self.inplanes, planes * block.expansion, stride), | |||
norm_layer(planes * block.expansion), | |||
) | |||
layers = list() | |||
layers.append(block(self.inplanes, planes, stride, downsample, self.groups, | |||
self.base_width, previous_dilation, norm_layer)) | |||
self.inplanes = planes * block.expansion | |||
for _ in range(1, blocks): | |||
layers.append(block(self.inplanes, planes, groups=self.groups, | |||
base_width=self.base_width, dilation=self.dilation, | |||
norm_layer=norm_layer)) | |||
return nn.Sequential(*layers) | |||
def base_forward(self, x): | |||
x = self.conv1(x) | |||
x = self.bn1(x) | |||
x = self.relu(x) | |||
x = self.maxpool(x) | |||
c1 = self.layer1(x) | |||
c2 = self.layer2(c1) | |||
c3 = self.layer3(c2) | |||
c4 = self.layer4(c3) | |||
return c1, c2, c3, c4 | |||
def _resnet(arch, block, layers, pretrained, **kwargs): | |||
model = ResNet(block, layers, **kwargs) | |||
if pretrained: | |||
pretrained_path = "pretrained/%s.pth" % arch | |||
state_dict = torch.load(pretrained_path) | |||
model.load_state_dict(state_dict, strict=False) | |||
return model | |||
def resnet50(pretrained=False, **kwargs): | |||
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, **kwargs) | |||
def resnet101(pretrained=False, **kwargs): | |||
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, **kwargs) |
@@ -0,0 +1,228 @@ | |||
import torch.nn | |||
from typing import Dict, List, Any | |||
from ......utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss | |||
from torch.nn import functional as F | |||
from PIL import Image | |||
from torchvision.utils import draw_segmentation_masks | |||
import os | |||
from mlassistant.core import Model | |||
from mlassistant.context.dataloader import DataloaderContext | |||
import numpy as np | |||
from typing import Dict, List, Any, Callable, Optional | |||
from torch import nn | |||
from torch.nn import functional as F | |||
from torchvision.models import resnet18 | |||
from mlassistant.core import Model | |||
from .....segmentation.submodules.res_enc_block import ResEncoderBlock | |||
from .....segmentation.submodules.res_dec_block import ResDecoderBlock | |||
from ......enums.roi.roi_aggregation_type import ROIAggregation | |||
from .....full_segmentation.utils import sum_roi_4ch_to_1ch | |||
from .....segmentation.utils import cal_patch_pixel_labels_by_patch_label_and_roi, \ | |||
create_patches_label_based_on_roi, AssembleDisassemble, \ | |||
reduce_resolution, select_non_boundaries | |||
from .....segmentation.losses import cal_focal_loss, cal_weighted_focal_loss | |||
class ResUNet(Model): | |||
def __init__(self, | |||
seg_num_labels: int, | |||
concat_original_image: bool, | |||
encoders_conv_kws: Optional[List[Dict[str, Any]]], | |||
bridge_conv_kws: Optional[Dict[str, Any]], | |||
decoders_conv_kws: List[Dict[str, Any]], | |||
roi_agg_type: ROIAggregation = ROIAggregation.mass_vs_others, | |||
remove_boundaries: bool = False, | |||
pretrained_resnet18: bool = False, | |||
loss=(lambda log_pred, gt: cal_weighted_focal_loss(log_pred, gt, 4)), | |||
use_dropout: bool = False) -> None: | |||
super(ResUNet, self).__init__() | |||
assert encoders_conv_kws is not None or pretrained_resnet18, \ | |||
"for creating encoder part, either of number of channels or using pretrained version of resnet18 should be determined" | |||
assert not concat_original_image or not pretrained_resnet18, \ | |||
"concatenating is currently not supported when using pretrained version" | |||
num_encoders = 6 if pretrained_resnet18 else len(encoders_conv_kws) | |||
self.halving_depth = num_encoders - len(decoders_conv_kws) + 1 | |||
self._roi_agg_type: ROIAggregation = roi_agg_type | |||
self._remove_boundaries: bool = remove_boundaries | |||
self._loss = loss | |||
self._concat_original_image: bool = concat_original_image | |||
self.encoders: nn.ModuleList = None | |||
self.bridge: nn.Sequential = None | |||
self._create_encoder(pretrained_resnet18, encoders_conv_kws) | |||
self._create_bridge(bridge_conv_kws) | |||
self.decoders = nn.ModuleList([ResDecoderBlock(**kwargs) | |||
for kwargs in decoders_conv_kws]) | |||
last_channel = decoders_conv_kws[-1]['out_channel'] | |||
if not use_dropout: | |||
self.decider = nn.Sequential( | |||
nn.Conv2d(last_channel, seg_num_labels, kernel_size=1), | |||
nn.LogSoftmax(dim=1)) | |||
else: | |||
self.decider = nn.Sequential( | |||
nn.Dropout(0.3), | |||
nn.Conv2d(last_channel, seg_num_labels, kernel_size=1), | |||
nn.LogSoftmax(dim=1)) | |||
def _create_encoder(self, | |||
pretrained: bool, | |||
encoders_conv_kws: Optional[List[Dict[str, Any]]]) -> None: | |||
if pretrained: | |||
resnet = resnet18(pretrained=True) | |||
self._inp_channel = 3 | |||
self.encoders = nn.ModuleList([]) | |||
self.encoders.append( | |||
nn.Sequential( | |||
resnet.conv1, | |||
resnet.bn1, | |||
resnet.relu, | |||
resnet.maxpool)) | |||
for l in [resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4]: | |||
self.encoders.append(l) | |||
else: | |||
self._inp_channel = encoders_conv_kws[0]['inp_channel'] | |||
self.encoders = nn.ModuleList([ResEncoderBlock(**encoders_conv_kws[0])]) | |||
encoders_conv_kws.pop(0) | |||
self.encoders += nn.ModuleList([nn.Sequential( | |||
nn.AvgPool2d(kernel_size=2, stride=2), | |||
ResEncoderBlock(**kwargs) | |||
) | |||
for kwargs in encoders_conv_kws]) | |||
def _create_bridge(self, | |||
bridge_conv_kws: Optional[List[Dict[str, Any]]]) -> None: | |||
if bridge_conv_kws is None: | |||
self.bridge = nn.AvgPool2d(kernel_size=2, stride=2), | |||
else: | |||
self.bridge = nn.Sequential( | |||
nn.AvgPool2d(kernel_size=2, stride=2), | |||
ResEncoderBlock(**bridge_conv_kws)) | |||
def _forward(self, x): | |||
x = x.repeat_interleave(self._inp_channel, 1) | |||
encoders_outs = [] | |||
out = x | |||
x_ = x | |||
for enc in self.encoders: | |||
out = enc(out) | |||
encoders_outs.append(out) | |||
if self._concat_original_image: | |||
x_ = F.interpolate(x_, out.shape[2:], mode='bilinear', align_corners=True) | |||
out = torch.cat([out, x_], dim=1) | |||
out = self.bridge(out) | |||
for dec in self.decoders: | |||
out = dec(out, encoders_outs.pop(-1)) | |||
return self.decider(out) | |||
class ResunetSegmentor(ResUNet): | |||
def __init__(self, | |||
seg_num_labels: int, | |||
concat_original_image: bool, | |||
encoders_conv_kws: List[Dict[str, Any]], | |||
bridge_conv_kws: Dict[str, Any], | |||
decoders_conv_kws: List[Dict[str, Any]] , | |||
log_transform=False , | |||
is_base=True, | |||
base_saving_dir=None , | |||
gamma=2): | |||
""" | |||
Args: | |||
seg_num_labels (int): [description] = 2 (background vs nipple) | |||
concat_original_image (bool): [description] | |||
encoders_conv_kws (List[Dict[str, Any]]): [description] | |||
bridge_conv_kws (Dict[str, Any]): [description] | |||
decoders_conv_kws (List[Dict[str, Any]]): [description] | |||
""" | |||
self.is_base = is_base | |||
self.batch_count = 0 | |||
super(ResunetSegmentor, self).__init__( | |||
seg_num_labels, | |||
concat_original_image, | |||
encoders_conv_kws, | |||
bridge_conv_kws, | |||
decoders_conv_kws, | |||
) | |||
self.seg_num_labels = seg_num_labels | |||
self.log_transform = log_transform | |||
self.dropout_2 = torch.nn.Dropout(0.4) | |||
self.dropout_1 = torch.nn.Dropout(0.2) | |||
last_channel = decoders_conv_kws[-1]["out_channel"] | |||
self.decider = torch.nn.Sequential( | |||
torch.nn.Conv2d(last_channel, seg_num_labels, kernel_size=1), | |||
torch.nn.LogSoftmax(dim=1), | |||
) | |||
self.gamma = gamma | |||
self.base_model = None | |||
self.base_saving_dir = base_saving_dir | |||
self.debug = self.base_saving_dir is not None | |||
if self.debug: | |||
os.makedirs(self.base_saving_dir , exist_ok=True) | |||
def set_base_model(self,base_model:Model): | |||
self.base_model = base_model | |||
return | |||
def _forward(self, x, need_fp=False): | |||
encoders_outs = [] | |||
out = x | |||
x_ = x | |||
for enc in self.encoders: | |||
out = enc(out) | |||
encoders_outs.append(out) | |||
out = self.dropout_1(out) | |||
if self._concat_original_image: | |||
x_ = F.interpolate(x_, out.shape[2:], mode='bilinear', align_corners=True) | |||
out = torch.cat([out, x_], dim=1) | |||
final_encoder_out = out | |||
final_encoder_out = final_encoder_out.flatten(2) | |||
out = self.dropout_2(out) | |||
out = self.bridge(out) | |||
for dec in self.decoders: | |||
encoder_out = encoders_outs.pop(-1) | |||
if need_fp: | |||
encoder_out = torch.cat((encoder_out, nn.Dropout2d(0.5)(encoder_out))) | |||
out = dec(out, encoder_out) | |||
out = self.dropout_1(out) | |||
final_segmentation = self.decider(out) | |||
if need_fp: | |||
return final_segmentation.chunk(2) | |||
return final_segmentation | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
need_fp=False): | |||
batch_size = len(mammo_x) | |||
predicted_mask = None | |||
chosen_samples = [0 for i in range(batch_size)] | |||
output = dict() | |||
log_pixel_probs = self._forward(mammo_x, need_fp) | |||
pixel_probs = torch.exp(log_pixel_probs).to(mammo_x.device) | |||
return output |
@@ -0,0 +1,174 @@ | |||
import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
from torch.distributions.uniform import Uniform | |||
def kaiming_normal_init_weight(model): | |||
for m in model.modules(): | |||
if isinstance(m, nn.Conv3d): | |||
torch.nn.init.kaiming_normal_(m.weight) | |||
elif isinstance(m, nn.BatchNorm3d): | |||
m.weight.data.fill_(1) | |||
m.bias.data.zero_() | |||
return model | |||
def sparse_init_weight(model): | |||
for m in model.modules(): | |||
if isinstance(m, nn.Conv3d): | |||
torch.nn.init.sparse_(m.weight, sparsity=0.1) | |||
elif isinstance(m, nn.BatchNorm3d): | |||
m.weight.data.fill_(1) | |||
m.bias.data.zero_() | |||
return model | |||
class ConvBlock(nn.Module): | |||
"""two convolution layers with batch norm and leaky relu""" | |||
def __init__(self, in_channels, out_channels, dropout_p): | |||
super(ConvBlock, self).__init__() | |||
self.conv_conv = nn.Sequential( | |||
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), | |||
nn.BatchNorm2d(out_channels), | |||
nn.LeakyReLU(), | |||
nn.Dropout(dropout_p), | |||
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), | |||
nn.BatchNorm2d(out_channels), | |||
nn.LeakyReLU() | |||
) | |||
def forward(self, x): | |||
return self.conv_conv(x) | |||
class DownBlock(nn.Module): | |||
"""Downsampling followed by ConvBlock""" | |||
def __init__(self, in_channels, out_channels, dropout_p): | |||
super(DownBlock, self).__init__() | |||
self.maxpool_conv = nn.Sequential( | |||
nn.MaxPool2d(2), | |||
ConvBlock(in_channels, out_channels, dropout_p) | |||
) | |||
def forward(self, x): | |||
return self.maxpool_conv(x) | |||
class UpBlock(nn.Module): | |||
"""Upssampling followed by ConvBlock""" | |||
def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, | |||
bilinear=True): | |||
super(UpBlock, self).__init__() | |||
self.bilinear = bilinear | |||
if bilinear: | |||
self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size=1) | |||
self.up = nn.Upsample( | |||
scale_factor=2, mode='bilinear', align_corners=True) | |||
else: | |||
self.up = nn.ConvTranspose2d( | |||
in_channels1, in_channels2, kernel_size=2, stride=2) | |||
self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p) | |||
def forward(self, x1, x2): | |||
if self.bilinear: | |||
x1 = self.conv1x1(x1) | |||
x1 = self.up(x1) | |||
x = torch.cat([x2, x1], dim=1) | |||
return self.conv(x) | |||
class Encoder(nn.Module): | |||
def __init__(self, params): | |||
super(Encoder, self).__init__() | |||
self.params = params | |||
self.in_chns = self.params['in_chns'] | |||
self.ft_chns = self.params['feature_chns'] | |||
self.n_class = self.params['class_num'] | |||
self.bilinear = self.params['bilinear'] | |||
self.dropout = self.params['dropout'] | |||
assert (len(self.ft_chns) == 5) | |||
self.in_conv = ConvBlock( | |||
self.in_chns, self.ft_chns[0], self.dropout[0]) | |||
self.down1 = DownBlock( | |||
self.ft_chns[0], self.ft_chns[1], self.dropout[1]) | |||
self.down2 = DownBlock( | |||
self.ft_chns[1], self.ft_chns[2], self.dropout[2]) | |||
self.down3 = DownBlock( | |||
self.ft_chns[2], self.ft_chns[3], self.dropout[3]) | |||
self.down4 = DownBlock( | |||
self.ft_chns[3], self.ft_chns[4], self.dropout[4]) | |||
def forward(self, x): | |||
x0 = self.in_conv(x) | |||
x1 = self.down1(x0) | |||
x2 = self.down2(x1) | |||
x3 = self.down3(x2) | |||
x4 = self.down4(x3) | |||
return [x0, x1, x2, x3, x4] | |||
class Decoder(nn.Module): | |||
def __init__(self, params): | |||
super(Decoder, self).__init__() | |||
self.params = params | |||
self.in_chns = self.params['in_chns'] | |||
self.ft_chns = self.params['feature_chns'] | |||
self.n_class = self.params['class_num'] | |||
self.bilinear = self.params['bilinear'] | |||
assert (len(self.ft_chns) == 5) | |||
self.up1 = UpBlock( | |||
self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p=0.0) | |||
self.up2 = UpBlock( | |||
self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p=0.0) | |||
self.up3 = UpBlock( | |||
self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p=0.0) | |||
self.up4 = UpBlock( | |||
self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p=0.0) | |||
self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, | |||
kernel_size=3, padding=1) | |||
def forward(self, feature): | |||
x0 = feature[0] | |||
x1 = feature[1] | |||
x2 = feature[2] | |||
x3 = feature[3] | |||
x4 = feature[4] | |||
x = self.up1(x4, x3) | |||
x = self.up2(x, x2) | |||
x = self.up3(x, x1) | |||
x = self.up4(x, x0) | |||
output = self.out_conv(x) | |||
return output | |||
class UNet(nn.Module): | |||
def __init__(self, in_chns, class_num): | |||
super(UNet, self).__init__() | |||
params = {'in_chns': in_chns, | |||
'feature_chns': [64, 128, 128, 256, 512], | |||
'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], | |||
'class_num': class_num, | |||
'bilinear': False, | |||
'acti_func': 'relu'} | |||
self.encoder = Encoder(params) | |||
self.decoder = Decoder(params) | |||
def forward(self, x, need_fp=False): | |||
feature = self.encoder(x) | |||
if need_fp: | |||
outs = self.decoder([torch.cat((feat, nn.Dropout2d(0.5)(feat))) for feat in feature]) | |||
return outs.chunk(2) | |||
output = self.decoder(feature) | |||
return output |
@@ -0,0 +1,147 @@ | |||
import torch | |||
from torch import nn | |||
import torch.nn.functional as F | |||
from torchvision import transforms | |||
from copy import deepcopy | |||
import numpy as np | |||
import random | |||
from .Models.unet import UNet | |||
from .transform import random_rot_flip, random_rotate | |||
from mlassistant.core import ModelIO, Model | |||
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi | |||
from .....utils.generalized_dice import dice_loss | |||
class DiceLoss(nn.Module): | |||
def __init__(self, n_classes): | |||
super(DiceLoss, self).__init__() | |||
self.n_classes = n_classes | |||
def _one_hot_encoder(self, input_tensor): | |||
tensor_list = [] | |||
for i in range(self.n_classes): | |||
temp_prob = input_tensor == i * torch.ones_like(input_tensor) | |||
tensor_list.append(temp_prob) | |||
output_tensor = torch.cat(tensor_list, dim=1) | |||
return output_tensor.float() | |||
def _dice_loss(self, score, target, ignore): | |||
target = target.float() | |||
smooth = 1e-5 | |||
intersect = torch.sum(score[ignore != 1] * target[ignore != 1]) | |||
y_sum = torch.sum(target[ignore != 1] * target[ignore != 1]) | |||
z_sum = torch.sum(score[ignore != 1] * score[ignore != 1]) | |||
loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) | |||
loss = 1 - loss | |||
return loss | |||
def forward(self, inputs, target, weight=None, softmax=False, ignore=None): | |||
if softmax: | |||
inputs = torch.softmax(inputs, dim=1) | |||
target = self._one_hot_encoder(target) | |||
if weight is None: | |||
weight = [1] * self.n_classes | |||
assert inputs.size() == target.size(), 'predict & target shape do not match' | |||
class_wise_dice = [] | |||
loss = 0.0 | |||
for i in range(0, self.n_classes): | |||
dice = self._dice_loss(inputs[:, i], target[:, i], ignore) | |||
class_wise_dice.append(1.0 - dice.item()) | |||
loss += dice * weight[i] | |||
return loss / self.n_classes | |||
class FixMatchModel(nn.Module): | |||
def __init__(self): | |||
super(FixMatchModel, self).__init__() | |||
self.base_model = UNet(1, 2) | |||
self.criterion_dice = DiceLoss(n_classes=2) | |||
def forward(self, x): | |||
weak_img = self._get_weak_augs(x) | |||
strong_img = self._get_strong_augs(weak_img) | |||
pred_u_w = self.base_model(weak_img.to(x.device)).softmax(dim=1) | |||
pred_u_s = self.base_model(strong_img.to(x.device)).softmax(dim=1) | |||
mask_u_w = pred_u_w.argmax(dim=1) | |||
print(pred_u_w.max(dim=1)) | |||
input() | |||
loss_u_s = self.criterion_dice(pred_u_s, mask_u_w.unsqueeze(1).float()) | |||
return self.base_model(x).softmax(dim=1), loss_u_s | |||
def _get_strong_augs(self, img): | |||
img_s = deepcopy(img) | |||
if random.random() < 0.8: | |||
img_s = transforms.ColorJitter(0.25, 0.25, 0.25, 0.15)(img_s) | |||
if random.random() < 0.5: | |||
sigma = np.random.uniform(0.1, 2.0) | |||
img_s = transforms.functional.gaussian_blur(img_s, kernel_size=(5, 5), sigma=sigma) | |||
return img_s | |||
def _get_weak_augs(self, img): | |||
if random.random() > 0.5: | |||
img = random_rot_flip(img) | |||
elif random.random() > 0.5: | |||
img = random_rotate(img) | |||
if isinstance(img, torch.Tensor): | |||
return img | |||
return torch.from_numpy(img).float() | |||
class FixMatch(Model): | |||
""" | |||
FixMatch architecture. | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
self.fix_match = FixMatchModel() | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_loss_and_gt: torch.Tensor) -> ModelIO: | |||
loss_type = mammo_loss_and_gt[:, 0, ...] | |||
ground_truth = mammo_loss_and_gt[:, 1, ...] * loss_type | |||
out_prob_map, unsup_loss = self.fix_match(mammo_x) | |||
main_shape = ground_truth.shape | |||
out_prob_map_shape = out_prob_map.shape | |||
loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],out_prob_map_shape[1],main_shape[1], main_shape[2]) | |||
ground_truth = ground_truth.unsqueeze(dim = 1) | |||
mask_channels = torch.cat((1 - ground_truth, ground_truth), 1) | |||
dummy_mask = mask_channels.clone().to(mammo_x.device) | |||
dummy_net_out = out_prob_map.clone().to(mammo_x.device) | |||
dummy_mask[loss_channels == 0] = 0.0 | |||
dummy_shape = dummy_mask.shape | |||
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_net_out = torch.cat((dummy_net_out, torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_yhat = out_prob_map * loss_channels | |||
dummy_y = mask_channels * loss_channels | |||
dice = dice_loss(dummy_net_out,dummy_mask) | |||
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y) | |||
focal_loss = balanced_focal_cross_entropy_loss_semi(dummy_net_out, dummy_mask, 4) | |||
output = { | |||
'pixel_probs': out_prob_map, | |||
'org_pixel_labels': mammo_loss_and_gt, | |||
'loss': (focal_loss + unsup_loss) / 2, | |||
'focal_loss': focal_loss, | |||
'ce_loss': ce_loss, | |||
'dice_loss': dice, | |||
} | |||
return output |
@@ -0,0 +1,153 @@ | |||
import torch | |||
from torch import nn | |||
import torch.nn.functional as F | |||
from torchvision import transforms | |||
from copy import deepcopy | |||
import numpy as np | |||
import random | |||
from .....models.semi_supervised.base_line.Resunet.Heavy.Models.resunet import RESUNet as RESUNETSEGMENTOR | |||
from .Models.unet import UNet | |||
from .transform import random_rot_flip, random_rotate | |||
from mlassistant.core import ModelIO, Model | |||
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi | |||
from .....utils.generalized_dice import dice_loss | |||
class DiceLoss(nn.Module): | |||
def __init__(self, n_classes): | |||
super(DiceLoss, self).__init__() | |||
self.n_classes = n_classes | |||
def _one_hot_encoder(self, input_tensor): | |||
tensor_list = [] | |||
for i in range(self.n_classes): | |||
temp_prob = input_tensor == i * torch.ones_like(input_tensor) | |||
tensor_list.append(temp_prob) | |||
output_tensor = torch.cat(tensor_list, dim=1) | |||
return output_tensor.float() | |||
def _dice_loss(self, score, target, ignore): | |||
target = target.float() | |||
smooth = 1e-5 | |||
intersect = torch.sum(score[ignore != 1] * target[ignore != 1]) | |||
y_sum = torch.sum(target[ignore != 1] * target[ignore != 1]) | |||
z_sum = torch.sum(score[ignore != 1] * score[ignore != 1]) | |||
loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) | |||
loss = 1 - loss | |||
return loss | |||
def forward(self, inputs, target, weight=None, softmax=False, ignore=None): | |||
if softmax: | |||
inputs = torch.softmax(inputs, dim=1) | |||
target = self._one_hot_encoder(target) | |||
if weight is None: | |||
weight = [1] * self.n_classes | |||
assert inputs.size() == target.size(), 'predict & target shape do not match' | |||
class_wise_dice = [] | |||
loss = 0.0 | |||
for i in range(0, self.n_classes): | |||
dice = self._dice_loss(inputs[:, i], target[:, i], ignore) | |||
class_wise_dice.append(1.0 - dice.item()) | |||
loss += dice * weight[i] | |||
return loss / self.n_classes | |||
class FixMatchModel(nn.Module): | |||
def __init__(self): | |||
super(FixMatchModel, self).__init__() | |||
self.base_model = RESUNETSEGMENTOR(1, 2) | |||
self.criterion_dice = DiceLoss(n_classes=2) | |||
def forward(self, x): | |||
weak_img = self._get_weak_augs(x) | |||
strong_img = self._get_strong_augs(weak_img) | |||
pred_u_w = self.base_model(weak_img.to(x.device)).softmax(dim=1) | |||
pred_u_s = self.base_model(strong_img.to(x.device)).softmax(dim=1) | |||
indicator = [] | |||
for i in range(len(pred_u_w)): | |||
if torch.max(x) >= 0.95: | |||
indicator.append(i) | |||
pred_u_w = pred_u_w[indicator, ...] | |||
pred_u_s = pred_u_w[indicator, ...] | |||
mask_u_w = pred_u_w.argmax(dim=1) | |||
loss_u_s = self.criterion_dice(pred_u_s, mask_u_w.unsqueeze(1).float()) | |||
return self.base_model(x).softmax(dim=1), loss_u_s | |||
def _get_strong_augs(self, img): | |||
img_s = deepcopy(img) | |||
if random.random() < 0.8: | |||
img_s = transforms.ColorJitter(0.25, 0.25, 0.25, 0.15)(img_s) | |||
if random.random() < 0.5: | |||
sigma = np.random.uniform(0.1, 2.0) | |||
img_s = transforms.functional.gaussian_blur(img_s, kernel_size=(5, 5), sigma=sigma) | |||
return img_s | |||
def _get_weak_augs(self, img): | |||
if random.random() > 0.5: | |||
img = random_rot_flip(img) | |||
elif random.random() > 0.5: | |||
img = random_rotate(img) | |||
if isinstance(img, torch.Tensor): | |||
return img | |||
return torch.from_numpy(img).float() | |||
class FixMatch(Model): | |||
""" | |||
FixMatch architecture. | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
self.fix_match = FixMatchModel() | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_loss_and_gt: torch.Tensor) -> ModelIO: | |||
loss_type = mammo_loss_and_gt[:, 0, ...] | |||
ground_truth = mammo_loss_and_gt[:, 1, ...] * loss_type | |||
out_prob_map, unsup_loss = self.fix_match(mammo_x) | |||
main_shape = ground_truth.shape | |||
out_prob_map_shape = out_prob_map.shape | |||
loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],out_prob_map_shape[1],main_shape[1], main_shape[2]) | |||
ground_truth = ground_truth.unsqueeze(dim = 1) | |||
mask_channels = torch.cat((1 - ground_truth, ground_truth), 1) | |||
dummy_mask = mask_channels.clone().to(mammo_x.device) | |||
dummy_net_out = out_prob_map.clone().to(mammo_x.device) | |||
dummy_mask[loss_channels == 0] = 0.0 | |||
dummy_shape = dummy_mask.shape | |||
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_net_out = torch.cat((dummy_net_out, torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_yhat = out_prob_map * loss_channels | |||
dummy_y = mask_channels * loss_channels | |||
dice = dice_loss(dummy_net_out,dummy_mask) | |||
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y) | |||
focal_loss = balanced_focal_cross_entropy_loss_semi(dummy_net_out, dummy_mask, 4) | |||
output = { | |||
'pixel_probs': out_prob_map, | |||
'org_pixel_labels': mammo_loss_and_gt, | |||
'loss': (focal_loss + unsup_loss) / 2, | |||
'focal_loss': focal_loss, | |||
'ce_loss': ce_loss, | |||
'dice_loss': dice, | |||
} | |||
return output |
@@ -0,0 +1,40 @@ | |||
import numpy as np | |||
import random | |||
import torch | |||
from torchvision import transforms | |||
from scipy import ndimage | |||
def random_rot_flip(img): | |||
k = np.random.randint(0, 4) | |||
img = transforms.functional.rotate(img, angle=k * 90) | |||
axis = np.random.randint(0, 2) | |||
img = torch.flip(img, [axis]) | |||
return img | |||
def random_rotate(img): | |||
angle = np.random.randint(-20, 20) | |||
img = transforms.functional.rotate(img, angle=angle) | |||
return img | |||
def obtain_cutmix_box(img_size, p=0.5, size_min=0.02, size_max=0.4, ratio_1=0.3, ratio_2=1/0.3): | |||
mask = torch.zeros(img_size, img_size) | |||
if random.random() > p: | |||
return mask | |||
size = np.random.uniform(size_min, size_max) * img_size * img_size | |||
while True: | |||
ratio = np.random.uniform(ratio_1, ratio_2) | |||
cutmix_w = int(np.sqrt(size / ratio)) | |||
cutmix_h = int(np.sqrt(size * ratio)) | |||
x = np.random.randint(0, img_size) | |||
y = np.random.randint(0, img_size) | |||
if x + cutmix_w <= img_size and y + cutmix_h <= img_size: | |||
break | |||
mask[y:y + cutmix_h, x:x + cutmix_w] = 1 | |||
return mask |
@@ -0,0 +1,160 @@ | |||
import torch | |||
from torch import nn | |||
import torch.nn.functional as F | |||
from torchvision import transforms | |||
from copy import deepcopy | |||
import math | |||
import numpy as np | |||
import os | |||
from PIL import Image | |||
import random | |||
from scipy.ndimage.interpolation import zoom | |||
from scipy import ndimage | |||
from .Models.unet import UNet | |||
from .transform import random_rot_flip, random_rotate, obtain_cutmix_box | |||
from mlassistant.core import ModelIO, Model | |||
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi | |||
from .....utils.generalized_dice import dice_loss | |||
class DiceLoss(nn.Module): | |||
def __init__(self, n_classes): | |||
super(DiceLoss, self).__init__() | |||
self.n_classes = n_classes | |||
def _one_hot_encoder(self, input_tensor): | |||
tensor_list = [] | |||
for i in range(self.n_classes): | |||
temp_prob = input_tensor == i * torch.ones_like(input_tensor) | |||
tensor_list.append(temp_prob) | |||
output_tensor = torch.cat(tensor_list, dim=1) | |||
return output_tensor.float() | |||
def _dice_loss(self, score, target, ignore): | |||
target = target.float() | |||
smooth = 1e-5 | |||
intersect = torch.sum(score[ignore != 1] * target[ignore != 1]) | |||
y_sum = torch.sum(target[ignore != 1] * target[ignore != 1]) | |||
z_sum = torch.sum(score[ignore != 1] * score[ignore != 1]) | |||
loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) | |||
loss = 1 - loss | |||
return loss | |||
def forward(self, inputs, target, weight=None, softmax=False, ignore=None): | |||
if softmax: | |||
inputs = torch.softmax(inputs, dim=1) | |||
target = self._one_hot_encoder(target) | |||
if weight is None: | |||
weight = [1] * self.n_classes | |||
assert inputs.size() == target.size(), 'predict & target shape do not match' | |||
class_wise_dice = [] | |||
loss = 0.0 | |||
for i in range(0, self.n_classes): | |||
dice = self._dice_loss(inputs[:, i], target[:, i], ignore) | |||
class_wise_dice.append(1.0 - dice.item()) | |||
loss += dice * weight[i] | |||
return loss / self.n_classes | |||
class UniMatchModel(nn.Module): | |||
def __init__(self): | |||
super(UniMatchModel, self).__init__() | |||
self.base_model = UNet(1, 2) | |||
self.criterion_dice = DiceLoss(n_classes=2) | |||
def forward(self, x): | |||
weak_img = self._get_weak_augs(x) | |||
strong_img1, strong_img2 = self._get_strong_augs(weak_img) | |||
pred_u_w, pred_fp = self.base_model(weak_img.to(x.device), True) | |||
pred_u_s1, pred_u_s2 = self.base_model(torch.cat((strong_img1.to(x.device), strong_img2.to(x.device)))).chunk(2) | |||
mask_u_w = pred_u_w.argmax(dim=1) | |||
loss_u_s1 = self.criterion_dice(pred_u_s1.softmax(dim=1), | |||
mask_u_w.unsqueeze(1).float()) | |||
loss_u_s2 = self.criterion_dice(pred_u_s2.softmax(dim=1), | |||
mask_u_w.unsqueeze(1).float()) | |||
loss_u_w_fp = self.criterion_dice(pred_fp.softmax(dim=1), mask_u_w.unsqueeze(1).float()) | |||
loss = loss_u_s1 * 0.25 + loss_u_s2 * 0.25 + loss_u_w_fp * 0.5 | |||
return self.base_model(x).softmax(dim=1), loss | |||
def _get_strong_augs(self, img): | |||
img_s1, img_s2 = deepcopy(img), deepcopy(img) | |||
if random.random() < 0.8: | |||
img_s1 = transforms.ColorJitter(0.5, 0.5, 0.5, 0.25)(img_s1) | |||
if random.random() < 0.5: | |||
sigma = np.random.uniform(0.1, 2.0) | |||
img_s1 = transforms.functional.gaussian_blur(img_s1, kernel_size=(5, 5), sigma=sigma) | |||
if random.random() < 0.8: | |||
img_s2 = transforms.ColorJitter(0.5, 0.5, 0.5, 0.25)(img_s2) | |||
if random.random() < 0.5: | |||
sigma = np.random.uniform(0.1, 2.0) | |||
img_s2 = transforms.functional.gaussian_blur(img_s2, kernel_size=(5, 5), sigma=sigma) | |||
return img_s1, img_s2 | |||
def _get_weak_augs(self, img): | |||
if random.random() > 0.5: | |||
img = random_rot_flip(img) | |||
elif random.random() > 0.5: | |||
img = random_rotate(img) | |||
if isinstance(img, torch.Tensor): | |||
return img | |||
return torch.from_numpy(img).float() | |||
class UniMatch(Model): | |||
""" | |||
UniMatch architecture. | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
self.uni_match = UniMatchModel() | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_loss_and_gt: torch.Tensor) -> ModelIO: | |||
loss_type = mammo_loss_and_gt[:, 0, ...] | |||
ground_truth = mammo_loss_and_gt[:, 1, ...] * loss_type | |||
out_prob_map, unsup_loss = self.uni_match(mammo_x) | |||
main_shape = ground_truth.shape | |||
out_prob_map_shape = out_prob_map.shape | |||
loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],out_prob_map_shape[1],main_shape[1], main_shape[2]) | |||
ground_truth = ground_truth.unsqueeze(dim = 1) | |||
mask_channels = torch.cat((1 - ground_truth, ground_truth), 1) | |||
dummy_mask = mask_channels.clone().to(mammo_x.device) | |||
dummy_net_out = out_prob_map.clone().to(mammo_x.device) | |||
dummy_mask[loss_channels == 0] = 0.0 | |||
dummy_shape = dummy_mask.shape | |||
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_net_out = torch.cat((dummy_net_out, torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_yhat = out_prob_map * loss_channels | |||
dummy_y = mask_channels * loss_channels | |||
dice = dice_loss(dummy_net_out,dummy_mask) | |||
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y) | |||
focal_loss = balanced_focal_cross_entropy_loss_semi(dummy_net_out, dummy_mask, 4) | |||
output = { | |||
'pixel_probs': out_prob_map, | |||
'org_pixel_labels': mammo_loss_and_gt, | |||
'loss': (focal_loss + unsup_loss) / 2, | |||
'focal_loss': focal_loss, | |||
'ce_loss': ce_loss, | |||
'dice_loss': dice, | |||
} | |||
return output |
@@ -0,0 +1,187 @@ | |||
import torch | |||
from torch import nn | |||
import torch.nn.functional as F | |||
from torchvision import transforms | |||
from .Models.resunet import ResunetSegmentor | |||
from mlassistant.core import ModelIO, Model | |||
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi | |||
from .....utils.generalized_dice import dice_loss | |||
from copy import deepcopy | |||
import math | |||
import numpy as np | |||
import os | |||
from PIL import Image | |||
import random | |||
from scipy.ndimage.interpolation import zoom | |||
from scipy import ndimage | |||
from .transform import random_rot_flip, random_rotate, obtain_cutmix_box | |||
from mlassistant.core import ModelIO, Model | |||
from .....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi | |||
from .....utils.generalized_dice import dice_loss | |||
class DiceLoss(nn.Module): | |||
def __init__(self, n_classes): | |||
super(DiceLoss, self).__init__() | |||
self.n_classes = n_classes | |||
def _one_hot_encoder(self, input_tensor): | |||
tensor_list = [] | |||
for i in range(self.n_classes): | |||
temp_prob = input_tensor == i * torch.ones_like(input_tensor) | |||
tensor_list.append(temp_prob) | |||
output_tensor = torch.cat(tensor_list, dim=1) | |||
return output_tensor.float() | |||
def _dice_loss(self, score, target, ignore): | |||
target = target.float() | |||
smooth = 1e-5 | |||
intersect = torch.sum(score[ignore != 1] * target[ignore != 1]) | |||
y_sum = torch.sum(target[ignore != 1] * target[ignore != 1]) | |||
z_sum = torch.sum(score[ignore != 1] * score[ignore != 1]) | |||
loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) | |||
loss = 1 - loss | |||
return loss | |||
def forward(self, inputs, target, weight=None, softmax=False, ignore=None): | |||
if softmax: | |||
inputs = torch.softmax(inputs, dim=1) | |||
target = self._one_hot_encoder(target) | |||
if weight is None: | |||
weight = [1] * self.n_classes | |||
assert inputs.size() == target.size(), 'predict & target shape do not match' | |||
class_wise_dice = [] | |||
loss = 0.0 | |||
for i in range(0, self.n_classes): | |||
dice = self._dice_loss(inputs[:, i], target[:, i], ignore) | |||
class_wise_dice.append(1.0 - dice.item()) | |||
loss += dice * weight[i] | |||
return loss / self.n_classes | |||
class UniMatchModel(nn.Module): | |||
def __init__(self): | |||
super(UniMatchModel, self).__init__() | |||
encoders_conv_kws = [ | |||
{"inp_channel": 1, "out_channel": 32, "n_conv_blocks": 2}, | |||
{"inp_channel": 33, "out_channel": 64, "n_conv_blocks": 2}, | |||
{"inp_channel": 65, "out_channel": 128, "n_conv_blocks": 3}, | |||
{"inp_channel": 129, "out_channel": 256, "n_conv_blocks": 3}, | |||
] | |||
bridge_conv_kws = {"inp_channel": 257, "out_channel": 512} | |||
decoders_conv_kws = [ | |||
{"inp_channel": 512, "enc_inp_channel": 256, "out_channel": 256}, | |||
{"inp_channel": 256, "enc_inp_channel": 128, "out_channel": 64}, | |||
{"inp_channel": 64, "enc_inp_channel": 64, "out_channel": 32}, | |||
{"inp_channel": 32, "enc_inp_channel": 32, "out_channel": 16}, | |||
] | |||
self.base_model = ResunetSegmentor( | |||
seg_num_labels=2, | |||
concat_original_image=True, | |||
encoders_conv_kws=encoders_conv_kws, | |||
decoders_conv_kws=decoders_conv_kws, | |||
bridge_conv_kws=bridge_conv_kws, | |||
) | |||
self.criterion_dice = DiceLoss(n_classes=2) | |||
def forward(self, x): | |||
weak_img = self._get_weak_augs(x) | |||
strong_img1, strong_img2 = self._get_strong_augs(weak_img) | |||
pred_u_w, pred_fp = self.base_model(weak_img.to(x.device), True) | |||
pred_u_s1, pred_u_s2 = self.base_model(torch.cat((strong_img1.to(x.device), strong_img2.to(x.device)))).chunk(2) | |||
mask_u_w = pred_u_w.argmax(dim=1) | |||
loss_u_s1 = self.criterion_dice(pred_u_s1.softmax(dim=1), | |||
mask_u_w.unsqueeze(1).float()) | |||
loss_u_s2 = self.criterion_dice(pred_u_s2.softmax(dim=1), | |||
mask_u_w.unsqueeze(1).float()) | |||
loss_u_w_fp = self.criterion_dice(pred_fp.softmax(dim=1), mask_u_w.unsqueeze(1).float()) | |||
loss = loss_u_s1 * 0.25 + loss_u_s2 * 0.25 + loss_u_w_fp * 0.5 | |||
return self.base_model(x).softmax(dim=1), loss | |||
def _get_strong_augs(self, img): | |||
img_s1, img_s2 = deepcopy(img), deepcopy(img) | |||
if random.random() < 0.8: | |||
img_s1 = transforms.ColorJitter(0.5, 0.5, 0.5, 0.25)(img_s1) | |||
if random.random() < 0.5: | |||
sigma = np.random.uniform(0.1, 2.0) | |||
img_s1 = transforms.functional.gaussian_blur(img_s1, kernel_size=(5, 5), sigma=sigma) | |||
if random.random() < 0.8: | |||
img_s2 = transforms.ColorJitter(0.5, 0.5, 0.5, 0.25)(img_s2) | |||
if random.random() < 0.5: | |||
sigma = np.random.uniform(0.1, 2.0) | |||
img_s2 = transforms.functional.gaussian_blur(img_s2, kernel_size=(5, 5), sigma=sigma) | |||
return img_s1, img_s2 | |||
def _get_weak_augs(self, img): | |||
if random.random() > 0.5: | |||
img = random_rot_flip(img) | |||
elif random.random() > 0.5: | |||
img = random_rotate(img) | |||
if isinstance(img, torch.Tensor): | |||
return img | |||
return torch.from_numpy(img).float() | |||
class UniMatch(Model): | |||
""" | |||
UniMatch architecture. | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
self.uni_match = UniMatchModel() | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_loss_and_gt: torch.Tensor) -> ModelIO: | |||
loss_type = mammo_loss_and_gt[:, 0, ...] | |||
ground_truth = mammo_loss_and_gt[:, 1, ...] * loss_type | |||
out_prob_map, unsup_loss = self.uni_match(mammo_x) | |||
main_shape = ground_truth.shape | |||
out_prob_map_shape = out_prob_map.shape | |||
loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],out_prob_map_shape[1],main_shape[1], main_shape[2]) | |||
ground_truth = ground_truth.unsqueeze(dim = 1) | |||
mask_channels = torch.cat((1 - ground_truth, ground_truth), 1) | |||
dummy_mask = mask_channels.clone().to(mammo_x.device) | |||
dummy_net_out = out_prob_map.clone().to(mammo_x.device) | |||
dummy_mask[loss_channels == 0] = 0.0 | |||
dummy_shape = dummy_mask.shape | |||
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_net_out = torch.cat((dummy_net_out, torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_yhat = out_prob_map * loss_channels | |||
dummy_y = mask_channels * loss_channels | |||
dice = dice_loss(dummy_net_out,dummy_mask) | |||
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y) | |||
focal_loss = balanced_focal_cross_entropy_loss_semi(dummy_net_out, dummy_mask, 4) | |||
output = { | |||
'pixel_probs': out_prob_map, | |||
'org_pixel_labels': mammo_loss_and_gt, | |||
'loss': (focal_loss + unsup_loss) / 2, | |||
'focal_loss': focal_loss, | |||
'ce_loss': ce_loss, | |||
'dice_loss': dice, | |||
} | |||
return output |
@@ -0,0 +1,194 @@ | |||
from copy import deepcopy | |||
import math | |||
from random import uniform | |||
from typing import List | |||
from torch.nn import functional as F | |||
from torch import nn | |||
from ....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi | |||
from ....utils.generalized_dice import dice_loss | |||
from ....utils.losses.ent_losses import entropy_loss_normalized | |||
from ....utils.losses.tverskyLoss import tversky_loss | |||
import time | |||
from mlassistant.core import Model, ModelIO | |||
import torch | |||
from torch import nn | |||
from ....utils.losses.balanced_focal_cross_entropy import balanced_focal_cross_entropy_loss_semi | |||
from ....utils.generalized_dice import dice_loss | |||
from ....utils.dynamic_num_generator.num_generator import init_num_generator1, init_num_generator2 | |||
from ....utils.dynamic_num_generator.wrapper import dynamic_num_generator | |||
import numpy as np | |||
def random_dropout(min, max): | |||
def assign(module: nn.Dropout,_: torch.Tensor) -> None: | |||
module.p = uniform(min, max) | |||
return assign | |||
def static_dropout(prob: float): | |||
def assign(module: nn.Dropout, _: torch.Tensor) -> None: | |||
module.p = prob | |||
return assign | |||
class UncertaintyAwareStudentMeanTeacherNet(Model): | |||
""" adapted from the paper 'Uncertainty-aware Self-ensembling Model for Semi-supervised | |||
3D Left Atrium Segmentation'. link: https://arxiv.org/pdf/1907.07034.pdf | |||
Args: | |||
alpha (float): exponential moving average decay | |||
total_num_iteration (int): total number of iteration this run will have | |||
base_model (Model): base networks for both student and teacher | |||
x_arg (str): argument name of `forward` indicating images, shape: B ... | |||
x_roi_arg (str): argument name of `forward` indicating ground-truth ROIs, shape: B ... | |||
x_annotated_arg (str): argument name of `forward` indicating being labeled or not for inputs, shape: B | |||
uncertainty_ratio_thd (float): the initial ratio threshold for uncertainty pruning | |||
std (float): standard deviation of gaussian noise to be applied to the input | |||
mean (float): mean of gaussian noise to be applied to the input | |||
T (int): number of perturbations for teacher net | |||
lambda_coef (float): lambda coefficient for calculating unsupervised loss | |||
use_supervised_in_consistency_checking (bool): if true then use supervised samples in unsupervised task (consistency checking) | |||
""" | |||
def train(self,mode : bool=True): | |||
super().train(mode) | |||
self.epochNUM += 1 | |||
# with torch.no_grad(): | |||
# if self.epochNUM % 2 == 0: | |||
# self.update_teacher_params() | |||
self.mode = mode | |||
print(mode) | |||
def __init__(self, | |||
alpha: float, | |||
base_model: Model, | |||
std: float = math.sqrt(0.0001), | |||
mean: float = 0.0, | |||
T: int=8, | |||
gamma = 2, smooth=1, alphat=0.7, beta=0.3) -> None: | |||
super(UncertaintyAwareStudentMeanTeacherNet, self).__init__() | |||
self.epochNUM = 1 | |||
assert alpha >= 0 and alpha <= 1, f'alpha should be in range [0, 1]; got {alpha} instead.' | |||
assert T > 0, f'number of perturbations to be applied for teacher net should be at least one; got {T} instead.' | |||
self._T: int = T | |||
self._alpha: float = alpha | |||
self._mean: float = mean | |||
self._std: float = std | |||
self.smooth = smooth | |||
self.alphatt = alphat | |||
self.beta = beta | |||
self._student: Model = base_model | |||
self._teacher: Model = deepcopy(self._student) | |||
self.gamma = gamma | |||
self.uncercoef = lambda t : torch.tensor(0.1*np.exp(-5 * (1-t/200)**2)) | |||
self.H = np.log(2) | |||
self.threshold = lambda t : torch.tensor((self.H/4)*np.exp(-5 * (1-t/200)**2) + self.H * 3/4) | |||
self.flag = True | |||
self.mode = True | |||
def _update_teacher_params(self, | |||
_: Model, | |||
__: torch.Tensor) -> None: | |||
"""a backward hook to update teacher network parameters, using exponential moving | |||
average method.""" | |||
for s_p, t_p in zip(self._student.parameters(), self._teacher.parameters()): | |||
t_p.data = self._alpha * t_p.data + (1 - self._alpha) * s_p.data | |||
def update_teacher_params(self) -> None: | |||
"""a backward hook to update teacher network parameters, using exponential moving | |||
average method.""" | |||
for s_p, t_p in zip(self._student.parameters(), self._teacher.parameters()): | |||
t_p.data = self._alpha * t_p.data + (1 - self._alpha) * s_p.data | |||
def _apply_gaussian_noise(self, | |||
input ) -> torch.Tensor: | |||
inp = input | |||
noise = torch.randn(inp.size()) * self._std + self._mean | |||
noise = noise.to(inp.device) | |||
inp = inp + (noise * inp) | |||
input = inp | |||
return input | |||
def forward(self, | |||
mammo_x: torch.Tensor, | |||
mammo_loss_and_gt: torch.Tensor) -> ModelIO: | |||
"""first forwards supervised samples through student net, then all samples through both nets. """ | |||
with torch.no_grad(): | |||
if self.mode and self.epochNUM > 1: | |||
self.update_teacher_params() | |||
output = dict(loss=0.0) | |||
assert mammo_loss_and_gt.shape[1] == 2 , "loss and gt is not in expected shape (B,2,size,size)" | |||
loss_type = mammo_loss_and_gt[:,0,:,:] | |||
mask = mammo_loss_and_gt[:,1,:,:] | |||
network_output = self._student.forward(mammo_x).softmax(dim=1) # logsoftmax # B | |||
network_output_soft = network_output #torch.exp(network_output).to(mammo_x.device) # softmax output of model | |||
output['pixel_probs'] = network_output_soft | |||
main_shape = mask.shape | |||
network_output_shape = network_output.shape | |||
loss_channels = loss_type.unsqueeze(dim=1).expand(main_shape[0],network_output_shape[1],main_shape[1], main_shape[2]) | |||
mask = mask.unsqueeze(dim = 1) | |||
mask_channels = torch.cat((1-mask, mask), 1) | |||
dummy_mask = mask_channels.clone().to(mammo_x.device) | |||
dummy_net_out = network_output_soft.clone().to(mammo_x.device) | |||
dummy_mask[loss_channels == 0] = 0.0 | |||
dummy_shape = dummy_mask.shape | |||
dummy_mask = torch.cat((dummy_mask,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_net_out = torch.cat((dummy_net_out,torch.ones((dummy_shape[0],1,dummy_shape[2],dummy_shape[2])).to(mammo_x.device)), 1) | |||
dummy_yhat = network_output_soft * loss_channels | |||
dummy_y = mask_channels * loss_channels | |||
focal_loss = balanced_focal_cross_entropy_loss_semi( | |||
dummy_net_out, | |||
dummy_mask, | |||
focal_gamma=self.gamma) | |||
dice = dice_loss(dummy_net_out,dummy_mask) | |||
ce_loss = F.binary_cross_entropy(dummy_yhat, dummy_y) | |||
tversky = tversky_loss(inputs = dummy_yhat[:,1,...], targets = dummy_y[:,1,...], smooth=self.smooth, alpha = self.alphatt, beta= self.beta) | |||
with torch.no_grad(): | |||
teacher_outs: List[torch.Tensor] = list() | |||
for _ in range(self._T): | |||
teacher_outs.append(self._teacher.forward( | |||
self._apply_gaussian_noise(mammo_x)).softmax(dim=1)) | |||
out_t_u = (torch.stack(teacher_outs, dim=2)).mean(dim=2) # B C H' W' | |||
uncertainty = -1 * torch.sum(out_t_u * torch.log(out_t_u), 1) # B H' W' | |||
mse = torch.sum((network_output_soft - out_t_u) ** 2, dim=1) | |||
# consider only most certain pixels for mse loss | |||
mse_ = torch.zeros_like(mse).to(mammo_x.device) | |||
UThreshhold = self.threshold(int(self.epochNUM/2)) | |||
mask_mse = uncertainty < UThreshhold | |||
mse_[mask_mse] += mse[mask_mse] | |||
active_pxs = torch.sum(mask_mse.long(), dim=[1, 2]) + 1 # B | |||
mse_loss = torch.sum(mse_ / active_pxs[:, None, None], dim=[1, 2]) | |||
mse_loss = torch.mean(mse_loss) | |||
mse_coef = self.uncercoef(int(self.epochNUM/2)).to(mammo_x.device) | |||
output['ce_loss'] = ce_loss | |||
output['loss'] = focal_loss + mse_coef * mse_loss | |||
#output['diceloss'] = dice | |||
#output['tverskyloss'] = tversky | |||
output['focalloss'] = focal_loss | |||
output['suploss'] = focal_loss | |||
output['org_pixel_labels'] = mammo_loss_and_gt | |||
output['mse_loss'] = mse_loss | |||
output['mse_coef_loss'] = torch.tensor(mse_coef) | |||
output['Uthreshhold_loss'] = torch.tensor(UThreshhold) | |||
output['totol_loss'] = output['loss'] | |||
return output |