123456789101112131415161718192021222324 |
- import torch
- import pickle
-
-
- def fast_adapt(
- learn,
- adaptation_data,
- evaluation_data,
- adaptation_labels,
- evaluation_labels,
- adaptation_steps,
- get_predictions=False):
- for step in range(adaptation_steps):
- temp = learn(adaptation_data, adaptation_labels, training=True)
- train_error = torch.nn.functional.mse_loss(temp.view(-1), adaptation_labels)
- learn.adapt(train_error)
-
- predictions = learn(evaluation_data, None, training=False, adaptation_data=adaptation_data,
- adaptation_labels=adaptation_labels)
- valid_error = torch.nn.functional.mse_loss(predictions.view(-1), evaluation_labels)
-
- if get_predictions:
- return valid_error, predictions
- return valid_error
|