|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134 |
- import torch
- import torch.nn as nn
-
- from .gumbal_switch import GumbalSwitch
-
- class SingleSuperSimplePrompt(nn.Module):
- def __init__(self, pretrained_emb):
- super().__init__()
-
- self.sadcl_prompt = nn.parameter.Parameter(
- pretrained_emb.detach().clone()
- )
-
- def forward(self):
- return self.sadcl_prompt
-
- def use_pretrained_tokens(self, new_tokens):
- assert new_tokens.shape[0] == 1
- assert new_tokens.shape[1] == self.sadcl_prompt.data.shape[0]
- self.sadcl_prompt.data = new_tokens[0].detach().clone()
-
- class SingleSimplePrompt(nn.Module):
- def __init__(self, pretrained_emb):
- super().__init__()
-
- self.pretrained_emb = nn.parameter.Parameter(
- pretrained_emb.detach().clone()
- )
-
- self.sadcl_emb_diff = nn.parameter.Parameter(
- torch.zeros_like(pretrained_emb)
- )
-
- def forward(self):
- return self.pretrained_emb + self.sadcl_emb_diff
-
- class SingleResidualPrompt(nn.Module):
- def __init__(self, pretrained_emb, mlp_size):
- super().__init__()
-
- self.pretrained_emb = nn.parameter.Parameter(
- pretrained_emb.detach().clone()
- )
-
- self.sadcl_emb_diff = nn.parameter.Parameter(
- torch.zeros_like(pretrained_emb)
- )
-
- self.sadcl_mlp = nn.Sequential(
- nn.Linear(pretrained_emb.size(0), mlp_size),
- nn.ReLU(),
- nn.Linear(mlp_size, pretrained_emb.size(0)),
- nn.LayerNorm(pretrained_emb.size(0))
- )
-
- def forward(self):
- input_prompt = self.pretrained_emb + self.sadcl_emb_diff
- return input_prompt + self.sadcl_mlp(input_prompt)
-
-
- class SingleCombPrompt(nn.Module):
- def __init__(self, pretrained_embs, softmax=False, use_pretrained_mode='simple', tempreture=1.0):
- super().__init__()
-
- self.sadcl_coeff = nn.parameter.Parameter(
- torch.FloatTensor(pretrained_embs.size(0)).uniform_(-0.5, 0.5) # maybe another init
- )
-
- self.pretrained_embs = nn.parameter.Parameter(
- pretrained_embs.detach().clone()
- )
-
- self.sadcl_embs_diff = nn.parameter.Parameter(
- torch.zeros_like(pretrained_embs)
- )
- self.use_pretrained = False
-
- self.softmax = softmax
-
- assert use_pretrained_mode in ['simple', 'gumbal', 'softmax']
- self.use_pretrained_mode = use_pretrained_mode
-
- self.tempreture = tempreture
-
- def use_pretrained_tokens(self, new_tokens):
- assert new_tokens.shape[1] == self.pretrained_embs.data.shape[1]
-
- self.use_pretrained = True
- self.pretrained_tokens = nn.parameter.Parameter(
- new_tokens.detach().clone()
- )
-
- if self.use_pretrained_mode == 'simple':
- self.sadcl_coeff_pretrained = nn.parameter.Parameter(
- torch.full(size=(new_tokens.size(0),), fill_value=0.5)
- )
- elif self.use_pretrained_mode == 'gumbal':
- self.sadcl_coeff_pretrained = GumbalSwitch(new_tokens.shape[0])
- elif self.use_pretrained_mode == 'softmax':
- self.sadcl_coeff_pretrained = nn.parameter.Parameter(
- torch.full(size=(new_tokens.size(0),), fill_value=1.)
- )
-
- def get_pretrained_coeff(self):
- assert self.use_pretrained
-
- if self.use_pretrained_mode == 'simple':
- return self.sadcl_coeff_pretrained
- elif self.use_pretrained_mode == 'gumbal':
- return self.sadcl_coeff_pretrained()
- elif self.use_pretrained_mode == 'softmax':
- return torch.softmax(self.sadcl_coeff_pretrained / self.tempreture, dim=0)
-
- def forward(self):
- coeff = self.sadcl_coeff
- mat = (self.pretrained_embs + self.sadcl_embs_diff)
-
- if self.use_pretrained:
- coeff = torch.cat(
- (
- coeff,
- self.get_pretrained_coeff()
- ), dim=0
- )
- mat = torch.cat(
- (mat, self.pretrained_tokens), dim=0
- )
-
- if self.softmax:
- assert (not self.use_pretrained), 'This feature is not compatible with use_pretrained'
- coeff = torch.nn.functional.softmax(coeff, dim=0)
-
- return coeff @ mat
-
|