|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102 |
- import numpy as np
- import torch
- import os
-
- class AvgMeter:
- def __init__(self, name="Metric"):
- self.name = name
- self.reset()
-
- def reset(self):
- self.avg, self.sum, self.count = [0] * 3
-
- def update(self, val, count=1):
- self.count += count
- self.sum += val * count
- self.avg = self.sum / self.count
-
- def __repr__(self):
- text = f"{self.name}: {self.avg:.4f}"
- return text
-
-
- def print_lr(optimizer):
- for param_group in optimizer.param_groups:
- print(param_group['name'], param_group['lr'])
-
-
- class CheckpointSaving:
-
- def __init__(self, path='checkpoint.pt', verbose=True, trace_func=print):
- self.best_score = None
- self.val_acc_max = 0
- self.path = path
- self.verbose = verbose
- self.trace_func = trace_func
-
- def __call__(self, val_acc, model):
- if self.best_score is None:
- self.best_score = val_acc
- self.save_checkpoint(val_acc, model)
- elif val_acc > self.best_score:
- self.best_score = val_acc
- self.save_checkpoint(val_acc, model)
-
- def save_checkpoint(self, val_acc, model):
- if self.verbose:
- self.trace_func(
- f'Validation accuracy increased ({self.val_acc_max:.6f} --> {val_acc:.6f}). Model saved ...')
- torch.save(model.state_dict(), self.path)
- self.val_acc_max = val_acc
-
-
- class EarlyStopping:
-
- def __init__(self, patience=10, verbose=False, delta=0.000001, path='checkpoint.pt', trace_func=print):
- self.patience = patience
- self.verbose = verbose
- self.counter = 0
- self.best_score = None
- self.early_stop = False
- self.val_loss_min = np.Inf
- self.delta = delta
- self.path = path
- self.trace_func = trace_func
-
- def __call__(self, val_loss, model):
-
- score = -val_loss
-
- if self.best_score is None:
- self.best_score = score
- if self.verbose:
- self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).')
- self.val_loss_min = val_loss
- # self.save_checkpoint(val_loss, model)
- elif score < self.best_score + self.delta:
- self.counter += 1
- self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
- if self.counter >= self.patience:
- self.early_stop = True
- else:
- self.best_score = score
- if self.verbose:
- self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).')
- # self.save_checkpoint(val_loss, model)
- self.val_loss_min = val_loss
- self.counter = 0
-
- # def save_checkpoint(self, val_loss, model):
- # if self.verbose:
- # self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Model saved ...')
- # torch.save(model.state_dict(), self.path)
- # self.val_loss_min = val_loss
-
-
- def make_directory(path):
- try:
- os.mkdir(path)
- except OSError:
- print("Creation of the directory failed")
- else:
- print("Successfully created the directory")
|