123456789101112131415161718192021222324252627282930 |
- import torch
-
- from models.base_model import ERModel
-
-
- def _rank(scores: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
- return (scores > scores[target]).sum() + 1
-
-
- def simple_ranking(model: ERModel, triples: torch.Tensor) -> torch.Tensor:
- h, r, t = triples[:, 0], triples[:, 1], triples[:, 2]
- E = model.num_entities
- ranks = []
-
- with torch.no_grad():
- for hi, ri, ti in zip(h, r, t):
- tails = torch.arange(E, device=triples.device)
- triples = torch.stack([hi.repeat(E), ri.repeat(E), tails], dim=1)
- scores = model(triples)
- ranks.append(_rank(scores, ti))
-
- return torch.tensor(ranks, device=triples.device, dtype=torch.float32)
-
-
- def mrr(r: torch.Tensor) -> torch.Tensor:
- return (1 / r).mean()
-
-
- def hits_at_k(r: torch.Tensor, k: int = 10) -> torch.Tensor:
- return (r <= k).float().mean()
|