Browse Source

storing different losses added

master
Faeze 1 year ago
parent
commit
4643219174
4 changed files with 27 additions and 13 deletions
  1. 2
    2
      data/weibo/config.py
  2. 20
    7
      learner.py
  3. 2
    2
      main.py
  4. 3
    2
      model.py

+ 2
- 2
data/weibo/config.py View File



image_model_name = '../../../../../media/external_10TB/10TB/ghorbanpoor/vit-base-patch16-224' image_model_name = '../../../../../media/external_10TB/10TB/ghorbanpoor/vit-base-patch16-224'
image_embedding = 768 image_embedding = 768
text_encoder_model = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased"
text_encoder_model = "../../../../../media/external_10TB/10TB/ghorbanpoor/chinese-xlnet-base"
# text_encoder_model = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased" # text_encoder_model = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased"
text_tokenizer = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased"
text_tokenizer = "../../../../../media/external_10TB/10TB/ghorbanpoor/chinese-xlnet-base"
# text_tokenizer = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased" # text_tokenizer = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased"
text_embedding = 768 text_embedding = 768
max_length = 200 max_length = 200

+ 20
- 7
learner.py View File



def train_epoch(config, model, train_loader, optimizer, scalar): def train_epoch(config, model, train_loader, optimizer, scalar):
loss_meter = AvgMeter('train') loss_meter = AvgMeter('train')
c_loss_meter = AvgMeter('train')
s_loss_meter = AvgMeter('train')

# tqdm_object = tqdm(train_loader, total=len(train_loader)) # tqdm_object = tqdm(train_loader, total=len(train_loader))


targets = [] targets = []


with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
output, score = model(batch) output, score = model(batch)
loss = calculate_loss(model, score, batch['label'])
loss, c_loss, s_loss = calculate_loss(model, score, batch['label'])
scalar.scale(loss).backward() scalar.scale(loss).backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)




count = batch["id"].size(0) count = batch["id"].size(0)
loss_meter.update(loss.detach(), count) loss_meter.update(loss.detach(), count)
c_loss_meter.update(c_loss.detach(), count)
s_loss_meter.update(s_loss.detach(), count)


# tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer)) # tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))


target = batch['label'].detach() target = batch['label'].detach()
targets.append(target) targets.append(target)


return loss_meter, targets, predictions
losses = (loss_meter, s_loss_meter, c_loss_meter)

return losses, targets, predictions




def validation_epoch(config, model, validation_loader): def validation_epoch(config, model, validation_loader):
loss_meter = AvgMeter('validation') loss_meter = AvgMeter('validation')
c_loss_meter = AvgMeter('validation')
s_loss_meter = AvgMeter('validation')


targets = [] targets = []
predictions = [] predictions = []
batch = batch_constructor(config, batch) batch = batch_constructor(config, batch)
with torch.no_grad(): with torch.no_grad():
output, score = model(batch) output, score = model(batch)
loss = calculate_loss(model, score, batch['label'])
loss, c_loss, s_loss = calculate_loss(model, score, batch['label'])


count = batch["id"].size(0) count = batch["id"].size(0)
loss_meter.update(loss.detach(), count) loss_meter.update(loss.detach(), count)
c_loss_meter.update(c_loss.detach(), count)
s_loss_meter.update(s_loss.detach(), count)


# tqdm_object.set_postfix(validation_loss=loss_meter.avg) # tqdm_object.set_postfix(validation_loss=loss_meter.avg)


target = batch['label'].detach() target = batch['label'].detach()
targets.append(target) targets.append(target)


return loss_meter, targets, predictions
losses = (loss_meter, s_loss_meter, c_loss_meter)

return losses, targets, predictions




def supervised_train(config, train_loader, validation_loader, trial=None): def supervised_train(config, train_loader, validation_loader, trial=None):
train_accuracy = multiclass_acc(train_truth, train_pred) train_accuracy = multiclass_acc(train_truth, train_pred)
validation_accuracy = multiclass_acc(validation_truth, validation_pred) validation_accuracy = multiclass_acc(validation_truth, validation_pred)
print_lr(optimizer) print_lr(optimizer)
print('Training Loss:', train_loss, 'Training Accuracy:', train_accuracy)
print('Training Loss:', train_loss[0], 'Training Accuracy:', train_accuracy)
print('Validation Loss', validation_loss, 'Validation Accuracy:', validation_accuracy) print('Validation Loss', validation_loss, 'Validation Accuracy:', validation_accuracy)


if lr_scheduler: if lr_scheduler:
lr_scheduler.step(validation_loss.avg)
lr_scheduler.step(validation_loss[0].avg)
if early_stopping: if early_stopping:
early_stopping(validation_loss.avg, model)
early_stopping(validation_loss[0].avg, model)
if early_stopping.early_stop: if early_stopping.early_stop:
print("Early stopping") print("Early stopping")
break break

+ 2
- 2
main.py View File



make_directory(config.output_path) make_directory(config.output_path)


config.writer = SummaryWriter(config.output_path)
# config.writer = SummaryWriter(config.output_path)


if args.use_optuna and use_optuna: if args.use_optuna and use_optuna:
optuna_main(config, args.use_optuna) optuna_main(config, args.use_optuna)
else: else:
torch_main(config) torch_main(config)


config.writer.close()
# config.writer.close()

+ 3
- 2
model.py View File

# fake = (label == 1).nonzero() # fake = (label == 1).nonzero()
# real = (label == 0).nonzero() # real = (label == 0).nonzero()
# s_loss = 0 * similarity[fake].mean() + similarity[real].mean() # s_loss = 0 * similarity[fake].mean() + similarity[real].mean()
s_loss = similarity.mean()
c_loss = model.classifier_loss_function(score, label) c_loss = model.classifier_loss_function(score, label)
loss = model.config.loss_weight * c_loss + similarity.mean()
return loss
loss = model.config.loss_weight * c_loss + s_loss
return loss, c_loss, s_loss




def calculate_similarity_loss(config, image_embeddings, text_embeddings): def calculate_similarity_loss(config, image_embeddings, text_embeddings):

Loading…
Cancel
Save