Adapted to Movie lens dataset
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

TaNP_training.py 1.5KB

123456789101112131415161718192021222324252627282930313233343536
  1. import os
  2. import torch
  3. import pickle
  4. import random
  5. from eval import testing
  6. def training(trainer, opt, train_dataset, test_dataset, batch_size, num_epoch, model_save=True, model_filename=None, logger=None):
  7. training_set_size = len(train_dataset)
  8. for epoch in range(num_epoch):
  9. random.shuffle(train_dataset)
  10. num_batch = int(training_set_size / batch_size)
  11. a, b, c, d = zip(*train_dataset)
  12. trainer.train()
  13. all_C_distribs = []
  14. for i in range(num_batch):
  15. try:
  16. supp_xs = list(a[batch_size*i:batch_size*(i+1)])
  17. supp_ys = list(b[batch_size*i:batch_size*(i+1)])
  18. query_xs = list(c[batch_size*i:batch_size*(i+1)])
  19. query_ys = list(d[batch_size*i:batch_size*(i+1)])
  20. except IndexError:
  21. continue
  22. train_loss, batch_C_distribs = trainer.global_update(supp_xs, supp_ys, query_xs, query_ys)
  23. all_C_distribs.append(batch_C_distribs)
  24. P5, NDCG5, MAP5, P7, NDCG7, MAP7, P10, NDCG10, MAP10 = testing(trainer, opt, test_dataset)
  25. logger.log(
  26. "{}\t{:.6f}\t TOP-5 {:.4f}\t{:.4f}\t{:.4f}\t TOP-7: {:.4f}\t{:.4f}\t{:.4f}"
  27. "\t TOP-10: {:.4f}\t{:.4f}\t{:.4f}".
  28. format(epoch, train_loss, P5, NDCG5, MAP5, P7, NDCG7, MAP7, P10, NDCG10, MAP10))
  29. if epoch == (num_epoch-1):
  30. with open('output_att', 'wb') as fp:
  31. pickle.dump(all_C_distribs, fp)
  32. if model_save:
  33. torch.save(trainer.state_dict(), model_filename)