You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

_utils.py 1.8KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. def prefix_dict_keys(prefix, input_dict):
  2. return {f'{prefix}_{key}': val for key, val in input_dict.items()}
  3. def print_system_info():
  4. from platform import python_version
  5. print(f"Python version is: {python_version()}")
  6. try:
  7. import sklearn
  8. print(f"Scikit-learn version is: {sklearn.__version__}")
  9. except:
  10. print("Scikit-learn not found!!!")
  11. try:
  12. import torch
  13. print(f"Torch version is: {torch.__version__}")
  14. if torch.cuda.is_available() and torch.cuda.device_count() >= 0:
  15. print(f"Nvidia device is: {torch.cuda.get_device_name(0)}")
  16. else:
  17. print("Torch is using CPU")
  18. except:
  19. print("Torch not found!!!")
  20. return
  21. try:
  22. import transformers
  23. print(f"Transformers version is: {transformers.__version__}")
  24. try:
  25. print(f"Adapterhub version is: {transformers.adapters.__version__}")
  26. except:
  27. print("Adapterhub not found!!!")
  28. except:
  29. print("Transformers not found!!!")
  30. def silent_logs():
  31. import os
  32. os.environ["WANDB_SILENT"] = "true"
  33. # os.environ["TRANSFORMERS_VERBOSITY"] = "fatal"
  34. os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
  35. os.environ["ACCELERATE_LOG_LEVEL"] = "CRITICAL"
  36. import transformers
  37. from transformers.utils import logging
  38. logging.set_verbosity(transformers.logging.FATAL)
  39. from datasets.utils.logging import disable_progress_bar, set_verbosity_error
  40. disable_progress_bar()
  41. set_verbosity_error()
  42. import accelerate.utils.other as accelerate_other
  43. accelerate_other.logger.setLevel(50)
  44. def sp_encode(data):
  45. import json
  46. import base64
  47. return base64.b32encode(json.dumps(data).encode())
  48. def sp_decode(encoded_data):
  49. import json
  50. import base64
  51. return json.loads(base64.b32decode(encoded_data).decode())