You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 5.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. '''
  2. This script handling the training process.
  3. '''
  4. import argparse
  5. import math
  6. import time
  7. from tqdm import tqdm
  8. import torch
  9. import torch.nn as nn
  10. import torch.optim as optim
  11. import transformer.Constants as Constants
  12. from transformer.Models import Decoder
  13. from transformer.Optim import ScheduledOptim
  14. from DataLoader_Feat import DataLoader
  15. CUDA = 1
  16. def get_performance(crit, pred, gold, smoothing=False, num_class=None):
  17. ''' Apply label smoothing if needed '''
  18. # TODO: Add smoothing
  19. if smoothing:
  20. assert bool(num_class)
  21. eps = 0.1
  22. gold = gold * (1 - eps) + (1 - gold) * eps / num_class
  23. raise NotImplementedError
  24. # print(pred.shape,gold.shape)
  25. loss = crit(pred, gold.contiguous().view(-1))
  26. pred = pred.max(1)[1]
  27. gold = gold.contiguous().view(-1)
  28. n_correct = pred.data.eq(gold.data)
  29. n_correct = n_correct.masked_select(gold.ne(Constants.PAD).data).sum()
  30. return loss, n_correct
  31. def train_epoch(model, training_data, crit, optimizer, opt):
  32. ''' Epoch operation in training phase'''
  33. model.train()
  34. total_loss = 0
  35. n_total_words = 0
  36. n_total_correct = 0
  37. for batch in tqdm(
  38. training_data, mininterval=2,
  39. desc=' - (Training) ', leave=False):
  40. # prepare data
  41. tgt = batch
  42. gold = tgt[:, 1:]
  43. h0 = torch.zeros(1, batch.size(0), 8 * opt.d_word_vec)
  44. if opt.cuda:
  45. h0 = h0.cuda(CUDA)
  46. # forward
  47. optimizer.zero_grad()
  48. tgt = tgt.cuda(CUDA)
  49. pred, *_ = model(tgt, h0)
  50. # backward
  51. loss, n_correct = get_performance(crit, pred, gold)
  52. loss.backward()
  53. # update parameters
  54. optimizer.step()
  55. optimizer.update_learning_rate()
  56. # note keeping
  57. n_words = gold.data.ne(Constants.PAD).sum()
  58. n_total_words += n_words
  59. n_total_correct += n_correct
  60. total_loss += loss.item()
  61. # total_loss = total_loss.data.cpu().numpy()
  62. n_total_correct = n_total_correct.data.cpu().numpy()
  63. n_total_words = n_total_words.data.cpu().numpy()
  64. return total_loss / n_total_words, n_total_correct / n_total_words
  65. def train(model, training_data, crit, optimizer, opt):
  66. ''' Start training '''
  67. for epoch_i in range(opt.epoch):
  68. print('[ Epoch', epoch_i, ']')
  69. start = time.time()
  70. train_loss, train_accu = train_epoch(model, training_data, crit, optimizer, opt)
  71. print(' - (Training) accuracy: {accu:3.3f} %, ' \
  72. 'elapse: {elapse:3.3f} min'.format(
  73. accu=100 * train_accu,
  74. elapse=(time.time() - start) / 60))
  75. model_state_dict = model.state_dict()
  76. checkpoint = {
  77. 'model': model_state_dict,
  78. 'settings': opt,
  79. 'epoch': epoch_i}
  80. if epoch_i == (opt.epoch-1):
  81. model_name = 'pa-our-twitter'+opt.save_model + str(epoch_i) + '.chkpt'
  82. torch.save(checkpoint, model_name)
  83. def main():
  84. ''' Main function'''
  85. parser = argparse.ArgumentParser()
  86. parser.add_argument('-epoch', type=int, default=50)
  87. parser.add_argument('-batch_size', type=int, default=4j)
  88. parser.add_argument('-d_model', type=int, default=1)
  89. parser.add_argument('-d_inner_hid', type=int, default=64)
  90. parser.add_argument('-d_k', type=int, default=64)
  91. parser.add_argument('-d_v', type=int, default=64)
  92. parser.add_argument('-window_size', type=int, default=3)
  93. parser.add_argument('-finit', type=int, default=0)
  94. parser.add_argument('-n_head', type=int, default=8)
  95. parser.add_argument('-n_warmup_steps', type=int, default=1000)
  96. parser.add_argument('-dropout', type=float, default=0.1)
  97. parser.add_argument('-embs_share_weight', action='store_true')
  98. parser.add_argument('-proj_share_weight', action='store_true')
  99. parser.add_argument('-save_model', default='Lastfm_test')
  100. parser.add_argument('-no_cuda', action='store_false')
  101. torch.cuda.set_device(0)
  102. opt = parser.parse_args()
  103. opt.cuda = True
  104. opt.d_word_vec = opt.d_model
  105. # ========= Preparing DataLoader =========#
  106. train_data = DataLoader(use_valid=False, load_dict=True, batch_size=opt.batch_size, cuda=opt.cuda)
  107. opt.user_size = train_data.user_size
  108. # ========= Preparing Model =========#
  109. decoder = Decoder(
  110. opt.user_size,
  111. d_k=opt.d_k,
  112. d_v=opt.d_v,
  113. d_model=opt.d_model,
  114. d_word_vec=opt.d_word_vec,
  115. d_inner_hid=opt.d_inner_hid,
  116. n_head=opt.n_head,
  117. kernel_size=opt.window_size,
  118. dropout=opt.dropout)
  119. optimizer = ScheduledOptim(
  120. optim.Adam(
  121. decoder.parameters(),
  122. betas=(0.9, 0.98), eps=1e-09),
  123. opt.d_model, opt.n_warmup_steps)
  124. def get_criterion(user_size):
  125. ''' With PAD token zero weight '''
  126. weight = torch.ones(user_size)
  127. weight[Constants.PAD] = 0
  128. weight[Constants.EOS] = 0
  129. return nn.CrossEntropyLoss(weight, size_average=False)
  130. crit = get_criterion(train_data.user_size)
  131. if opt.cuda:
  132. decoder = decoder.cuda(CUDA)
  133. crit = crit.cuda(CUDA)
  134. train(decoder, train_data, crit, optimizer, opt)
  135. if __name__ == '__main__':
  136. main()