You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 2.1KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. from copy import deepcopy
  2. import hydra
  3. import lovely_tensors as lt
  4. import torch
  5. from hydra.utils import instantiate
  6. from omegaconf import DictConfig
  7. from metrics.c_swklf import CSWKLF
  8. from tools import get_pretty_logger
  9. # (Import whatever Trainer or runner you have)
  10. from training.trainer import Trainer
  11. logger = get_pretty_logger(__name__)
  12. lt.monkey_patch()
  13. @hydra.main(config_path="configs", config_name="config")
  14. def main(cfg: DictConfig) -> None:
  15. # Detect CUDA
  16. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  17. common_params = cfg.common
  18. training_params = deepcopy(cfg.training)
  19. # drop the model_trainer from cfg.trainer
  20. del training_params.model_trainer
  21. common_params = instantiate(common_params)
  22. training_params = instantiate(training_params)
  23. # dataset instantiation
  24. train_dataset = instantiate(cfg.data.train)
  25. val_dataset = instantiate(cfg.data.valid)
  26. model = instantiate(
  27. cfg.model,
  28. num_entities=train_dataset.num_entities,
  29. num_relations=train_dataset.num_relations,
  30. triples_factory=train_dataset.triples_factory,
  31. device=device,
  32. )
  33. model = model.to(device)
  34. model_trainer = instantiate(
  35. cfg.training.model_trainer,
  36. model=model,
  37. training_params=training_params,
  38. triples_factory=train_dataset.triples_factory,
  39. mapped_triples=train_dataset.triples,
  40. device=model.device,
  41. )
  42. # Instantiate the Trainer with the model_trainer
  43. trainer = Trainer(
  44. train_dataset=train_dataset,
  45. val_dataset=val_dataset,
  46. model_trainer=model_trainer,
  47. common_params=common_params,
  48. training_params=training_params,
  49. device=model.device,
  50. )
  51. if cfg.common.evaluate_only:
  52. trainer.evaluate()
  53. c_swklf_metric = CSWKLF(
  54. dataset=val_dataset,
  55. model=model,
  56. )
  57. c_swklf_value = c_swklf_metric.compute()
  58. logger.info(f"CSWKLF Metric Value: {c_swklf_value}")
  59. else:
  60. trainer.run()
  61. if __name__ == "__main__":
  62. main()