|
|
@@ -22,6 +22,9 @@ def batch_constructor(config, batch): |
|
|
|
|
|
|
|
def train_epoch(config, model, train_loader, optimizer, scalar): |
|
|
|
loss_meter = AvgMeter('train') |
|
|
|
c_loss_meter = AvgMeter('train') |
|
|
|
s_loss_meter = AvgMeter('train') |
|
|
|
|
|
|
|
# tqdm_object = tqdm(train_loader, total=len(train_loader)) |
|
|
|
|
|
|
|
targets = [] |
|
|
@@ -32,7 +35,7 @@ def train_epoch(config, model, train_loader, optimizer, scalar): |
|
|
|
|
|
|
|
with torch.cuda.amp.autocast(): |
|
|
|
output, score = model(batch) |
|
|
|
loss = calculate_loss(model, score, batch['label']) |
|
|
|
loss, c_loss, s_loss = calculate_loss(model, score, batch['label']) |
|
|
|
scalar.scale(loss).backward() |
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) |
|
|
|
|
|
|
@@ -44,6 +47,8 @@ def train_epoch(config, model, train_loader, optimizer, scalar): |
|
|
|
|
|
|
|
count = batch["id"].size(0) |
|
|
|
loss_meter.update(loss.detach(), count) |
|
|
|
c_loss_meter.update(c_loss.detach(), count) |
|
|
|
s_loss_meter.update(s_loss.detach(), count) |
|
|
|
|
|
|
|
# tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer)) |
|
|
|
|
|
|
@@ -53,11 +58,15 @@ def train_epoch(config, model, train_loader, optimizer, scalar): |
|
|
|
target = batch['label'].detach() |
|
|
|
targets.append(target) |
|
|
|
|
|
|
|
return loss_meter, targets, predictions |
|
|
|
losses = (loss_meter, s_loss_meter, c_loss_meter) |
|
|
|
|
|
|
|
return losses, targets, predictions |
|
|
|
|
|
|
|
|
|
|
|
def validation_epoch(config, model, validation_loader): |
|
|
|
loss_meter = AvgMeter('validation') |
|
|
|
c_loss_meter = AvgMeter('validation') |
|
|
|
s_loss_meter = AvgMeter('validation') |
|
|
|
|
|
|
|
targets = [] |
|
|
|
predictions = [] |
|
|
@@ -66,10 +75,12 @@ def validation_epoch(config, model, validation_loader): |
|
|
|
batch = batch_constructor(config, batch) |
|
|
|
with torch.no_grad(): |
|
|
|
output, score = model(batch) |
|
|
|
loss = calculate_loss(model, score, batch['label']) |
|
|
|
loss, c_loss, s_loss = calculate_loss(model, score, batch['label']) |
|
|
|
|
|
|
|
count = batch["id"].size(0) |
|
|
|
loss_meter.update(loss.detach(), count) |
|
|
|
c_loss_meter.update(c_loss.detach(), count) |
|
|
|
s_loss_meter.update(s_loss.detach(), count) |
|
|
|
|
|
|
|
# tqdm_object.set_postfix(validation_loss=loss_meter.avg) |
|
|
|
|
|
|
@@ -79,7 +90,9 @@ def validation_epoch(config, model, validation_loader): |
|
|
|
target = batch['label'].detach() |
|
|
|
targets.append(target) |
|
|
|
|
|
|
|
return loss_meter, targets, predictions |
|
|
|
losses = (loss_meter, s_loss_meter, c_loss_meter) |
|
|
|
|
|
|
|
return losses, targets, predictions |
|
|
|
|
|
|
|
|
|
|
|
def supervised_train(config, train_loader, validation_loader, trial=None): |
|
|
@@ -131,13 +144,13 @@ def supervised_train(config, train_loader, validation_loader, trial=None): |
|
|
|
train_accuracy = multiclass_acc(train_truth, train_pred) |
|
|
|
validation_accuracy = multiclass_acc(validation_truth, validation_pred) |
|
|
|
print_lr(optimizer) |
|
|
|
print('Training Loss:', train_loss, 'Training Accuracy:', train_accuracy) |
|
|
|
print('Training Loss:', train_loss[0], 'Training Accuracy:', train_accuracy) |
|
|
|
print('Validation Loss', validation_loss, 'Validation Accuracy:', validation_accuracy) |
|
|
|
|
|
|
|
if lr_scheduler: |
|
|
|
lr_scheduler.step(validation_loss.avg) |
|
|
|
lr_scheduler.step(validation_loss[0].avg) |
|
|
|
if early_stopping: |
|
|
|
early_stopping(validation_loss.avg, model) |
|
|
|
early_stopping(validation_loss[0].avg, model) |
|
|
|
if early_stopping.early_stop: |
|
|
|
print("Early stopping") |
|
|
|
break |