|
1234567891011121314151617181920212223242526272829303132333435363738394041424344 |
- import torch
- import torch.nn as nn
- from transformers.models.t5.modeling_t5 import T5LayerFF
-
-
- class AdapterLayer(nn.Module):
- def __init__(
- self,
- emb_dim: int,
- bottleneck_size: int
- ):
-
- super().__init__()
-
- self.sadcl_adapter = nn.Sequential(
- nn.Linear(emb_dim, bottleneck_size),
- nn.ReLU(),
- nn.Linear(bottleneck_size, emb_dim)
- )
-
- def forward(self, x: torch.Tensor):
- return x + self.sharif_llm_adapter(x)
-
- class FeedForwardAdapterWrapper(nn.Module):
- def __init__(
- self,
- original_module: T5LayerFF,
- bottleneck_size: int
- ):
-
- super().__init__()
-
- assert isinstance(original_module, T5LayerFF)
-
- self.original_module = original_module
-
- emb_dim = original_module.DenseReluDense.wi.in_features
-
- self.adapter = AdapterLayer(emb_dim, bottleneck_size)
-
- def forward(self, x: torch.Tensor):
- output = self.original_module(x)
- output = self.adapter(output)
- return output
|