from pathlib import Path from typing import Optional, List import torch import torch.nn as nn import numpy as np from .multi_prompt import MultiPrompt from .attempt import Attempt def _prompts_joiner(prompts, input_embedding): batch_size = input_embedding.size(0) if len(prompts.shape) == 3: prompts_batched = prompts else: prompts_batched = prompts.repeat(batch_size, 1, 1) # (batch_size, n_tokens, emb_dim) n_tokens = prompts_batched.size(1) return torch.cat([prompts_batched, input_embedding[:, n_tokens:]], dim=1) class EmbeddingWrapper(nn.Module): def __init__( self, emb_layer: nn.Embedding, n_tokens: int, n_comb_tokens: Optional[int] = None, radnom_init: bool = False, pretrained_paths: Optional[List[str]] = None, pad_token_id: int = 0, # todo! **kwargs ): super().__init__() self.emb_layer = emb_layer self.kind = kwargs['kind'] self.pad_token_id = pad_token_id if self.kind == 'combine': slected_tokens_size = (n_comb_tokens,) elif self.kind in ['residual', 'simple', 'spot', 'attempt']: slected_tokens_size = (n_tokens,) else: raise NotImplementedError() selected_embs=self._generate_embs(slected_tokens_size, radnom_init) pretrained=self._generate_pretrained(pretrained_paths) if self.kind in ['combine', 'residual', 'simple', 'spot']: self.soft_prompts = MultiPrompt( n_tokens=n_tokens, selected_embs=selected_embs, pretrained=pretrained, **kwargs ) elif self.kind == 'attempt': self.soft_prompts = Attempt( selected_embs=selected_embs, pretrained=pretrained, **kwargs ) else: raise NotImplementedError() def _generate_pretrained(self, pretrained_paths): if pretrained_paths is None or len(pretrained_paths) == 0: return None pretrained = torch.stack([ MultiPrompt.get_saved_final_emb( config_path=Path(path) / 'config.json', weights_path=Path(path) / 'best.pt' ) for path in pretrained_paths ], dim=0) return pretrained def _generate_embs(self, size, radnom_init): if radnom_init: size = size + (self.emb_layer.embedding_dim,) mean = self.emb_layer.weight.ravel().detach().numpy().mean() std_dev = self.emb_layer.weight.ravel().detach().numpy().std() return torch.FloatTensor(*size).normal_(mean=mean, std=std_dev) # return torch.FloatTensor(*size).uniform_(-1, 1) else: slected_tokens = torch.from_numpy( np.random.choice( self.emb_layer.num_embeddings, size=size, replace=False ) ) return self.emb_layer(slected_tokens) def forward(self, tokens): input_embedding = self.emb_layer(tokens) if self.kind == 'attempt': prompts = self.soft_prompts( x_inp=input_embedding, prompt_mask=(tokens == self.pad_token_id) ) else: prompts = self.soft_prompts() return _prompts_joiner(prompts, input_embedding) def peft_state_dict(self): return self.soft_prompts.state_dict() def peft_config(self): return self.soft_prompts._constructed_configs def load_peft(self, config, state_dict): self.soft_prompts = MultiPrompt.from_config(config) self.soft_prompts.load_state_dict(state_dict)