import torch from models.base_model import ERModel def _rank(scores: torch.Tensor, target: torch.Tensor) -> torch.Tensor: return (scores > scores[target]).sum() + 1 def simple_ranking(model: ERModel, triples: torch.Tensor) -> torch.Tensor: h, r, t = triples[:, 0], triples[:, 1], triples[:, 2] E = model.num_entities ranks = [] with torch.no_grad(): for hi, ri, ti in zip(h, r, t): tails = torch.arange(E, device=triples.device) triples = torch.stack([hi.repeat(E), ri.repeat(E), tails], dim=1) scores = model(triples) ranks.append(_rank(scores, ti)) return torch.tensor(ranks, device=triples.device, dtype=torch.float32) def mrr(r: torch.Tensor) -> torch.Tensor: return (1 / r).mean() def hits_at_k(r: torch.Tensor, k: int = 10) -> torch.Tensor: return (r <= k).float().mean()