import torch import pickle def fast_adapt( learn, adaptation_data, evaluation_data, adaptation_labels, evaluation_labels, adaptation_steps, get_predictions = False): for step in range(adaptation_steps): temp = learn(adaptation_data) train_error = torch.nn.functional.mse_loss(temp.view(-1), adaptation_labels) learn.adapt(train_error) predictions = learn(evaluation_data) # loss = torch.nn.MSELoss(reduction='mean') # valid_error = loss(predictions, evaluation_labels) valid_error = torch.nn.functional.mse_loss(predictions.view(-1),evaluation_labels) if get_predictions: return valid_error,predictions return valid_error