import torch import pickle def fast_adapt( learn, adaptation_data, evaluation_data, adaptation_labels, evaluation_labels, adaptation_steps, get_predictions=False): for step in range(adaptation_steps): temp = learn(adaptation_data, adaptation_labels, training=True) train_error = torch.nn.functional.mse_loss(temp.view(-1), adaptation_labels) learn.adapt(train_error) predictions = learn(evaluation_data, None, training=False, adaptation_data=adaptation_data, adaptation_labels=adaptation_labels) valid_error = torch.nn.functional.mse_loss(predictions.view(-1), evaluation_labels) if get_predictions: return valid_error, predictions return valid_error