import json from pathlib import Path from typing import Optional, List import numpy as np import torch import torch.nn as nn from .single_prompt import SingleCombPrompt, SingleResidualPrompt, SingleSimplePrompt class AttemptAttention(nn.Module): def __init__(self, emb_dim, g_bottleneck, temperature): super().__init__() self.g_network = nn.Sequential( nn.Linear(emb_dim, g_bottleneck, bias=False), nn.SiLU(), nn.Linear(g_bottleneck, emb_dim, bias=False), nn.LayerNorm(emb_dim) ) self.temperature = temperature def forward(self, x_hat, p_hats): # x_hat.shape == batch_size, emb_dim # p_hats.shape == (pretrained_tasks + 1), emb_dim batch_size = x_hat.shape[0] p_hats_batched = p_hats.repeat(batch_size, 1, 1) # p_hats_batched.shape == batch_size, (pretrained_tasks + 1), emb_dim h_out = self.g_network(x_hat) powers = torch.bmm(p_hats_batched, h_out[:, :, None]) / self.temperature # powers.shape == batch_size, (pretrained_tasks + 1), 1 attention_weights = torch.softmax(powers[:, :, 0], dim=1) # attention_weights.shape == batch_size, (pretrained_tasks + 1) return attention_weights class Attempt(nn.Module): def __init__(self, selected_embs, pretrained, g_bottleneck, kind): # selected_embs.shape == n_tokens, emb_dim # pretrained.shape == pretrained_tasks, n_tokens, emb_dim super().__init__() assert selected_embs.shape == pretrained.shape[1:] self._constructed_configs = { 'kind': kind, 'selected_embs.shape': selected_embs.shape, 'pretrained.shape': pretrained.shape, 'g_bottleneck': g_bottleneck } self.sadcl_p_target = nn.parameter.Parameter( selected_embs.detach().clone() ) self.pretrained_tasks = nn.parameter.Parameter( pretrained.detach().clone() ) self.sadcl_attention_score = AttemptAttention( emb_dim=selected_embs.shape[1], g_bottleneck=g_bottleneck, temperature=selected_embs.shape[1] * 2.71828 # e number ) def forward(self, x_inp, prompt_mask): # x_inp.shape == batch_size, seq_len, emb_dim # prompt_mask.shape == batch_size, seq_len ------- 1 when token is prompt o.w. 0 prompt_mask = torch.zeros_like(prompt_mask, dtype=torch.float).masked_fill_(prompt_mask, float('-Inf')) x_inp = x_inp + prompt_mask[:, :, None] x_hat = x_inp.max(axis=1).values # x_hat.shape == batch_size, emb_dim all_prompts = torch.cat(( self.pretrained_tasks, self.sadcl_p_target[None, :, :] ),dim=0) # all_prompts.shape == (pretrained_tasks + 1), n_tokens, emb_dim p_hats = all_prompts.max(axis=1).values # p_hats.shape == (pretrained_tasks + 1), emb_dim attention_weights = self.sadcl_attention_score(x_hat=x_hat, p_hats=p_hats) # attention_weights.shape == batch_size, (pretrained_tasks + 1) all_prompts_weighted = all_prompts[None, :, :, :] * attention_weights[:, :, None, None] # all_prompts_weighted.shape == batch_size, (pretrained_tasks + 1), n_tokens, emb_dim prompts = all_prompts_weighted.sum(axis=1) # prompts.shape == batch_size, n_tokens, emb_dim return prompts