Browse Source

Add files via upload

main
rucv 3 years ago
commit
7d4ada243e
No account linked to committer's email address
1 changed files with 116 additions and 0 deletions
  1. 116
    0
      train.py

+ 116
- 0
train.py View File

#Adopted from the ACSNet

import torch
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import LambdaLR
from tqdm import tqdm
import datasets
from utils.metrics import evaluate
from opt import opt
from utils.comm import generate_model
from utils.loss import DeepSupervisionLoss, BceDiceLoss
from utils.metrics import Metrics
import torch.nn as nn


def valid(model, valid_dataloader, total_batch):

model.eval()

# Metrics_logger initialization
metrics = Metrics(['recall', 'specificity', 'precision', 'F1', 'F2',
'ACC_overall', 'IoU_poly', 'IoU_bg', 'IoU_mean'])

with torch.no_grad():
bar = tqdm(enumerate(valid_dataloader), total=total_batch)
for i, data in bar:
img, gt = data['image'], data['label']

if opt.use_gpu:
img = img.cuda()
gt = gt.cuda()

output = model(img)
_recall, _specificity, _precision, _F1, _F2, \
_ACC_overall, _IoU_poly, _IoU_bg, _IoU_mean = evaluate(output, gt, 0.5)

metrics.update(recall= _recall, specificity= _specificity, precision= _precision,
F1= _F1, F2= _F2, ACC_overall= _ACC_overall, IoU_poly= _IoU_poly,
IoU_bg= _IoU_bg, IoU_mean= _IoU_mean
)

metrics_result = metrics.mean(total_batch)

return metrics_result


def train():

model = generate_model(opt)
#model = nn.DataParallel(model)

# load data
train_data = getattr(datasets, opt.dataset)(opt.root, opt.train_data_dir, mode='train')
train_dataloader = DataLoader(train_data, opt.batch_size, shuffle=True, num_workers=opt.num_workers)
valid_data = getattr(datasets, opt.dataset)(opt.root, opt.valid_data_dir, mode='valid')
valid_dataloader = DataLoader(valid_data, batch_size=1, shuffle=False, num_workers=opt.num_workers)
val_total_batch = int(len(valid_data) / 1)


# load optimizer and scheduler
optimizer = torch.optim.SGD(model.parameters(), lr=opt.lr, momentum=opt.mt, weight_decay=opt.weight_decay)

lr_lambda = lambda epoch: 1.0 - pow((epoch / opt.nEpoch), opt.power)
scheduler = LambdaLR(optimizer, lr_lambda)

# train
print('Start training')
print('---------------------------------\n')

for epoch in range(opt.nEpoch):
print('------ Epoch', epoch + 1)
model.train()
total_batch = int(len(train_data) / opt.batch_size)
bar = tqdm(enumerate(train_dataloader), total=total_batch)

for i, data in bar:
img = data['image']
gt = data['label']


if opt.use_gpu:
img = img.cuda()
gt = gt.cuda()

optimizer.zero_grad()
output = model(img)

#loss = BceDiceLoss()(output, gt)
loss = DeepSupervisionLoss(output, gt)
loss.backward()

optimizer.step()
bar.set_postfix_str('loss: %.5s' % loss.item())

scheduler.step()

metrics_result = valid(model, valid_dataloader, val_total_batch)

print("Valid Result:")
print('recall: %.4f, specificity: %.4f, precision: %.4f, F1: %.4f,'
' F2: %.4f, ACC_overall: %.4f, IoU_poly: %.4f, IoU_bg: %.4f, IoU_mean: %.4f'
% (metrics_result['recall'], metrics_result['specificity'], metrics_result['precision'],
metrics_result['F1'], metrics_result['F2'], metrics_result['ACC_overall'],
metrics_result['IoU_poly'], metrics_result['IoU_bg'], metrics_result['IoU_mean']))

if ((epoch + 1) % opt.ckpt_period == 0):
torch.save(model.state_dict(), './checkpoints/exp' + str(opt.expID)+"/ck_{}.pth".format(epoch + 1))


if __name__ == '__main__':

if opt.mode == 'train':
print('---PolpySeg Train---')
train()

print('Done')

Loading…
Cancel
Save