123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111 |
- from pathlib import Path
- from typing import Optional, List
-
- import torch
- import torch.nn as nn
-
- import numpy as np
-
- from .multi_prompt import MultiPrompt
- from .attempt import Attempt
-
- def _prompts_joiner(prompts, input_embedding):
- batch_size = input_embedding.size(0)
- if len(prompts.shape) == 3:
- prompts_batched = prompts
- else:
- prompts_batched = prompts.repeat(batch_size, 1, 1) # (batch_size, n_tokens, emb_dim)
- n_tokens = prompts_batched.size(1)
- return torch.cat([prompts_batched, input_embedding[:, n_tokens:]], dim=1)
-
- class EmbeddingWrapper(nn.Module):
- def __init__(
- self,
- emb_layer: nn.Embedding,
- n_tokens: int,
- n_comb_tokens: Optional[int] = None,
- radnom_init: bool = False,
- pretrained_paths: Optional[List[str]] = None,
- pad_token_id: int = 0, # todo!
- **kwargs
- ):
- super().__init__()
- self.emb_layer = emb_layer
- self.kind = kwargs['kind']
- self.pad_token_id = pad_token_id
-
- if self.kind == 'combine':
- slected_tokens_size = (n_comb_tokens,)
- elif self.kind in ['residual', 'simple', 'spot', 'attempt']:
- slected_tokens_size = (n_tokens,)
- else:
- raise NotImplementedError()
-
- selected_embs=self._generate_embs(slected_tokens_size, radnom_init)
- pretrained=self._generate_pretrained(pretrained_paths)
-
- if self.kind in ['combine', 'residual', 'simple', 'spot']:
- self.soft_prompts = MultiPrompt(
- n_tokens=n_tokens,
- selected_embs=selected_embs,
- pretrained=pretrained,
- **kwargs
- )
- elif self.kind == 'attempt':
- self.soft_prompts = Attempt(
- selected_embs=selected_embs,
- pretrained=pretrained,
- **kwargs
- )
- else:
- raise NotImplementedError()
-
- def _generate_pretrained(self, pretrained_paths):
- if pretrained_paths is None or len(pretrained_paths) == 0:
- return None
-
- pretrained = torch.stack([
- MultiPrompt.get_saved_final_emb(
- config_path=Path(path) / 'config.json',
- weights_path=Path(path) / 'best.pt'
- ) for path in pretrained_paths
- ], dim=0)
- return pretrained
-
- def _generate_embs(self, size, radnom_init):
- if radnom_init:
- size = size + (self.emb_layer.embedding_dim,)
- mean = self.emb_layer.weight.ravel().detach().numpy().mean()
- std_dev = self.emb_layer.weight.ravel().detach().numpy().std()
- return torch.FloatTensor(*size).normal_(mean=mean, std=std_dev)
- # return torch.FloatTensor(*size).uniform_(-1, 1)
- else:
- slected_tokens = torch.from_numpy(
- np.random.choice(
- self.emb_layer.num_embeddings,
- size=size,
- replace=False
- )
- )
- return self.emb_layer(slected_tokens)
-
- def forward(self, tokens):
- input_embedding = self.emb_layer(tokens)
- if self.kind == 'attempt':
- prompts = self.soft_prompts(
- x_inp=input_embedding,
- prompt_mask=(tokens == self.pad_token_id)
- )
- else:
- prompts = self.soft_prompts()
- return _prompts_joiner(prompts, input_embedding)
-
- def peft_state_dict(self):
- return self.soft_prompts.state_dict()
-
- def peft_config(self):
- return self.soft_prompts._constructed_configs
-
- def load_peft(self, config, state_dict):
- self.soft_prompts = MultiPrompt.from_config(config)
- self.soft_prompts.load_state_dict(state_dict)
|