|
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162 |
- import os
- import torch
- import pickle
-
- from MeLU import MeLU
- from options import config
-
-
- def selection(melu, master_path, topk):
- if not os.path.exists("{}/scores/".format(master_path)):
- os.mkdir("{}/scores/".format(master_path))
- if config['use_cuda']:
- melu.cuda()
- melu.eval()
-
- target_state = 'warm_state'
- dataset_size = int(len(os.listdir("{}/{}".format(master_path, target_state))) / 4)
- grad_norms = {}
- for j in list(range(dataset_size)):
- support_xs = pickle.load(open("{}/{}/supp_x_{}.pkl".format(master_path, target_state, j), "rb"))
- support_ys = pickle.load(open("{}/{}/supp_y_{}.pkl".format(master_path, target_state, j), "rb"))
- item_ids = []
- with open("{}/log/{}/supp_x_{}_u_m_ids.txt".format(master_path, target_state, j), "r") as f:
- for line in f.readlines():
- item_id = line.strip().split()[1]
- item_ids.append(item_id)
- for support_x, support_y, item_id in zip(support_xs, support_ys, item_ids):
- support_x = support_x.view(1, -1)
- support_y = support_y.view(1, -1)
- norm = melu.get_weight_avg_norm(support_x, support_y, config['inner'])
- try:
- grad_norms[item_id]['discriminative_value'] += norm.item()
- grad_norms[item_id]['popularity_value'] += 1
- except:
- grad_norms[item_id] = {
- 'discriminative_value': norm.item(),
- 'popularity_value': 1
- }
-
- d_value_max = 0
- p_value_max = 0
- for item_id in grad_norms.keys():
- grad_norms[item_id]['discriminative_value'] /= grad_norms[item_id]['popularity_value']
- if grad_norms[item_id]['discriminative_value'] > d_value_max:
- d_value_max = grad_norms[item_id]['discriminative_value']
- if grad_norms[item_id]['popularity_value'] > p_value_max:
- p_value_max = grad_norms[item_id]['popularity_value']
- for item_id in grad_norms.keys():
- grad_norms[item_id]['discriminative_value'] /= float(d_value_max)
- grad_norms[item_id]['popularity_value'] /= float(p_value_max)
- grad_norms[item_id]['final_score'] = grad_norms[item_id]['discriminative_value'] * grad_norms[item_id]['popularity_value']
-
- movie_info = {}
- with open("./movielens/ml-1m/movies_extrainfos.dat", encoding="utf-8") as f:
- for line in f.readlines():
- tmp = line.strip().split("::")
- movie_info[tmp[0]] = "{} ({})".format(tmp[1], tmp[2])
-
- evidence_candidates = []
- for item_id, value in list(sorted(grad_norms.items(), key=lambda x: x[1]['final_score'], reverse=True))[:topk]:
- evidence_candidates.append((movie_info[item_id], value['final_score']))
- return evidence_candidates
|