from copy import deepcopy import hydra import lovely_tensors as lt import torch from hydra.utils import instantiate from omegaconf import DictConfig from metrics.c_swklf import CSWKLF from tools import get_pretty_logger # (Import whatever Trainer or runner you have) from training.trainer import Trainer logger = get_pretty_logger(__name__) lt.monkey_patch() @hydra.main(config_path="configs", config_name="config") def main(cfg: DictConfig) -> None: # Detect CUDA device = torch.device("cuda" if torch.cuda.is_available() else "cpu") common_params = cfg.common training_params = deepcopy(cfg.training) # drop the model_trainer from cfg.trainer del training_params.model_trainer common_params = instantiate(common_params) training_params = instantiate(training_params) # dataset instantiation train_dataset = instantiate(cfg.data.train) val_dataset = instantiate(cfg.data.valid) model = instantiate( cfg.model, num_entities=train_dataset.num_entities, num_relations=train_dataset.num_relations, triples_factory=train_dataset.triples_factory, device=device, ) model = model.to(device) model_trainer = instantiate( cfg.training.model_trainer, model=model, training_params=training_params, triples_factory=train_dataset.triples_factory, mapped_triples=train_dataset.triples, device=model.device, ) # Instantiate the Trainer with the model_trainer trainer = Trainer( train_dataset=train_dataset, val_dataset=val_dataset, model_trainer=model_trainer, common_params=common_params, training_params=training_params, device=model.device, ) if cfg.common.evaluate_only: trainer.evaluate() c_swklf_metric = CSWKLF( dataset=val_dataset, model=model, ) c_swklf_value = c_swklf_metric.compute() logger.info(f"CSWKLF Metric Value: {c_swklf_value}") else: trainer.run() if __name__ == "__main__": main()