You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

adapter.py 1.0KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. import torch
  2. import torch.nn as nn
  3. from transformers.models.t5.modeling_t5 import T5LayerFF
  4. class AdapterLayer(nn.Module):
  5. def __init__(
  6. self,
  7. emb_dim: int,
  8. bottleneck_size: int
  9. ):
  10. super().__init__()
  11. self.sadcl_adapter = nn.Sequential(
  12. nn.Linear(emb_dim, bottleneck_size),
  13. nn.ReLU(),
  14. nn.Linear(bottleneck_size, emb_dim)
  15. )
  16. def forward(self, x: torch.Tensor):
  17. return x + self.sharif_llm_adapter(x)
  18. class FeedForwardAdapterWrapper(nn.Module):
  19. def __init__(
  20. self,
  21. original_module: T5LayerFF,
  22. bottleneck_size: int
  23. ):
  24. super().__init__()
  25. assert isinstance(original_module, T5LayerFF)
  26. self.original_module = original_module
  27. emb_dim = original_module.DenseReluDense.wi.in_features
  28. self.adapter = AdapterLayer(emb_dim, bottleneck_size)
  29. def forward(self, x: torch.Tensor):
  30. output = self.original_module(x)
  31. output = self.adapter(output)
  32. return output