import torch import torch.nn as nn from .gumbal_switch import GumbalSwitch class SingleSuperSimplePrompt(nn.Module): def __init__(self, pretrained_emb): super().__init__() self.sadcl_prompt = nn.parameter.Parameter( pretrained_emb.detach().clone() ) def forward(self): return self.sadcl_prompt def use_pretrained_tokens(self, new_tokens): assert new_tokens.shape[0] == 1 assert new_tokens.shape[1] == self.sadcl_prompt.data.shape[0] self.sadcl_prompt.data = new_tokens[0].detach().clone() class SingleSimplePrompt(nn.Module): def __init__(self, pretrained_emb): super().__init__() self.pretrained_emb = nn.parameter.Parameter( pretrained_emb.detach().clone() ) self.sadcl_emb_diff = nn.parameter.Parameter( torch.zeros_like(pretrained_emb) ) def forward(self): return self.pretrained_emb + self.sadcl_emb_diff class SingleResidualPrompt(nn.Module): def __init__(self, pretrained_emb, mlp_size): super().__init__() self.pretrained_emb = nn.parameter.Parameter( pretrained_emb.detach().clone() ) self.sadcl_emb_diff = nn.parameter.Parameter( torch.zeros_like(pretrained_emb) ) self.sadcl_mlp = nn.Sequential( nn.Linear(pretrained_emb.size(0), mlp_size), nn.ReLU(), nn.Linear(mlp_size, pretrained_emb.size(0)), nn.LayerNorm(pretrained_emb.size(0)) ) def forward(self): input_prompt = self.pretrained_emb + self.sadcl_emb_diff return input_prompt + self.sadcl_mlp(input_prompt) class SingleCombPrompt(nn.Module): def __init__(self, pretrained_embs, softmax=False, use_pretrained_mode='simple', tempreture=1.0): super().__init__() self.sadcl_coeff = nn.parameter.Parameter( torch.FloatTensor(pretrained_embs.size(0)).uniform_(-0.5, 0.5) # maybe another init ) self.pretrained_embs = nn.parameter.Parameter( pretrained_embs.detach().clone() ) self.sadcl_embs_diff = nn.parameter.Parameter( torch.zeros_like(pretrained_embs) ) self.use_pretrained = False self.softmax = softmax assert use_pretrained_mode in ['simple', 'gumbal', 'softmax'] self.use_pretrained_mode = use_pretrained_mode self.tempreture = tempreture def use_pretrained_tokens(self, new_tokens): assert new_tokens.shape[1] == self.pretrained_embs.data.shape[1] self.use_pretrained = True self.pretrained_tokens = nn.parameter.Parameter( new_tokens.detach().clone() ) if self.use_pretrained_mode == 'simple': self.sadcl_coeff_pretrained = nn.parameter.Parameter( torch.full(size=(new_tokens.size(0),), fill_value=0.5) ) elif self.use_pretrained_mode == 'gumbal': self.sadcl_coeff_pretrained = GumbalSwitch(new_tokens.shape[0]) elif self.use_pretrained_mode == 'softmax': self.sadcl_coeff_pretrained = nn.parameter.Parameter( torch.full(size=(new_tokens.size(0),), fill_value=1.) ) def get_pretrained_coeff(self): assert self.use_pretrained if self.use_pretrained_mode == 'simple': return self.sadcl_coeff_pretrained elif self.use_pretrained_mode == 'gumbal': return self.sadcl_coeff_pretrained() elif self.use_pretrained_mode == 'softmax': return torch.softmax(self.sadcl_coeff_pretrained / self.tempreture, dim=0) def forward(self): coeff = self.sadcl_coeff mat = (self.pretrained_embs + self.sadcl_embs_diff) if self.use_pretrained: coeff = torch.cat( ( coeff, self.get_pretrained_coeff() ), dim=0 ) mat = torch.cat( (mat, self.pretrained_tokens), dim=0 ) if self.softmax: assert (not self.use_pretrained), 'This feature is not compatible with use_pretrained' coeff = torch.nn.functional.softmax(coeff, dim=0) return coeff @ mat