You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

ranking.py 877B

123456789101112131415161718192021222324252627282930
  1. import torch
  2. from models.base_model import ERModel
  3. def _rank(scores: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
  4. return (scores > scores[target]).sum() + 1
  5. def simple_ranking(model: ERModel, triples: torch.Tensor) -> torch.Tensor:
  6. h, r, t = triples[:, 0], triples[:, 1], triples[:, 2]
  7. E = model.num_entities
  8. ranks = []
  9. with torch.no_grad():
  10. for hi, ri, ti in zip(h, r, t):
  11. tails = torch.arange(E, device=triples.device)
  12. triples = torch.stack([hi.repeat(E), ri.repeat(E), tails], dim=1)
  13. scores = model(triples)
  14. ranks.append(_rank(scores, ti))
  15. return torch.tensor(ranks, device=triples.device, dtype=torch.float32)
  16. def mrr(r: torch.Tensor) -> torch.Tensor:
  17. return (1 / r).mean()
  18. def hits_at_k(r: torch.Tensor, k: int = 10) -> torch.Tensor:
  19. return (r <= k).float().mean()