@@ -48,7 +48,7 @@ head = torch.nn.Sequential(fc1,fc2,linear_out) | |||
# META LEARNING | |||
print("META LEARNING PHASE") | |||
head = l2l.algorithms.MetaSGD(head, lr=config['local_lr']) | |||
head = l2l.algorithms.MetaSGD(head, lr=config['local_lr'],first_order=True) | |||
head.cuda() | |||
# Setup optimization |