You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

auto_mutate.py 1.5KB

3 months ago
1234567891011121314151617181920212223242526272829303132333435363738
  1. from .emb_wrapper import EmbeddingWrapper
  2. from .mutate_forward import mutate_remove_dropout
  3. def _mutate_comb_prompt(emb_layer, **kwargs):
  4. return EmbeddingWrapper(emb_layer=emb_layer, **kwargs)
  5. def auto_mutate(model, tokenizer, peft_params, remove_dropout: bool):
  6. if model._is_seq2seq:
  7. delta_module = _mutate_comb_prompt(model.get_encoder().get_input_embeddings(), **peft_params)
  8. model.get_encoder().set_input_embeddings(delta_module)
  9. else:
  10. delta_module = _mutate_comb_prompt(model.get_input_embeddings(), **peft_params)
  11. model.set_input_embeddings(delta_module)
  12. # mutate_forward(model, peft_params.get('n_tokens'), just_place_holder=False)
  13. if remove_dropout:
  14. mutate_remove_dropout(model)
  15. model._delta_module = delta_module
  16. return delta_module
  17. # temp = MultiCombPrompt(
  18. # n_tokens=config.peft_params.n_tokens,
  19. # selected_embs=torch.zeros(128, 768),
  20. # shared_diff=False
  21. # )
  22. # state_dict = torch.load('/disks/ssd/trained_extensive_test_l2.01_for_real/base_10_128/best.pt')
  23. # state_dict = {key.replace('comb_prompts.comb_prompts', 'comb_prompts'): val for (key, val) in state_dict.items()}
  24. # temp.load_state_dict(state_dict)
  25. # embs = temp()
  26. # print(embs.shape)
  27. # for idx, module in enumerate(delta_module.soft_prompts.comb_prompts.comb_prompts):
  28. # module.sadcl_coeff.data[0] = 1
  29. # module.pretrained_embs.data[0] = embs[idx]