import torch import torch.nn as nn from transformers.models.t5.modeling_t5 import T5LayerFF class AdapterLayer(nn.Module): def __init__( self, emb_dim: int, bottleneck_size: int ): super().__init__() self.sadcl_adapter = nn.Sequential( nn.Linear(emb_dim, bottleneck_size), nn.ReLU(), nn.Linear(bottleneck_size, emb_dim) ) def forward(self, x: torch.Tensor): return x + self.sharif_llm_adapter(x) class FeedForwardAdapterWrapper(nn.Module): def __init__( self, original_module: T5LayerFF, bottleneck_size: int ): super().__init__() assert isinstance(original_module, T5LayerFF) self.original_module = original_module emb_dim = original_module.DenseReluDense.wi.in_features self.adapter = AdapterLayer(emb_dim, bottleneck_size) def forward(self, x: torch.Tensor): output = self.original_module(x) output = self.adapter(output) return output