extend Melu code to perform different meta algorithms and hyperparameters
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

fast_adapt.py 798B

1234567891011121314151617181920212223242526
  1. import torch
  2. import pickle
  3. def fast_adapt(
  4. learn,
  5. adaptation_data,
  6. evaluation_data,
  7. adaptation_labels,
  8. evaluation_labels,
  9. adaptation_steps,
  10. get_predictions = False):
  11. for step in range(adaptation_steps):
  12. temp = learn(adaptation_data)
  13. train_error = torch.nn.functional.mse_loss(temp.view(-1), adaptation_labels)
  14. learn.adapt(train_error)
  15. predictions = learn(evaluation_data)
  16. # loss = torch.nn.MSELoss(reduction='mean')
  17. # valid_error = loss(predictions, evaluation_labels)
  18. valid_error = torch.nn.functional.mse_loss(predictions.view(-1),evaluation_labels)
  19. if get_predictions:
  20. return valid_error,predictions
  21. return valid_error