You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

multi_prompt.py 3.6KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import json
  2. from pathlib import Path
  3. from typing import Optional, List
  4. import numpy as np
  5. import torch
  6. import torch.nn as nn
  7. from _trainer.loss_hooks import add_to_loss_hooks
  8. from .single_prompt import SingleCombPrompt, SingleResidualPrompt, SingleSimplePrompt, SingleSuperSimplePrompt
  9. class MultiPrompt(nn.Module):
  10. def __init__(self, n_tokens, selected_embs, kind: str, shared_weights: bool = False, pretrained: Optional[torch.Tensor] = None, **kwargs):
  11. ####### Kind in [simple, super_simple, residual]
  12. # selected_embs.shape == n_tokens, emb_dim
  13. # pretrained.shape == 1, n_tokens, emb_dim
  14. ####### Kind == combine
  15. # selected_embs.shape == super_pos_m, emb_dim for combine
  16. # pretrained.shape == pretrained_task_count, n_tokens, emb_dim
  17. super().__init__()
  18. self._constructed_configs = {
  19. 'n_tokens': n_tokens,
  20. 'selected_embs.shape': selected_embs.shape,
  21. 'kind': kind,
  22. 'shared_weights': shared_weights,
  23. **kwargs
  24. }
  25. self.n_tokens = n_tokens
  26. self.emb_dim = selected_embs.size(1)
  27. prompt_constructor = {
  28. 'simple': lambda idx, selected_embs: SingleSimplePrompt(selected_embs[idx], **kwargs),
  29. 'spot': lambda idx, selected_embs: SingleSuperSimplePrompt(selected_embs[idx], **kwargs),
  30. 'residual': lambda idx, selected_embs: SingleResidualPrompt(selected_embs[idx], **kwargs),
  31. 'combine': lambda ـ, selected_embs: SingleCombPrompt(selected_embs, **kwargs),
  32. }[kind]
  33. self.prompts = nn.ModuleList([
  34. prompt_constructor(idx, selected_embs) for idx in range(n_tokens)
  35. ])
  36. if shared_weights:
  37. if kind == 'combine':
  38. for module in self.prompts:
  39. module.sadcl_embs_diff = self.prompts[0].sadcl_embs_diff
  40. elif kind == 'residual':
  41. for module in self.prompts:
  42. module.sadcl_mlp = self.prompts[0].sadcl_mlp
  43. else:
  44. raise NotImplementedError()
  45. if pretrained is not None:
  46. self._constructed_configs['pretrained.shape'] = pretrained.shape
  47. assert pretrained.shape[1:] == (self.n_tokens, self.emb_dim)
  48. for idx, module in enumerate(self.prompts):
  49. self.prompts[idx].use_pretrained_tokens(pretrained[:, idx, :])
  50. if kind == 'combine':
  51. for prompt in self.prompts[1:]:
  52. prompt.sadcl_coeff_pretrained = self.prompts[0].sadcl_coeff_pretrained
  53. # l1 loss
  54. # add_to_loss_hooks(self.prompts[0].loss_hook_coeff_pretrained)
  55. @classmethod
  56. def from_config(cls, config):
  57. selected_embs = torch.zeros(*config.pop('selected_embs.shape'))
  58. pretrained = None
  59. if 'pretrained.shape' in config:
  60. pretrained = torch.zeros(*config.pop('pretrained.shape'))
  61. return cls(selected_embs=selected_embs, pretrained=pretrained, **config)
  62. @classmethod
  63. def get_saved_final_emb(cls, config_path, weights_path):
  64. with open(config_path, 'r') as f:
  65. config = json.load(f)
  66. temp_multi_prompt = cls.from_config(config['peft_config'])
  67. temp_multi_prompt.load_state_dict(torch.load(weights_path, map_location='cpu'))
  68. with torch.no_grad():
  69. embs = temp_multi_prompt().detach()
  70. # embs.shape == n_tokens, emb_dim
  71. return embs
  72. def forward(self):
  73. out = torch.stack([
  74. prompt() for prompt in self.prompts
  75. ], dim=0)
  76. assert out.shape == (self.n_tokens, self.emb_dim)
  77. return out