|
12345678910111213141516171819202122232425 |
- import torch
- import pickle
-
- def fast_adapt(
- learn,
- adaptation_data,
- evaluation_data,
- adaptation_labels,
- evaluation_labels,
- adaptation_steps,
- get_predictions = False):
-
- for step in range(adaptation_steps):
- temp = learn(adaptation_data)
- train_error = torch.nn.functional.mse_loss(temp.view(-1), adaptation_labels)
- learn.adapt(train_error)
-
- predictions = learn(evaluation_data)
- # loss = torch.nn.MSELoss(reduction='mean')
- # valid_error = loss(predictions, evaluation_labels)
- valid_error = torch.nn.functional.mse_loss(predictions.view(-1),evaluation_labels)
-
- if get_predictions:
- return valid_error,predictions
- return valid_error
|