You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

attempt.py 3.5KB

3 months ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import json
  2. from pathlib import Path
  3. from typing import Optional, List
  4. import numpy as np
  5. import torch
  6. import torch.nn as nn
  7. from .single_prompt import SingleCombPrompt, SingleResidualPrompt, SingleSimplePrompt
  8. class AttemptAttention(nn.Module):
  9. def __init__(self, emb_dim, g_bottleneck, temperature):
  10. super().__init__()
  11. self.g_network = nn.Sequential(
  12. nn.Linear(emb_dim, g_bottleneck, bias=False),
  13. nn.SiLU(),
  14. nn.Linear(g_bottleneck, emb_dim, bias=False),
  15. nn.LayerNorm(emb_dim)
  16. )
  17. self.temperature = temperature
  18. def forward(self, x_hat, p_hats):
  19. # x_hat.shape == batch_size, emb_dim
  20. # p_hats.shape == (pretrained_tasks + 1), emb_dim
  21. batch_size = x_hat.shape[0]
  22. p_hats_batched = p_hats.repeat(batch_size, 1, 1)
  23. # p_hats_batched.shape == batch_size, (pretrained_tasks + 1), emb_dim
  24. h_out = self.g_network(x_hat)
  25. powers = torch.bmm(p_hats_batched, h_out[:, :, None]) / self.temperature
  26. # powers.shape == batch_size, (pretrained_tasks + 1), 1
  27. attention_weights = torch.softmax(powers[:, :, 0], dim=1)
  28. # attention_weights.shape == batch_size, (pretrained_tasks + 1)
  29. return attention_weights
  30. class Attempt(nn.Module):
  31. def __init__(self, selected_embs, pretrained, g_bottleneck, kind):
  32. # selected_embs.shape == n_tokens, emb_dim
  33. # pretrained.shape == pretrained_tasks, n_tokens, emb_dim
  34. super().__init__()
  35. assert selected_embs.shape == pretrained.shape[1:]
  36. self._constructed_configs = {
  37. 'kind': kind,
  38. 'selected_embs.shape': selected_embs.shape,
  39. 'pretrained.shape': pretrained.shape,
  40. 'g_bottleneck': g_bottleneck
  41. }
  42. self.sadcl_p_target = nn.parameter.Parameter(
  43. selected_embs.detach().clone()
  44. )
  45. self.pretrained_tasks = nn.parameter.Parameter(
  46. pretrained.detach().clone()
  47. )
  48. self.sadcl_attention_score = AttemptAttention(
  49. emb_dim=selected_embs.shape[1],
  50. g_bottleneck=g_bottleneck,
  51. temperature=selected_embs.shape[1] * 2.71828 # e number
  52. )
  53. def forward(self, x_inp, prompt_mask):
  54. # x_inp.shape == batch_size, seq_len, emb_dim
  55. # prompt_mask.shape == batch_size, seq_len ------- 1 when token is prompt o.w. 0
  56. prompt_mask = torch.zeros_like(prompt_mask, dtype=torch.float).masked_fill_(prompt_mask, float('-Inf'))
  57. x_inp = x_inp + prompt_mask[:, :, None]
  58. x_hat = x_inp.max(axis=1).values
  59. # x_hat.shape == batch_size, emb_dim
  60. all_prompts = torch.cat((
  61. self.pretrained_tasks,
  62. self.sadcl_p_target[None, :, :]
  63. ),dim=0)
  64. # all_prompts.shape == (pretrained_tasks + 1), n_tokens, emb_dim
  65. p_hats = all_prompts.max(axis=1).values
  66. # p_hats.shape == (pretrained_tasks + 1), emb_dim
  67. attention_weights = self.sadcl_attention_score(x_hat=x_hat, p_hats=p_hats)
  68. # attention_weights.shape == batch_size, (pretrained_tasks + 1)
  69. all_prompts_weighted = all_prompts[None, :, :, :] * attention_weights[:, :, None, None]
  70. # all_prompts_weighted.shape == batch_size, (pretrained_tasks + 1), n_tokens, emb_dim
  71. prompts = all_prompts_weighted.sum(axis=1)
  72. # prompts.shape == batch_size, n_tokens, emb_dim
  73. return prompts