You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

emb_wrapper.py 3.8KB

3 months ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. from pathlib import Path
  2. from typing import Optional, List
  3. import torch
  4. import torch.nn as nn
  5. import numpy as np
  6. from .multi_prompt import MultiPrompt
  7. from .attempt import Attempt
  8. def _prompts_joiner(prompts, input_embedding):
  9. batch_size = input_embedding.size(0)
  10. if len(prompts.shape) == 3:
  11. prompts_batched = prompts
  12. else:
  13. prompts_batched = prompts.repeat(batch_size, 1, 1) # (batch_size, n_tokens, emb_dim)
  14. n_tokens = prompts_batched.size(1)
  15. return torch.cat([prompts_batched, input_embedding[:, n_tokens:]], dim=1)
  16. class EmbeddingWrapper(nn.Module):
  17. def __init__(
  18. self,
  19. emb_layer: nn.Embedding,
  20. n_tokens: int,
  21. n_comb_tokens: Optional[int] = None,
  22. radnom_init: bool = False,
  23. pretrained_paths: Optional[List[str]] = None,
  24. pad_token_id: int = 0, # todo!
  25. **kwargs
  26. ):
  27. super().__init__()
  28. self.emb_layer = emb_layer
  29. self.kind = kwargs['kind']
  30. self.pad_token_id = pad_token_id
  31. if self.kind == 'combine':
  32. slected_tokens_size = (n_comb_tokens,)
  33. elif self.kind in ['residual', 'simple', 'spot', 'attempt']:
  34. slected_tokens_size = (n_tokens,)
  35. else:
  36. raise NotImplementedError()
  37. selected_embs=self._generate_embs(slected_tokens_size, radnom_init)
  38. pretrained=self._generate_pretrained(pretrained_paths)
  39. if self.kind in ['combine', 'residual', 'simple', 'spot']:
  40. self.soft_prompts = MultiPrompt(
  41. n_tokens=n_tokens,
  42. selected_embs=selected_embs,
  43. pretrained=pretrained,
  44. **kwargs
  45. )
  46. elif self.kind == 'attempt':
  47. self.soft_prompts = Attempt(
  48. selected_embs=selected_embs,
  49. pretrained=pretrained,
  50. **kwargs
  51. )
  52. else:
  53. raise NotImplementedError()
  54. def _generate_pretrained(self, pretrained_paths):
  55. if pretrained_paths is None or len(pretrained_paths) == 0:
  56. return None
  57. pretrained = torch.stack([
  58. MultiPrompt.get_saved_final_emb(
  59. config_path=Path(path) / 'config.json',
  60. weights_path=Path(path) / 'best.pt'
  61. ) for path in pretrained_paths
  62. ], dim=0)
  63. return pretrained
  64. def _generate_embs(self, size, radnom_init):
  65. if radnom_init:
  66. size = size + (self.emb_layer.embedding_dim,)
  67. mean = self.emb_layer.weight.ravel().detach().numpy().mean()
  68. std_dev = self.emb_layer.weight.ravel().detach().numpy().std()
  69. return torch.FloatTensor(*size).normal_(mean=mean, std=std_dev)
  70. # return torch.FloatTensor(*size).uniform_(-1, 1)
  71. else:
  72. slected_tokens = torch.from_numpy(
  73. np.random.choice(
  74. self.emb_layer.num_embeddings,
  75. size=size,
  76. replace=False
  77. )
  78. )
  79. return self.emb_layer(slected_tokens)
  80. def forward(self, tokens):
  81. input_embedding = self.emb_layer(tokens)
  82. if self.kind == 'attempt':
  83. prompts = self.soft_prompts(
  84. x_inp=input_embedding,
  85. prompt_mask=(tokens == self.pad_token_id)
  86. )
  87. else:
  88. prompts = self.soft_prompts()
  89. return _prompts_joiner(prompts, input_embedding)
  90. def peft_state_dict(self):
  91. return self.soft_prompts.state_dict()
  92. def peft_config(self):
  93. return self.soft_prompts._constructed_configs
  94. def load_peft(self, config, state_dict):
  95. self.soft_prompts = MultiPrompt.from_config(config)
  96. self.soft_prompts.load_state_dict(state_dict)