import os
import torch
import torch.nn as nn
from ray import tune
import pickle
import random
import gc
from trainer import Trainer
import numpy as np
from utils import *
from sampler import *
import os


def train_metatl(conf,checkpoint_dir=None):

    SEED = conf["params"]['seed']
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True
    np.random.seed(SEED)
    random.seed(SEED)

    params = conf['params']
    params['batch_size'] = conf['batch_size']
    params['number_of_neg'] = conf['number_of_neg']

    user_train, usernum_train, itemnum, user_input_test, user_test, user_input_valid, user_valid = data_load(params['dataset'], params['K'])
    sampler = WarpSampler(user_train, usernum_train, itemnum, batch_size=params['batch_size'], maxlen=params['K'], n_workers=1,params=params)
    sampler_test = DataLoader(user_input_test, user_test, itemnum, params)
    sampler_valid = DataLoader(user_input_valid, user_valid, itemnum, params)

    ps = {
        "batch_size"    :   conf['batch_size'],
        "learning_rate" :   conf['learning_rate'],
        "epoch"         :   conf["params"]['epoch'],
        "beta"          :   conf['beta'],
        "embed_dim"     :   conf['embed_dim'],
        "margin"        :   conf['margin'],
        "K"             :   conf["params"]['K'],

        "number_of_neg" :   conf["number_of_neg"],
        "loss_function" :   conf["loss_function"],
        "eval_epoch"    :   conf["eval_epoch"],
        "device"        :   params['device']
    }

    trainer = Trainer([sampler, sampler_valid, sampler_test], conf["itemnum"], ps)

    if checkpoint_dir:
        print("===================== using checkpoint =====================")
        model_state, optimizer_state = torch.load(
            os.path.join(checkpoint_dir, "checkpoint"))
        trainer.MetaTL.load_state_dict(model_state)
        trainer.optimizer.load_state_dict(optimizer_state)

    for epoch in range(int(ps['epoch']/ps['eval_epoch'])):
        for e in range(ps['eval_epoch']):
            # sample one batch from data_loader
            train_task, curr_rel = trainer.train_data_loader.next_batch()
            loss, _, _ = trainer.do_one_step(train_task, iseval=False, curr_rel=curr_rel)

        # do evaluation on specific epoch
        valid_data = trainer.eval(istest=False, epoch=(-1))

        with tune.checkpoint_dir(epoch) as checkpoint_dir:
            path = os.path.join(checkpoint_dir, "checkpoint")
            torch.save((trainer.MetaTL.state_dict(), trainer.optimizer.state_dict()), path)

        tune.report(
            MRR=valid_data["MRR"], NDCG10=valid_data['NDCG@10'], NDCG5=valid_data["NDCG@5"], NDCG1=valid_data["NDCG@1"],
            Hits10=valid_data["Hits@10"], Hits5=valid_data["Hits@5"], Hits1=valid_data["Hits@1"],
            training_iteration=epoch*ps['eval_epoch']
        )