Adapted to Movie lens dataset
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

torch_utils.py 5.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. """
  2. Utility functions for torch.
  3. """
  4. import torch
  5. from torch import nn, optim
  6. from torch.optim.optimizer import Optimizer
  7. ### class
  8. class MyAdagrad(Optimizer):
  9. """My modification of the Adagrad optimizer that allows to specify an initial
  10. accumulater value. This mimics the behavior of the default Adagrad implementation
  11. in Tensorflow. The default PyTorch Adagrad uses 0 for initial acculmulator value.
  12. Arguments:
  13. params (iterable): iterable of parameters to optimize or dicts defining
  14. parameter groups
  15. lr (float, optional): learning rate (default: 1e-2)
  16. lr_decay (float, optional): learning rate decay (default: 0)
  17. init_accu_value (float, optional): initial accumulater value.
  18. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  19. """
  20. def __init__(self, params, lr=1e-2, lr_decay=0, init_accu_value=0.1, weight_decay=0):
  21. defaults = dict(lr=lr, lr_decay=lr_decay, init_accu_value=init_accu_value, \
  22. weight_decay=weight_decay)
  23. super(MyAdagrad, self).__init__(params, defaults)
  24. for group in self.param_groups:
  25. for p in group['params']:
  26. state = self.state[p]
  27. state['step'] = 0
  28. state['sum'] = torch.ones(p.data.size()).type_as(p.data) *\
  29. init_accu_value
  30. def share_memory(self):
  31. for group in self.param_groups:
  32. for p in group['params']:
  33. state = self.state[p]
  34. state['sum'].share_memory_()
  35. def step(self, closure=None):
  36. """Performs a single optimization step.
  37. Arguments:
  38. closure (callable, optional): A closure that reevaluates the model
  39. and returns the loss.
  40. """
  41. loss = None
  42. if closure is not None:
  43. loss = closure()
  44. for group in self.param_groups:
  45. for p in group['params']:
  46. if p.grad is None:
  47. continue
  48. grad = p.grad.data
  49. state = self.state[p]
  50. state['step'] += 1
  51. if group['weight_decay'] != 0:
  52. if p.grad.data.is_sparse:
  53. raise RuntimeError("weight_decay option is not compatible with sparse gradients ")
  54. grad = grad.add(group['weight_decay'], p.data)
  55. clr = group['lr'] / (1 + (state['step'] - 1) * group['lr_decay'])
  56. if p.grad.data.is_sparse:
  57. grad = grad.coalesce() # the update is non-linear so indices must be unique
  58. grad_indices = grad._indices()
  59. grad_values = grad._values()
  60. size = torch.Size([x for x in grad.size()])
  61. def make_sparse(values):
  62. constructor = type(p.grad.data)
  63. if grad_indices.dim() == 0 or values.dim() == 0:
  64. return constructor()
  65. return constructor(grad_indices, values, size)
  66. state['sum'].add_(make_sparse(grad_values.pow(2)))
  67. std = state['sum']._sparse_mask(grad)
  68. std_values = std._values().sqrt_().add_(1e-10)
  69. p.data.add_(-clr, make_sparse(grad_values / std_values))
  70. else:
  71. state['sum'].addcmul_(1, grad, grad)
  72. std = state['sum'].sqrt().add_(1e-10)
  73. p.data.addcdiv_(-clr, grad, std)
  74. return loss
  75. ### torch specific functions
  76. def get_optimizer(name, parameters, lr, l2=0):
  77. if name == 'sgd':
  78. return torch.optim.SGD(parameters, lr=lr, weight_decay=l2)
  79. elif name in ['adagrad', 'myadagrad']:
  80. # use my own adagrad to allow for init accumulator value
  81. return MyAdagrad(parameters, lr=lr, init_accu_value=0.1, weight_decay=l2)
  82. elif name == 'adam':
  83. return torch.optim.Adam(parameters, weight_decay=l2) # use default lr
  84. elif name == 'adamax':
  85. return torch.optim.Adamax(parameters, weight_decay=l2) # use default lr
  86. elif name == 'adadelta':
  87. return torch.optim.Adadelta(parameters, lr=lr, weight_decay=l2)
  88. else:
  89. raise Exception("Unsupported optimizer: {}".format(name))
  90. def change_lr(optimizer, new_lr):
  91. for param_group in optimizer.param_groups:
  92. param_group['lr'] = new_lr
  93. def flatten_indices(seq_lens, width):
  94. flat = []
  95. for i, l in enumerate(seq_lens):
  96. for j in range(l):
  97. flat.append(i * width + j)
  98. return flat
  99. def set_cuda(var, cuda):
  100. if cuda:
  101. return var.cuda()
  102. return var
  103. def keep_partial_grad(grad, topk):
  104. """
  105. Keep only the topk rows of grads.
  106. """
  107. assert topk < grad.size(0)
  108. grad.data[topk:].zero_()
  109. return grad
  110. ### model IO
  111. def save(model, optimizer, opt, filename):
  112. params = {
  113. 'model': model.state_dict(),
  114. 'optimizer': optimizer.state_dict(),
  115. 'config': opt
  116. }
  117. try:
  118. torch.save(params, filename)
  119. except BaseException:
  120. print("[ Warning: model saving failed. ]")
  121. def load(model, optimizer, filename):
  122. try:
  123. dump = torch.load(filename)
  124. except BaseException:
  125. print("[ Fail: model loading failed. ]")
  126. if model is not None:
  127. model.load_state_dict(dump['model'])
  128. if optimizer is not None:
  129. optimizer.load_state_dict(dump['optimizer'])
  130. opt = dump['config']
  131. return model, optimizer, opt
  132. def load_config(filename):
  133. try:
  134. dump = torch.load(filename)
  135. except BaseException:
  136. print("[ Fail: model loading failed. ]")
  137. return dump['config']