meta-learning approach for solving cold start problem
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

model_training.py 1.1KB

1234567891011121314151617181920212223242526272829303132333435
  1. import os
  2. import torch
  3. import pickle
  4. import random
  5. from MeLU import MeLU
  6. from options import config, states
  7. def training(melu, total_dataset, batch_size, num_epoch, model_save=True, model_filename=None):
  8. if config['use_cuda']:
  9. melu.cuda()
  10. print("mode: " + str(config['use_cuda']))
  11. training_set_size = len(total_dataset)
  12. melu.train()
  13. for epoch in range(num_epoch):
  14. random.shuffle(total_dataset)
  15. num_batch = int(training_set_size / batch_size)
  16. a,b,c,d = zip(*total_dataset)
  17. for i in range(num_batch):
  18. print("training - epoch:{} - batch:{}".format(epoch,i))
  19. try:
  20. supp_xs = list(a[batch_size*i:batch_size*(i+1)])
  21. supp_ys = list(b[batch_size*i:batch_size*(i+1)])
  22. query_xs = list(c[batch_size*i:batch_size*(i+1)])
  23. query_ys = list(d[batch_size*i:batch_size*(i+1)])
  24. except IndexError:
  25. continue
  26. melu.global_update(supp_xs, supp_ys, query_xs, query_ys, config['inner'])
  27. del supp_xs,supp_ys,query_xs,query_ys
  28. if model_save:
  29. torch.save(melu.state_dict(), model_filename)