import json from pathlib import Path from typing import Optional, List import numpy as np import torch import torch.nn as nn from _trainer.loss_hooks import add_to_loss_hooks from .single_prompt import SingleCombPrompt, SingleResidualPrompt, SingleSimplePrompt, SingleSuperSimplePrompt class MultiPrompt(nn.Module): def __init__(self, n_tokens, selected_embs, kind: str, shared_weights: bool = False, pretrained: Optional[torch.Tensor] = None, **kwargs): ####### Kind in [simple, super_simple, residual] # selected_embs.shape == n_tokens, emb_dim # pretrained.shape == 1, n_tokens, emb_dim ####### Kind == combine # selected_embs.shape == super_pos_m, emb_dim for combine # pretrained.shape == pretrained_task_count, n_tokens, emb_dim super().__init__() self._constructed_configs = { 'n_tokens': n_tokens, 'selected_embs.shape': selected_embs.shape, 'kind': kind, 'shared_weights': shared_weights, **kwargs } self.n_tokens = n_tokens self.emb_dim = selected_embs.size(1) prompt_constructor = { 'simple': lambda idx, selected_embs: SingleSimplePrompt(selected_embs[idx], **kwargs), 'spot': lambda idx, selected_embs: SingleSuperSimplePrompt(selected_embs[idx], **kwargs), 'residual': lambda idx, selected_embs: SingleResidualPrompt(selected_embs[idx], **kwargs), 'combine': lambda ـ, selected_embs: SingleCombPrompt(selected_embs, **kwargs), }[kind] self.prompts = nn.ModuleList([ prompt_constructor(idx, selected_embs) for idx in range(n_tokens) ]) if shared_weights: if kind == 'combine': for module in self.prompts: module.sadcl_embs_diff = self.prompts[0].sadcl_embs_diff elif kind == 'residual': for module in self.prompts: module.sadcl_mlp = self.prompts[0].sadcl_mlp else: raise NotImplementedError() if pretrained is not None: self._constructed_configs['pretrained.shape'] = pretrained.shape assert pretrained.shape[1:] == (self.n_tokens, self.emb_dim) for idx, module in enumerate(self.prompts): self.prompts[idx].use_pretrained_tokens(pretrained[:, idx, :]) if kind == 'combine': for prompt in self.prompts[1:]: prompt.sadcl_coeff_pretrained = self.prompts[0].sadcl_coeff_pretrained # l1 loss # add_to_loss_hooks(self.prompts[0].loss_hook_coeff_pretrained) @classmethod def from_config(cls, config): selected_embs = torch.zeros(*config.pop('selected_embs.shape')) pretrained = None if 'pretrained.shape' in config: pretrained = torch.zeros(*config.pop('pretrained.shape')) return cls(selected_embs=selected_embs, pretrained=pretrained, **config) @classmethod def get_saved_final_emb(cls, config_path, weights_path): with open(config_path, 'r') as f: config = json.load(f) temp_multi_prompt = cls.from_config(config['peft_config']) temp_multi_prompt.load_state_dict(torch.load(weights_path, map_location='cpu')) with torch.no_grad(): embs = temp_multi_prompt().detach() # embs.shape == n_tokens, emb_dim return embs def forward(self): out = torch.stack([ prompt() for prompt in self.prompts ], dim=0) assert out.shape == (self.n_tokens, self.emb_dim) return out