1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192 |
- import json
- from pathlib import Path
- from typing import Optional, List
-
- import numpy as np
-
- import torch
- import torch.nn as nn
-
- from _trainer.loss_hooks import add_to_loss_hooks
- from .single_prompt import SingleCombPrompt, SingleResidualPrompt, SingleSimplePrompt, SingleSuperSimplePrompt
-
- class MultiPrompt(nn.Module):
- def __init__(self, n_tokens, selected_embs, kind: str, shared_weights: bool = False, pretrained: Optional[torch.Tensor] = None, **kwargs):
- ####### Kind in [simple, super_simple, residual]
- # selected_embs.shape == n_tokens, emb_dim
- # pretrained.shape == 1, n_tokens, emb_dim
- ####### Kind == combine
- # selected_embs.shape == super_pos_m, emb_dim for combine
- # pretrained.shape == pretrained_task_count, n_tokens, emb_dim
- super().__init__()
-
- self._constructed_configs = {
- 'n_tokens': n_tokens,
- 'selected_embs.shape': selected_embs.shape,
- 'kind': kind,
- 'shared_weights': shared_weights,
- **kwargs
- }
-
- self.n_tokens = n_tokens
- self.emb_dim = selected_embs.size(1)
-
- prompt_constructor = {
- 'simple': lambda idx, selected_embs: SingleSimplePrompt(selected_embs[idx], **kwargs),
- 'spot': lambda idx, selected_embs: SingleSuperSimplePrompt(selected_embs[idx], **kwargs),
- 'residual': lambda idx, selected_embs: SingleResidualPrompt(selected_embs[idx], **kwargs),
- 'combine': lambda ـ, selected_embs: SingleCombPrompt(selected_embs, **kwargs),
- }[kind]
-
- self.prompts = nn.ModuleList([
- prompt_constructor(idx, selected_embs) for idx in range(n_tokens)
- ])
-
- if shared_weights:
- if kind == 'combine':
- for module in self.prompts:
- module.sadcl_embs_diff = self.prompts[0].sadcl_embs_diff
- elif kind == 'residual':
- for module in self.prompts:
- module.sadcl_mlp = self.prompts[0].sadcl_mlp
- else:
- raise NotImplementedError()
-
- if pretrained is not None:
- self._constructed_configs['pretrained.shape'] = pretrained.shape
- assert pretrained.shape[1:] == (self.n_tokens, self.emb_dim)
- for idx, module in enumerate(self.prompts):
- self.prompts[idx].use_pretrained_tokens(pretrained[:, idx, :])
- if kind == 'combine':
- for prompt in self.prompts[1:]:
- prompt.sadcl_coeff_pretrained = self.prompts[0].sadcl_coeff_pretrained
- # l1 loss
- # add_to_loss_hooks(self.prompts[0].loss_hook_coeff_pretrained)
-
- @classmethod
- def from_config(cls, config):
- selected_embs = torch.zeros(*config.pop('selected_embs.shape'))
- pretrained = None
- if 'pretrained.shape' in config:
- pretrained = torch.zeros(*config.pop('pretrained.shape'))
- return cls(selected_embs=selected_embs, pretrained=pretrained, **config)
-
- @classmethod
- def get_saved_final_emb(cls, config_path, weights_path):
- with open(config_path, 'r') as f:
- config = json.load(f)
- temp_multi_prompt = cls.from_config(config['peft_config'])
- temp_multi_prompt.load_state_dict(torch.load(weights_path, map_location='cpu'))
- with torch.no_grad():
- embs = temp_multi_prompt().detach()
- # embs.shape == n_tokens, emb_dim
- return embs
-
- def forward(self):
- out = torch.stack([
- prompt() for prompt in self.prompts
- ], dim=0)
- assert out.shape == (self.n_tokens, self.emb_dim)
- return out
-
-
|