You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_new.py 6.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. '''
  2. This script handling the training process.
  3. '''
  4. import argparse
  5. import math
  6. import time
  7. from tqdm import tqdm
  8. import torch
  9. import torch.nn as nn
  10. import torch.optim as optim
  11. import transformer.Constants as Constants
  12. from transformer.MyModel import EncoderDecoder , Encoder , Decoder , Generator ,BahdanauAttention
  13. from transformer.Optim import ScheduledOptim
  14. from DataLoader_Feat import DataLoader
  15. import numpy as np
  16. import pickle
  17. idx2vec_addr= '/home/rafie/NeuralDiffusionModel-master/data/weibo2/idx2vec.pickle'
  18. def get_performance(crit, pred, gold,opt, smoothing=False, num_class=None):
  19. ''' Apply label smoothing if needed '''
  20. # TODO: Add smoothing
  21. if smoothing:
  22. assert bool(num_class)
  23. eps = 0.1
  24. gold = gold * (1 - eps) + (1 - gold) * eps / num_class
  25. raise NotImplementedError
  26. # print(pred.shape,gold.shape)
  27. loss = crit(pred.view(-1, pred.size(-1)), gold.contiguous().view(-1))
  28. pred = pred.view(-1,opt.user_size).max(1)[1]
  29. gold = gold.contiguous().view(-1)
  30. # print('preddddddddddd\n',pred,'golddddddddd\n',gold)
  31. n_correct = pred.view(-1).data.eq(gold.data)
  32. # print(n_correct)
  33. n_correct = n_correct.masked_select(gold.ne(Constants.PAD).data).sum()
  34. # print('cooooooorrrecttttt',n_correct)
  35. return loss, n_correct
  36. def train_epoch(model, training_data, crit, optimizer, opt,idx2vec):
  37. ''' Epoch operation in training phase'''
  38. model.train()
  39. total_loss = 0
  40. n_total_words = 0
  41. n_total_correct = 0
  42. for (src,src_mask,src_lengths,trg,trg_y,trg_mask,trg_lengths) in tqdm(
  43. training_data, mininterval=2,
  44. desc=' - (Training) ', leave=False):
  45. # forward
  46. optimizer.zero_grad()
  47. # print('*****************************************\n',src,src_mask,src_lengths,trg,trg_y,trg_mask,trg_lengths)
  48. # src = torch.FloatTensor([[ idx2vec[int(idx.data.cpu().numpy())] for idx in src[i]] for i in range(src.size(0))]).cuda(0)
  49. _,_,pred = model(src, trg, src_mask, trg_mask, src_lengths, trg_lengths)
  50. pred = model.generator(pred)
  51. # backward
  52. gold = trg_y
  53. loss, n_correct = get_performance(crit, pred, gold,opt)
  54. loss.backward()
  55. # update parameters
  56. optimizer.step()
  57. # optimizer.update_learning_rate()
  58. # note keeping
  59. n_words = gold.data.ne(Constants.PAD).sum()
  60. n_total_words += n_words
  61. n_total_correct += n_correct
  62. total_loss += loss.item()
  63. # total_loss = total_loss.data.cpu().numpy()
  64. n_total_correct = n_total_correct.data.cpu().numpy()
  65. n_total_words = n_total_words.data.cpu().numpy()
  66. return total_loss/n_total_words, n_total_correct/n_total_words
  67. def train(model, training_data, crit, optimizer, opt):
  68. ''' Start training '''
  69. with open(idx2vec_addr, 'rb') as handle:
  70. idx2vec = pickle.load(handle)
  71. for epoch_i in range(opt.epoch):
  72. print('[ Epoch', epoch_i, ']')
  73. start = time.time()
  74. train_loss, train_accu = train_epoch(model, training_data, crit, optimizer , opt,idx2vec)
  75. print(' - (Training) accuracy: {accu:3.3f} %, '\
  76. 'elapse: {elapse:3.3f} min'.format(
  77. accu=100*train_accu,
  78. elapse=(time.time()-start)/60))
  79. model_state_dict = model.state_dict()
  80. checkpoint = {
  81. 'model': model_state_dict,
  82. 'settings': opt,
  83. 'epoch': epoch_i}
  84. if opt.save_model and epoch_i%10==0:
  85. model_name = opt.save_model + str(epoch_i)+'.chkpt'
  86. torch.save(checkpoint, model_name)
  87. def main():
  88. ''' Main function'''
  89. # torch.ma
  90. parser = argparse.ArgumentParser()
  91. parser.add_argument('-epoch', type=int, default=20)
  92. parser.add_argument('-batch_size', type=int, default=32)
  93. parser.add_argument('-d_model', type=int, default=64)
  94. parser.add_argument('-d_inner_hid', type=int, default=64)
  95. parser.add_argument('-d_k', type=int, default=64)
  96. parser.add_argument('-d_v', type=int, default=64)
  97. parser.add_argument('-window_size',type=int, default=3)
  98. parser.add_argument('-finit',type=int, default=0)
  99. parser.add_argument('-n_head', type=int, default=8)
  100. parser.add_argument('-n_warmup_steps', type=int, default=1000)
  101. parser.add_argument('-dropout', type=float, default=0.5)
  102. parser.add_argument('-embs_share_weight', action='store_true')
  103. parser.add_argument('-proj_share_weight', action='store_true')
  104. parser.add_argument('-save_model', default='Lastfm_test')
  105. parser.add_argument('-no_cuda', action='store_true')
  106. opt = parser.parse_args()
  107. opt.cuda = not opt.no_cuda
  108. opt.d_word_vec = opt.d_model
  109. #========= Preparing DataLoader =========#
  110. train_data = DataLoader(use_valid=False, load_dict=False, batch_size=opt.batch_size, cuda=opt.cuda)
  111. opt.user_size = train_data.user_size
  112. #========= Preparing Model =========#
  113. # decoder = Decoder(
  114. # opt.user_size,
  115. # d_k=opt.d_k,
  116. # d_v=opt.d_v,
  117. # d_model=opt.d_model,
  118. # d_word_vec=opt.d_word_vec,
  119. # d_inner_hid=opt.d_inner_hid,
  120. # n_head=opt.n_head,
  121. # kernel_size=opt.window_size,
  122. # dropout=opt.dropout)
  123. attention = BahdanauAttention(opt.d_inner_hid)
  124. model = EncoderDecoder(
  125. Encoder(opt.d_word_vec, opt.d_inner_hid, num_layers=1, dropout=opt.dropout),
  126. Decoder(opt.d_word_vec, opt.d_inner_hid, attention, num_layers=1, dropout=opt.dropout),
  127. nn.Embedding(opt.user_size, opt.d_word_vec),
  128. nn.Embedding(opt.user_size, opt.d_word_vec),
  129. Generator(opt.d_inner_hid, opt.user_size))
  130. optimizer = ScheduledOptim(
  131. optim.Adam(
  132. model.parameters(),
  133. betas=(0.9, 0.98), eps=1e-09),
  134. opt.d_model, opt.n_warmup_steps)
  135. # optimizer = optim.Adam(model.parameters(),0.01 ,weight_decay=0.1)
  136. def get_criterion(user_size):
  137. ''' With PAD token zero weight '''
  138. weight = torch.ones(user_size)
  139. weight[Constants.PAD] = 0
  140. weight[Constants.EOS] = 0
  141. return nn.CrossEntropyLoss(weight, size_average=False)
  142. crit = get_criterion(train_data.user_size)
  143. if opt.cuda:
  144. model = model.cuda(0)
  145. crit = crit.cuda(0)
  146. train(model, train_data, crit, optimizer, opt)
  147. if __name__ == '__main__':
  148. main()