1234567891011121314151617181920212223242526272829303132333435 |
- import os
- import torch
- import pickle
- import random
-
- from MeLU import MeLU
- from options import config, states
-
-
- def training(melu, total_dataset, batch_size, num_epoch, model_save=True, model_filename=None):
- if config['use_cuda']:
- melu.cuda()
-
- print("mode: " + str(config['use_cuda']))
-
- training_set_size = len(total_dataset)
- melu.train()
- for epoch in range(num_epoch):
- random.shuffle(total_dataset)
- num_batch = int(training_set_size / batch_size)
- a,b,c,d = zip(*total_dataset)
- for i in range(num_batch):
- print("training - epoch:{} - batch:{}".format(epoch,i))
- try:
- supp_xs = list(a[batch_size*i:batch_size*(i+1)])
- supp_ys = list(b[batch_size*i:batch_size*(i+1)])
- query_xs = list(c[batch_size*i:batch_size*(i+1)])
- query_ys = list(d[batch_size*i:batch_size*(i+1)])
- except IndexError:
- continue
- melu.global_update(supp_xs, supp_ys, query_xs, query_ys, config['inner'])
- del supp_xs,supp_ys,query_xs,query_ys
-
- if model_save:
- torch.save(melu.state_dict(), model_filename)
|