12345678910111213141516171819202122232425 |
- import torch
- import pickle
-
-
- def fast_adapt(
- learn,
- adaptation_data,
- evaluation_data,
- adaptation_labels,
- evaluation_labels,
- adaptation_steps,
- get_predictions=False):
- for step in range(adaptation_steps):
- temp = learn(adaptation_data, adaptation_labels, training=True)
- train_error = torch.nn.functional.mse_loss(temp.view(-1), adaptation_labels)
- learn.adapt(train_error)
-
- predictions = learn(evaluation_data, None , training=False)
- # loss = torch.nn.MSELoss(reduction='mean')
- # valid_error = loss(predictions, evaluation_labels)
- valid_error = torch.nn.functional.mse_loss(predictions.view(-1), evaluation_labels)
-
- if get_predictions:
- return valid_error, predictions
- return valid_error
|