from .emb_wrapper import EmbeddingWrapper from .mutate_forward import mutate_remove_dropout def _mutate_comb_prompt(emb_layer, **kwargs): return EmbeddingWrapper(emb_layer=emb_layer, **kwargs) def auto_mutate(model, tokenizer, peft_params, remove_dropout: bool): if model._is_seq2seq: delta_module = _mutate_comb_prompt(model.get_encoder().get_input_embeddings(), **peft_params) model.get_encoder().set_input_embeddings(delta_module) else: delta_module = _mutate_comb_prompt(model.get_input_embeddings(), **peft_params) model.set_input_embeddings(delta_module) # mutate_forward(model, peft_params.get('n_tokens'), just_place_holder=False) if remove_dropout: mutate_remove_dropout(model) model._delta_module = delta_module return delta_module # temp = MultiCombPrompt( # n_tokens=config.peft_params.n_tokens, # selected_embs=torch.zeros(128, 768), # shared_diff=False # ) # state_dict = torch.load('/disks/ssd/trained_extensive_test_l2.01_for_real/base_10_128/best.pt') # state_dict = {key.replace('comb_prompts.comb_prompts', 'comb_prompts'): val for (key, val) in state_dict.items()} # temp.load_state_dict(state_dict) # embs = temp() # print(embs.shape) # for idx, module in enumerate(delta_module.soft_prompts.comb_prompts.comb_prompts): # module.sadcl_coeff.data[0] = 1 # module.pretrained_embs.data[0] = embs[idx]