1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980 |
- from copy import deepcopy
-
- import hydra
- import lovely_tensors as lt
- import torch
- from hydra.utils import instantiate
- from omegaconf import DictConfig
-
- from metrics.c_swklf import CSWKLF
- from tools import get_pretty_logger
-
- # (Import whatever Trainer or runner you have)
- from training.trainer import Trainer
-
- logger = get_pretty_logger(__name__)
-
- lt.monkey_patch()
-
-
- @hydra.main(config_path="configs", config_name="config")
- def main(cfg: DictConfig) -> None:
- # Detect CUDA
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
- common_params = cfg.common
- training_params = deepcopy(cfg.training)
- # drop the model_trainer from cfg.trainer
- del training_params.model_trainer
-
- common_params = instantiate(common_params)
- training_params = instantiate(training_params)
-
- # dataset instantiation
- train_dataset = instantiate(cfg.data.train)
- val_dataset = instantiate(cfg.data.valid)
-
- model = instantiate(
- cfg.model,
- num_entities=train_dataset.num_entities,
- num_relations=train_dataset.num_relations,
- triples_factory=train_dataset.triples_factory,
- device=device,
- )
-
- model = model.to(device)
-
- model_trainer = instantiate(
- cfg.training.model_trainer,
- model=model,
- training_params=training_params,
- triples_factory=train_dataset.triples_factory,
- mapped_triples=train_dataset.triples,
- device=model.device,
- )
-
- # Instantiate the Trainer with the model_trainer
- trainer = Trainer(
- train_dataset=train_dataset,
- val_dataset=val_dataset,
- model_trainer=model_trainer,
- common_params=common_params,
- training_params=training_params,
- device=model.device,
- )
-
- if cfg.common.evaluate_only:
- trainer.evaluate()
- c_swklf_metric = CSWKLF(
- dataset=val_dataset,
- model=model,
- )
- c_swklf_value = c_swklf_metric.compute()
- logger.info(f"CSWKLF Metric Value: {c_swklf_value}")
-
- else:
- trainer.run()
-
-
- if __name__ == "__main__":
- main()
|