12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091 |
- import json
- from pathlib import Path
- from typing import Optional, List
-
- import numpy as np
-
- import torch
- import torch.nn as nn
-
- from .single_prompt import SingleCombPrompt, SingleResidualPrompt, SingleSimplePrompt
-
- class AttemptAttention(nn.Module):
- def __init__(self, emb_dim, g_bottleneck, temperature):
- super().__init__()
-
- self.g_network = nn.Sequential(
- nn.Linear(emb_dim, g_bottleneck, bias=False),
- nn.SiLU(),
- nn.Linear(g_bottleneck, emb_dim, bias=False),
- nn.LayerNorm(emb_dim)
- )
- self.temperature = temperature
-
- def forward(self, x_hat, p_hats):
- # x_hat.shape == batch_size, emb_dim
- # p_hats.shape == (pretrained_tasks + 1), emb_dim
- batch_size = x_hat.shape[0]
- p_hats_batched = p_hats.repeat(batch_size, 1, 1)
- # p_hats_batched.shape == batch_size, (pretrained_tasks + 1), emb_dim
-
- h_out = self.g_network(x_hat)
- powers = torch.bmm(p_hats_batched, h_out[:, :, None]) / self.temperature
- # powers.shape == batch_size, (pretrained_tasks + 1), 1
- attention_weights = torch.softmax(powers[:, :, 0], dim=1)
- # attention_weights.shape == batch_size, (pretrained_tasks + 1)
- return attention_weights
-
-
- class Attempt(nn.Module):
- def __init__(self, selected_embs, pretrained, g_bottleneck, kind):
- # selected_embs.shape == n_tokens, emb_dim
- # pretrained.shape == pretrained_tasks, n_tokens, emb_dim
-
- super().__init__()
-
- assert selected_embs.shape == pretrained.shape[1:]
-
- self._constructed_configs = {
- 'kind': kind,
- 'selected_embs.shape': selected_embs.shape,
- 'pretrained.shape': pretrained.shape,
- 'g_bottleneck': g_bottleneck
- }
-
- self.sadcl_p_target = nn.parameter.Parameter(
- selected_embs.detach().clone()
- )
- self.pretrained_tasks = nn.parameter.Parameter(
- pretrained.detach().clone()
- )
- self.sadcl_attention_score = AttemptAttention(
- emb_dim=selected_embs.shape[1],
- g_bottleneck=g_bottleneck,
- temperature=selected_embs.shape[1] * 2.71828 # e number
- )
-
- def forward(self, x_inp, prompt_mask):
- # x_inp.shape == batch_size, seq_len, emb_dim
- # prompt_mask.shape == batch_size, seq_len ------- 1 when token is prompt o.w. 0
- prompt_mask = torch.zeros_like(prompt_mask, dtype=torch.float).masked_fill_(prompt_mask, float('-Inf'))
- x_inp = x_inp + prompt_mask[:, :, None]
- x_hat = x_inp.max(axis=1).values
- # x_hat.shape == batch_size, emb_dim
- all_prompts = torch.cat((
- self.pretrained_tasks,
- self.sadcl_p_target[None, :, :]
- ),dim=0)
- # all_prompts.shape == (pretrained_tasks + 1), n_tokens, emb_dim
- p_hats = all_prompts.max(axis=1).values
- # p_hats.shape == (pretrained_tasks + 1), emb_dim
-
- attention_weights = self.sadcl_attention_score(x_hat=x_hat, p_hats=p_hats)
- # attention_weights.shape == batch_size, (pretrained_tasks + 1)
-
- all_prompts_weighted = all_prompts[None, :, :, :] * attention_weights[:, :, None, None]
- # all_prompts_weighted.shape == batch_size, (pretrained_tasks + 1), n_tokens, emb_dim
-
- prompts = all_prompts_weighted.sum(axis=1)
- # prompts.shape == batch_size, n_tokens, emb_dim
- return prompts
-
|