You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

single_prompt.py 4.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. import torch
  2. import torch.nn as nn
  3. from .gumbal_switch import GumbalSwitch
  4. class SingleSuperSimplePrompt(nn.Module):
  5. def __init__(self, pretrained_emb):
  6. super().__init__()
  7. self.sadcl_prompt = nn.parameter.Parameter(
  8. pretrained_emb.detach().clone()
  9. )
  10. def forward(self):
  11. return self.sadcl_prompt
  12. def use_pretrained_tokens(self, new_tokens):
  13. assert new_tokens.shape[0] == 1
  14. assert new_tokens.shape[1] == self.sadcl_prompt.data.shape[0]
  15. self.sadcl_prompt.data = new_tokens[0].detach().clone()
  16. class SingleSimplePrompt(nn.Module):
  17. def __init__(self, pretrained_emb):
  18. super().__init__()
  19. self.pretrained_emb = nn.parameter.Parameter(
  20. pretrained_emb.detach().clone()
  21. )
  22. self.sadcl_emb_diff = nn.parameter.Parameter(
  23. torch.zeros_like(pretrained_emb)
  24. )
  25. def forward(self):
  26. return self.pretrained_emb + self.sadcl_emb_diff
  27. class SingleResidualPrompt(nn.Module):
  28. def __init__(self, pretrained_emb, mlp_size):
  29. super().__init__()
  30. self.pretrained_emb = nn.parameter.Parameter(
  31. pretrained_emb.detach().clone()
  32. )
  33. self.sadcl_emb_diff = nn.parameter.Parameter(
  34. torch.zeros_like(pretrained_emb)
  35. )
  36. self.sadcl_mlp = nn.Sequential(
  37. nn.Linear(pretrained_emb.size(0), mlp_size),
  38. nn.ReLU(),
  39. nn.Linear(mlp_size, pretrained_emb.size(0)),
  40. nn.LayerNorm(pretrained_emb.size(0))
  41. )
  42. def forward(self):
  43. input_prompt = self.pretrained_emb + self.sadcl_emb_diff
  44. return input_prompt + self.sadcl_mlp(input_prompt)
  45. class SingleCombPrompt(nn.Module):
  46. def __init__(self, pretrained_embs, softmax=False, use_pretrained_mode='simple', tempreture=1.0):
  47. super().__init__()
  48. self.sadcl_coeff = nn.parameter.Parameter(
  49. torch.FloatTensor(pretrained_embs.size(0)).uniform_(-0.5, 0.5) # maybe another init
  50. )
  51. self.pretrained_embs = nn.parameter.Parameter(
  52. pretrained_embs.detach().clone()
  53. )
  54. self.sadcl_embs_diff = nn.parameter.Parameter(
  55. torch.zeros_like(pretrained_embs)
  56. )
  57. self.use_pretrained = False
  58. self.softmax = softmax
  59. assert use_pretrained_mode in ['simple', 'gumbal', 'softmax']
  60. self.use_pretrained_mode = use_pretrained_mode
  61. self.tempreture = tempreture
  62. def use_pretrained_tokens(self, new_tokens):
  63. assert new_tokens.shape[1] == self.pretrained_embs.data.shape[1]
  64. self.use_pretrained = True
  65. self.pretrained_tokens = nn.parameter.Parameter(
  66. new_tokens.detach().clone()
  67. )
  68. if self.use_pretrained_mode == 'simple':
  69. self.sadcl_coeff_pretrained = nn.parameter.Parameter(
  70. torch.full(size=(new_tokens.size(0),), fill_value=0.5)
  71. )
  72. elif self.use_pretrained_mode == 'gumbal':
  73. self.sadcl_coeff_pretrained = GumbalSwitch(new_tokens.shape[0])
  74. elif self.use_pretrained_mode == 'softmax':
  75. self.sadcl_coeff_pretrained = nn.parameter.Parameter(
  76. torch.full(size=(new_tokens.size(0),), fill_value=1.)
  77. )
  78. def get_pretrained_coeff(self):
  79. assert self.use_pretrained
  80. if self.use_pretrained_mode == 'simple':
  81. return self.sadcl_coeff_pretrained
  82. elif self.use_pretrained_mode == 'gumbal':
  83. return self.sadcl_coeff_pretrained()
  84. elif self.use_pretrained_mode == 'softmax':
  85. return torch.softmax(self.sadcl_coeff_pretrained / self.tempreture, dim=0)
  86. def forward(self):
  87. coeff = self.sadcl_coeff
  88. mat = (self.pretrained_embs + self.sadcl_embs_diff)
  89. if self.use_pretrained:
  90. coeff = torch.cat(
  91. (
  92. coeff,
  93. self.get_pretrained_coeff()
  94. ), dim=0
  95. )
  96. mat = torch.cat(
  97. (mat, self.pretrained_tokens), dim=0
  98. )
  99. if self.softmax:
  100. assert (not self.use_pretrained), 'This feature is not compatible with use_pretrained'
  101. coeff = torch.nn.functional.softmax(coeff, dim=0)
  102. return coeff @ mat