In [1]:
from types import SimpleNamespace
from typing import Optional

import torch
import torch.nn as nn

In [2]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
MODEL_NAME = 'bert-base-uncased'
NAMESPACE = 'sadcl'

NTOKENS = 10
PROMPT_PLACE = 'post'  # pre

In [6]:
def initialize_embedding(
    emb_dim: int,
    n_tokens: int, 
    random_range: float,
    initialize_from: Optional[torch.Tensor] = None
):
    if initialize_from is None:
        return torch.FloatTensor(n_tokens, emb_dim).uniform_(-random_range, random_range)

    assert initialize_from.shape == (n_tokens, )

    return initialize_from.clone().detach().tile(1, emb_dim)

class SoftEmbedding(nn.Module):
    def __init__(
        self,
        emb_dim: int,
        n_tokens: int, 
        random_range: float = 0.5,
        prompt_place: str = 'post',
        mode: str = 'cat',
        initialize_from: Optional[torch.Tensor] = None
    ):
        super().__init__()
        assert mode in ['cat', 'add']
        assert prompt_place in ['pre', 'post']
        
        self.post_tokenizer_map = {
            'input_ids': 0,
            'attention_mask': 1,
            'token_type_ids': 0
        }
        self.n_tokens = n_tokens
        self.mode = mode
        self.prompt_place = prompt_place
        
        self.sadcl_learned_embedding = nn.parameter.Parameter(
            initialize_embedding(
                emb_dim,
                n_tokens,
                random_range,
                initialize_from
            )
        )

        assert self.sadcl_learned_embedding.shape == (n_tokens, emb_dim)
            
    def forward(self, input_embedding):
        # input_embedding.shape = (batch_size, num_of_input_tokens, emb_dim)
        batch_size = input_embedding.size(0)
        if self.mode == 'cat':
            learned_embedding = self.sadcl_learned_embedding.repeat(batch_size, 1, 1)  # (batch_size, n_tokens, emb_dim)
            return self.concat_batch(input_embedding[self.get_slice_for_cat()], learned_embedding)
        else:  # mode == add
            input_embedding[self.get_slice_for_add()] += self.sadcl_learned_embedding[None, :, :]
            return input_embedding
    
    def get_weights(self):
        return self.sadcl_learned_embedding.detach().clone()
    
    def set_weights(self, new_weights: torch.Tensor):
        self.sadcl_learned_embedding.data = new_weights
    
    def get_slice_for_add(self):
        if self.prompt_place == 'pre':
            return slice(None), slice(None, self.n_tokens), slice(None)
        else:  # prompt_place == post
            return slice(None), slice(-self.n_tokens, None), slice(None)
        
    def get_slice_for_cat(self):
        if self.prompt_place == 'pre':
            return slice(None), slice(self.n_tokens, None), slice(None)
        else:  # prompt_place == post
            return slice(None), slice(None, -self.n_tokens), slice(None)
        
    def concat_batch(self, orig_vals, new_vals):
        if self.prompt_place == 'pre':
            return torch.cat([new_vals, orig_vals], axis=1)
        else:  # prompt_place == post
            return torch.cat([orig_vals, new_vals], axis=1)
        
    def post_tokenizer(self, **kwargs):
        for special_key, pad_val in self.post_tokenizer_map.items():
            if special_key in kwargs:
                orig_tokens = kwargs[special_key]
                batch_size = kwargs[special_key].size(0)
                new_vals = torch.full(
                    size=(batch_size, self.n_tokens),
                    fill_value=pad_val,
                    dtype=orig_tokens.dtype,
                    device=orig_tokens.device
                )
                kwargs[special_key].data = self.concat_batch(orig_tokens, new_vals)
        return kwargs

class TransformerInjector(nn.Module):
    def __init__(self, module):
        super().__init__()
        self.original_module = module
        self.add_prompt = SoftEmbedding(
            emb_dim=module.output.dense.out_features,
            n_tokens=NTOKENS,
            prompt_place=PROMPT_PLACE,
            mode='add'
        )
        
    def forward(self, hidden_states, *args, **kwargs):
        hidden_states = self.add_prompt(hidden_states)
        return self.original_module(hidden_states, *args, **kwargs)
    
    @classmethod
    def muatate_list(cls, module_list):
        for idx, module in enumerate(module_list):
            module_list[idx] = cls(module)
        return module_list
    
class NewEmbeddingLayer(nn.Module):
    def __init__(self, emb_layer=nn.Embedding):
        super().__init__()
        self.emb_layer = emb_layer
        self.soft_prompt = SoftEmbedding(
            emb_dim=emb_layer.weight.size(1),
            n_tokens=NTOKENS,
            prompt_place=PROMPT_PLACE
        )
        
    def forward(self, tokens):
        out = self.emb_layer(tokens)
        out = self.soft_prompt(out)
        return out
    
    def get_weights(self):
        return self.soft_prompt.get_weights()
    
    def set_weights(self, new_weights):
        self.soft_prompt.set_weights(new_weights)
    
    @classmethod
    def mutate(cls, model):
        emb_layer = model.get_input_embeddings()
        new_emb_layer = cls(emb_layer)
        model.set_input_embeddings(new_emb_layer)
        
        orig_forward = model.forward
        
        def new_forward(**kwargs):
            new_kwargs = new_emb_layer.soft_prompt.post_tokenizer(**kwargs)
            return orig_forward(**new_kwargs)
        
        model.forward = new_forward
        return new_emb_layer

In [7]:
from transformers import BertForSequenceClassification, BertTokenizerFast

model = BertForSequenceClassification.from_pretrained(MODEL_NAME)
tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME)

peft_module = NewEmbeddingLayer.mutate(model)
peft_bert_layers = TransformerInjector.muatate_list(model.bert.encoder.layer)

model.to(DEVICE);

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [23]:
old_w = peft_module.get_weights()
old_w

tensor([[-0.2546, -0.0352, -0.4110,  ...,  0.0189,  0.4121,  0.2206],
        [ 0.0670,  0.0600,  0.4493,  ..., -0.4346,  0.4130, -0.3507],
        [ 0.0827,  0.3569,  0.0943,  ..., -0.3451, -0.1879,  0.0831],
        ...,
        [-0.0489, -0.2570, -0.3328,  ..., -0.4109,  0.0884, -0.0290],
        [-0.2705, -0.3854,  0.4559,  ..., -0.0480, -0.4039,  0.4245],
        [-0.1941,  0.2237,  0.3494,  ..., -0.1199, -0.3030, -0.1530]],
       device='cuda:0')

In [24]:
# tokens = tokenizer("Hi bye", return_tensors='pt').to(DEVICE)

# model.eval()
# with torch.no_grad():
#     out = model(**tokens)
# out

In [3]:
from _datasets import AutoLoad
autoload = AutoLoad()

Found cached dataset glue (/home/mohalisad/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


  0%|          | 0/3 [00:00<?, ?it/s]

In [26]:
for param_name, weights in model.named_parameters():
    if 'classifier' in param_name or NAMESPACE in param_name:
        weights.requires_grad = True
        print(param_name)
    else:
        weights.requires_grad = False

bert.embeddings.word_embeddings.soft_prompt.sadcl_learned_embedding
bert.encoder.layer.0.add_prompt.sadcl_learned_embedding
bert.encoder.layer.1.add_prompt.sadcl_learned_embedding
bert.encoder.layer.2.add_prompt.sadcl_learned_embedding
bert.encoder.layer.3.add_prompt.sadcl_learned_embedding
bert.encoder.layer.4.add_prompt.sadcl_learned_embedding
bert.encoder.layer.5.add_prompt.sadcl_learned_embedding
bert.encoder.layer.6.add_prompt.sadcl_learned_embedding
bert.encoder.layer.7.add_prompt.sadcl_learned_embedding
bert.encoder.layer.8.add_prompt.sadcl_learned_embedding
bert.encoder.layer.9.add_prompt.sadcl_learned_embedding
bert.encoder.layer.10.add_prompt.sadcl_learned_embedding
bert.encoder.layer.11.add_prompt.sadcl_learned_embedding
classifier.weight
classifier.bias


In [8]:
loader_out = autoload.get_and_map(tokenizer, "glue:cola")


Map:   0%|          | 0/8551 [00:00<?, ? examples/s]

Map:   0%|          | 0/1043 [00:00<?, ? examples/s]

Map:   0%|          | 0/1063 [00:00<?, ? examples/s]

In [9]:
loader_out

{'train': Dataset({
     features: ['sentence', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
     num_rows: 8551
 }),
 'valid': Dataset({
     features: ['sentence', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
     num_rows: 1043
 }),
 'output': {'kind': 'classification', 'range': {0, 1}}}

In [28]:
from config import load_config
config = load_config('config.yaml')

In [29]:
from transformers import TrainingArguments, Trainer, DataCollatorWithPadding
from sklearn.metrics import classification_report


def compute_metrics(pred):
    true_labels = pred.label_ids.ravel()
    pred_labels = pred.predictions.argmax(-1).ravel()
    report = classification_report(true_labels, pred_labels, output_dict=True)
    return {
        'accuracy': report['accuracy'],
        'f1-score-1': report['1']['f1-score'],
        'f1-score-ma': report['macro avg']['f1-score']
    }


# def train_model(input_model, task_name, train_dataset, eval_dataset, col_fn):
#     training_args = TrainingArguments(
#         evaluation_strategy="epoch",
#         save_strategy="epoch",
#         # The next 2 lines are important to ensure the dataset labels are properly passed to the model
#         remove_unused_columns=False,
#         **config.hf_trainer_params.to_dict()
#     )

#     trainer = Trainer(
#         model=input_model,
#         args=training_args,
#         train_dataset=train_dataset,
#         eval_dataset=eval_dataset,
#         data_collator=col_fn,
#         compute_metrics=compute_metrics
#     )
#     trainer.train()

col_fn = DataCollatorWithPadding(
    tokenizer, return_tensors='pt', padding='longest'
)

loader_out = autoload.get_and_map(tokenizer, "glue:cola")
num_labels = len(loader_out['output']['range'])

training_args = TrainingArguments(
    evaluation_strategy="epoch",
    save_strategy="epoch",
    # The next 2 lines are important to ensure the dataset labels are properly passed to the model
    remove_unused_columns=False,
    **config.hf_trainer_params.to_dict()
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=loader_out['train'],
    eval_dataset=loader_out['valid'],
    data_collator=col_fn,
    compute_metrics=compute_metrics
)
trainer.train()

Loading cached processed dataset at /home/mohalisad/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-41a6799222324b5f.arrow
Loading cached processed dataset at /home/mohalisad/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-9fc7d7deaf3161a2.arrow
Loading cached processed dataset at /home/mohalisad/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-0eb862d54758b38d.arrow
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Accuracy,F1-score-1,F1-score-ma
1,No log,0.655867,0.691275,0.81746,0.40873
2,0.577800,0.639771,0.763183,0.84893,0.650629
3,0.577800,0.507809,0.766059,0.849197,0.663915
4,0.528700,0.52382,0.770853,0.852195,0.6713
5,0.528700,0.480276,0.794823,0.861757,0.731994
6,0.499800,0.506056,0.776606,0.855906,0.679552
7,0.499800,0.475724,0.795781,0.863198,0.730276
8,0.482900,0.494971,0.790988,0.860614,0.721495
9,0.482900,0.478771,0.786194,0.858592,0.710239
10,0.465700,0.502414,0.780441,0.858903,0.682151


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


TrainOutput(global_step=21440, training_loss=0.37007259682043275, metrics={'train_runtime': 421.6464, 'train_samples_per_second': 1622.402, 'train_steps_per_second': 50.848, 'total_flos': 8141300538608160.0, 'train_loss': 0.37007259682043275, 'epoch': 80.0})

In [72]:
model.bert.encoder.layer[0]

BertLayer(
  (attention): BertAttention(
    (self): BertSelfAttention(
      (query): Linear(in_features=768, out_features=768, bias=True)
      (key): Linear(in_features=768, out_features=768, bias=True)
      (value): Linear(in_features=768, out_features=768, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (output): BertSelfOutput(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (intermediate): BertIntermediate(
    (dense): Linear(in_features=768, out_features=3072, bias=True)
    (intermediate_act_fn): GELUActivation()
  )
  (output): BertOutput(
    (dense): Linear(in_features=3072, out_features=768, bias=True)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

In [None]:
tokenizer("hi bye", return_tensors='pt')

In [None]:
for param_name, weights in model.named_parameters():
    if weights.requires_grad:
        print(param_name)

In [None]:
new_w = peft_module.get_weights()
new_w
(new_w - old_w).sum()

In [None]:
new_w

In [None]:
new_model = BertForSequenceClassification.from_pretrained(MODEL_NAME)
new_model.to(DEVICE)
nm_sd = new_model.state_dict()

In [None]:
old_sd = model.state_dict()

In [None]:
for key, val in old_sd.items():
    if key not in nm_sd:
        print(key)

In [None]:
model.bert.embeddings.word_embeddings.sadcl_soft_prompt.learned_embedding.requires_grad

In [None]:
(old_sd['bert.embeddings.word_embeddings.emb_layer.weight'] - nm_sd['bert.embeddings.word_embeddings.weight']).sum()

In [None]:
a = "hi bye"
model.eval()
with torch.no_grad():
    tokens = tokenizer(a, return_tensors='pt').to(DEVICE)
    o = model.bert.embeddings.word_embeddings(peft_module.post_tokenizer(input_ids=tokens['input_ids'])['input_ids'])

In [None]:
o.shape

torch.Size([1, 10, 768])

In [None]:
o

In [None]:
len(out.hidden_states)

In [None]:
setattr(c, 'a', 3)

In [None]:
class c:
    def a(self, i):
        print(i + 1)
        
    def b(self):
        c = self.a
        self.a = lambda i: c(i + 1)

o = c()
o.a(3)
o.b()
o.a(3)