Melu project implemented by l2l and using MetaSGD instead of MAML
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

evidence_candidate.py 2.7KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import os
  2. import torch
  3. import pickle
  4. from MeLU import MeLU
  5. from options import config
  6. def selection(melu, master_path, topk):
  7. if not os.path.exists("{}/scores/".format(master_path)):
  8. os.mkdir("{}/scores/".format(master_path))
  9. if config['use_cuda']:
  10. melu.cuda()
  11. melu.eval()
  12. target_state = 'warm_state'
  13. dataset_size = int(len(os.listdir("{}/{}".format(master_path, target_state))) / 4)
  14. grad_norms = {}
  15. for j in list(range(dataset_size)):
  16. support_xs = pickle.load(open("{}/{}/supp_x_{}.pkl".format(master_path, target_state, j), "rb"))
  17. support_ys = pickle.load(open("{}/{}/supp_y_{}.pkl".format(master_path, target_state, j), "rb"))
  18. item_ids = []
  19. with open("{}/log/{}/supp_x_{}_u_m_ids.txt".format(master_path, target_state, j), "r") as f:
  20. for line in f.readlines():
  21. item_id = line.strip().split()[1]
  22. item_ids.append(item_id)
  23. for support_x, support_y, item_id in zip(support_xs, support_ys, item_ids):
  24. support_x = support_x.view(1, -1)
  25. support_y = support_y.view(1, -1)
  26. norm = melu.get_weight_avg_norm(support_x, support_y, config['inner'])
  27. try:
  28. grad_norms[item_id]['discriminative_value'] += norm.item()
  29. grad_norms[item_id]['popularity_value'] += 1
  30. except:
  31. grad_norms[item_id] = {
  32. 'discriminative_value': norm.item(),
  33. 'popularity_value': 1
  34. }
  35. d_value_max = 0
  36. p_value_max = 0
  37. for item_id in grad_norms.keys():
  38. grad_norms[item_id]['discriminative_value'] /= grad_norms[item_id]['popularity_value']
  39. if grad_norms[item_id]['discriminative_value'] > d_value_max:
  40. d_value_max = grad_norms[item_id]['discriminative_value']
  41. if grad_norms[item_id]['popularity_value'] > p_value_max:
  42. p_value_max = grad_norms[item_id]['popularity_value']
  43. for item_id in grad_norms.keys():
  44. grad_norms[item_id]['discriminative_value'] /= float(d_value_max)
  45. grad_norms[item_id]['popularity_value'] /= float(p_value_max)
  46. grad_norms[item_id]['final_score'] = grad_norms[item_id]['discriminative_value'] * grad_norms[item_id]['popularity_value']
  47. movie_info = {}
  48. with open("./movielens/ml-1m/movies_extrainfos.dat", encoding="utf-8") as f:
  49. for line in f.readlines():
  50. tmp = line.strip().split("::")
  51. movie_info[tmp[0]] = "{} ({})".format(tmp[1], tmp[2])
  52. evidence_candidates = []
  53. for item_id, value in list(sorted(grad_norms.items(), key=lambda x: x[1]['final_score'], reverse=True))[:topk]:
  54. evidence_candidates.append((movie_info[item_id], value['final_score']))
  55. return evidence_candidates