| @@ -48,7 +48,7 @@ head = torch.nn.Sequential(fc1,fc2,linear_out) | |||
| # META LEARNING | |||
| print("META LEARNING PHASE") | |||
| head = l2l.algorithms.MetaSGD(head, lr=config['local_lr']) | |||
| head = l2l.algorithms.MetaSGD(head, lr=config['local_lr'],first_order=True) | |||
| head.cuda() | |||
| # Setup optimization | |||