extend Melu code to perform different meta algorithms and hyperparameters
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

fast_adapt.py 787B

123456789101112131415161718192021222324
  1. import torch
  2. import pickle
  3. def fast_adapt(
  4. learn,
  5. adaptation_data,
  6. evaluation_data,
  7. adaptation_labels,
  8. evaluation_labels,
  9. adaptation_steps,
  10. get_predictions=False):
  11. for step in range(adaptation_steps):
  12. temp = learn(adaptation_data, adaptation_labels, training=True)
  13. train_error = torch.nn.functional.mse_loss(temp.view(-1), adaptation_labels)
  14. learn.adapt(train_error)
  15. predictions = learn(evaluation_data, None, training=False, adaptation_data=adaptation_data,
  16. adaptation_labels=adaptation_labels)
  17. valid_error = torch.nn.functional.mse_loss(predictions.view(-1), evaluation_labels)
  18. if get_predictions:
  19. return valid_error, predictions
  20. return valid_error