BSc project of Parham Saremi. The goal of the project was to detect the geographical region of the food using textual and visual features extracted from recipes and ingredients of the food.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

sam.py 2.4KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import torch
  2. class SAM(torch.optim.Optimizer):
  3. def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
  4. assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
  5. defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
  6. super(SAM, self).__init__(params, defaults)
  7. self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
  8. self.param_groups = self.base_optimizer.param_groups
  9. @torch.no_grad()
  10. def first_step(self, zero_grad=False):
  11. grad_norm = self._grad_norm()
  12. for group in self.param_groups:
  13. scale = group["rho"] / (grad_norm + 1e-12)
  14. for p in group["params"]:
  15. if p.grad is None: continue
  16. self.state[p]["old_p"] = p.data.clone()
  17. e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
  18. p.add_(e_w) # climb to the local maximum "w + e(w)"
  19. if zero_grad: self.zero_grad()
  20. @torch.no_grad()
  21. def second_step(self, zero_grad=False):
  22. for group in self.param_groups:
  23. for p in group["params"]:
  24. if p.grad is None: continue
  25. p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)"
  26. self.base_optimizer.step() # do the actual "sharpness-aware" update
  27. if zero_grad: self.zero_grad()
  28. @torch.no_grad()
  29. def step(self, closure=None):
  30. assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
  31. closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass
  32. self.first_step(zero_grad=True)
  33. closure()
  34. self.second_step()
  35. def _grad_norm(self):
  36. shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism
  37. norm = torch.norm(
  38. torch.stack([
  39. ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
  40. for group in self.param_groups for p in group["params"]
  41. if p.grad is not None
  42. ]),
  43. p=2
  44. )
  45. return norm
  46. def load_state_dict(self, state_dict):
  47. super().load_state_dict(state_dict)
  48. self.base_optimizer.param_groups = self.param_groups