meta-learning approach for solving cold start problem
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

model_training.py 1008B

12345678910111213141516171819202122232425262728293031
  1. import os
  2. import torch
  3. import pickle
  4. import random
  5. from MeLU import MeLU
  6. from options import config, states
  7. def training(melu, total_dataset, batch_size, num_epoch, model_save=True, model_filename=None):
  8. if config['use_cuda']:
  9. melu.cuda()
  10. training_set_size = len(total_dataset)
  11. melu.train()
  12. for _ in range(num_epoch):
  13. random.shuffle(total_dataset)
  14. num_batch = int(training_set_size / batch_size)
  15. a,b,c,d = zip(*total_dataset)
  16. for i in range(num_batch):
  17. try:
  18. supp_xs = list(a[batch_size*i:batch_size*(i+1)])
  19. supp_ys = list(b[batch_size*i:batch_size*(i+1)])
  20. query_xs = list(c[batch_size*i:batch_size*(i+1)])
  21. query_ys = list(d[batch_size*i:batch_size*(i+1)])
  22. except IndexError:
  23. continue
  24. melu.global_update(supp_xs, supp_ys, query_xs, query_ys, config['inner'])
  25. if model_save:
  26. torch.save(melu.state_dict(), model_filename)