Browse Source

storing different losses added

master
Faeze 3 months ago
parent
commit
4643219174
4 changed files with 27 additions and 13 deletions
  1. 2
    2
      data/weibo/config.py
  2. 20
    7
      learner.py
  3. 2
    2
      main.py
  4. 3
    2
      model.py

+ 2
- 2
data/weibo/config.py View File

@@ -34,9 +34,9 @@ class WeiboConfig(Config):

image_model_name = '../../../../../media/external_10TB/10TB/ghorbanpoor/vit-base-patch16-224'
image_embedding = 768
text_encoder_model = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased"
text_encoder_model = "../../../../../media/external_10TB/10TB/ghorbanpoor/chinese-xlnet-base"
# text_encoder_model = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased"
text_tokenizer = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased"
text_tokenizer = "../../../../../media/external_10TB/10TB/ghorbanpoor/chinese-xlnet-base"
# text_tokenizer = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased"
text_embedding = 768
max_length = 200

+ 20
- 7
learner.py View File

@@ -22,6 +22,9 @@ def batch_constructor(config, batch):

def train_epoch(config, model, train_loader, optimizer, scalar):
loss_meter = AvgMeter('train')
c_loss_meter = AvgMeter('train')
s_loss_meter = AvgMeter('train')

# tqdm_object = tqdm(train_loader, total=len(train_loader))

targets = []
@@ -32,7 +35,7 @@ def train_epoch(config, model, train_loader, optimizer, scalar):

with torch.cuda.amp.autocast():
output, score = model(batch)
loss = calculate_loss(model, score, batch['label'])
loss, c_loss, s_loss = calculate_loss(model, score, batch['label'])
scalar.scale(loss).backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)

@@ -44,6 +47,8 @@ def train_epoch(config, model, train_loader, optimizer, scalar):

count = batch["id"].size(0)
loss_meter.update(loss.detach(), count)
c_loss_meter.update(c_loss.detach(), count)
s_loss_meter.update(s_loss.detach(), count)

# tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))

@@ -53,11 +58,15 @@ def train_epoch(config, model, train_loader, optimizer, scalar):
target = batch['label'].detach()
targets.append(target)

return loss_meter, targets, predictions
losses = (loss_meter, s_loss_meter, c_loss_meter)

return losses, targets, predictions


def validation_epoch(config, model, validation_loader):
loss_meter = AvgMeter('validation')
c_loss_meter = AvgMeter('validation')
s_loss_meter = AvgMeter('validation')

targets = []
predictions = []
@@ -66,10 +75,12 @@ def validation_epoch(config, model, validation_loader):
batch = batch_constructor(config, batch)
with torch.no_grad():
output, score = model(batch)
loss = calculate_loss(model, score, batch['label'])
loss, c_loss, s_loss = calculate_loss(model, score, batch['label'])

count = batch["id"].size(0)
loss_meter.update(loss.detach(), count)
c_loss_meter.update(c_loss.detach(), count)
s_loss_meter.update(s_loss.detach(), count)

# tqdm_object.set_postfix(validation_loss=loss_meter.avg)

@@ -79,7 +90,9 @@ def validation_epoch(config, model, validation_loader):
target = batch['label'].detach()
targets.append(target)

return loss_meter, targets, predictions
losses = (loss_meter, s_loss_meter, c_loss_meter)

return losses, targets, predictions


def supervised_train(config, train_loader, validation_loader, trial=None):
@@ -131,13 +144,13 @@ def supervised_train(config, train_loader, validation_loader, trial=None):
train_accuracy = multiclass_acc(train_truth, train_pred)
validation_accuracy = multiclass_acc(validation_truth, validation_pred)
print_lr(optimizer)
print('Training Loss:', train_loss, 'Training Accuracy:', train_accuracy)
print('Training Loss:', train_loss[0], 'Training Accuracy:', train_accuracy)
print('Validation Loss', validation_loss, 'Validation Accuracy:', validation_accuracy)

if lr_scheduler:
lr_scheduler.step(validation_loss.avg)
lr_scheduler.step(validation_loss[0].avg)
if early_stopping:
early_stopping(validation_loss.avg, model)
early_stopping(validation_loss[0].avg, model)
if early_stopping.early_stop:
print("Early stopping")
break

+ 2
- 2
main.py View File

@@ -50,7 +50,7 @@ if __name__ == '__main__':

make_directory(config.output_path)

config.writer = SummaryWriter(config.output_path)
# config.writer = SummaryWriter(config.output_path)

if args.use_optuna and use_optuna:
optuna_main(config, args.use_optuna)
@@ -61,4 +61,4 @@ if __name__ == '__main__':
else:
torch_main(config)

config.writer.close()
# config.writer.close()

+ 3
- 2
model.py View File

@@ -95,9 +95,10 @@ def calculate_loss(model, score, label):
# fake = (label == 1).nonzero()
# real = (label == 0).nonzero()
# s_loss = 0 * similarity[fake].mean() + similarity[real].mean()
s_loss = similarity.mean()
c_loss = model.classifier_loss_function(score, label)
loss = model.config.loss_weight * c_loss + similarity.mean()
return loss
loss = model.config.loss_weight * c_loss + s_loss
return loss, c_loss, s_loss


def calculate_similarity_loss(config, image_embeddings, text_embeddings):

Loading…
Cancel
Save