import os
import gym
import random
import torch
import numpy as np
import pickle
from policy_mix import Policy_mix


import sys
sys.path.append('..')
from lunar_lander_risk import LunarLander
LunarLander_LEN = 500

import argparse
parser = argparse.ArgumentParser(description='parameters')
parser.add_argument("--epochs", type=int, default=4000, help="Number of train epochs")
parser.add_argument("--n_episodes", type=int, default=30, help="Number of episodes per epoch")
parser.add_argument("--batch_size", type=int, default=128, help="Batch size, default: 256")
parser.add_argument("--hidden_size", type=int, default=128, help="")
parser.add_argument("--cvar_lr", type=float, default=7e-4, help="")
parser.add_argument("--iql_policy_lr", type=float, default=1e-4, help="")
parser.add_argument("--iql_value_lr", type=float, default=1e-4, help="")
parser.add_argument("--temperature", type=float, default=1, help="")
parser.add_argument("--expectile", type=float, default=0.8, help="")
parser.add_argument("--tau", type=float, default=5e-3, help="")
parser.add_argument("--cvar_alpha", type=float, default=0.1, help="")
parser.add_argument('--gamma', type=float, default=0.999, help='gamma')
parser.add_argument('--iql_update', type=int, default=3, help='iql update frequency')
parser.add_argument('--iql_sample_size', type=int, default=200000, help='iql update sample size')
parser.add_argument("--seed", type=int, default=1)
args = parser.parse_args()


################# setting ##################
noise_scale = 100
env = LunarLander(noise_scale)
eval_env = LunarLander(noise_scale)
state_size = env.observation_space.shape[0]
action_size = env.action_space.n

# set seed
seed = args.seed
env.seed(seed)
eval_env.seed(2**31-1-seed)
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)

device = torch.device('cpu')
agent = Policy_mix(state_size,
                    action_size,
                    args.hidden_size,
                    args.cvar_lr,
                    args.iql_policy_lr,
                    args.iql_value_lr,
                    args.cvar_alpha,
                    args.tau,
                    args.temperature,
                    args.expectile,
                    args.gamma,
                    device)

print('CVaR', args.cvar_alpha, '| lr', args.cvar_lr, '| hidden', args.hidden_size)
print('IQL lr_policy', args.iql_policy_lr, ', lr_v', args.iql_value_lr, ', temp', args.temperature, ', expectile', args.expectile, ', batch_size', args.batch_size, ', update', args.iql_update, ', sample size', args.iql_sample_size)

###########################################
def play_episodes(env, agent, n_episodes):
    ##########
    VIEWPORT_W = 600
    VIEWPORT_H = 400
    LEG_DOWN = 18
    SCALE = 30.0
    H = VIEWPORT_H/SCALE
    #########

    total_r_lst = []
    land_left_cnt = 0.
    p_episodes_lst= []

    for _ in range(n_episodes):
        state = env.reset()
        total_r, ep_len = 0, 0
        p_lst = []

        state_lst, action_lst, reward_lst = [], [], []
        for t in range(LunarLander_LEN):
            action, p = agent.get_action(torch.from_numpy(state).float().to(device))
            
            next_state, reward, done, info = env.step(action)

            # record
            agent.replay_buffer.add(state, action, reward, next_state, done)
            state_lst.append(state)
            action_lst.append(action)
            reward_lst.append(reward)
            p_lst.append(p)

            # update status
            total_r += reward
            ep_len += 1
            state = next_state

            if done or (t+1) == LunarLander_LEN:
                x = next_state[0] * (VIEWPORT_W/SCALE/2) + (VIEWPORT_W/SCALE/2)
                if x <= 10 and reward == 100:
                    land_left_cnt += 1
                break
        
        agent.add_traj(state_lst, action_lst, reward_lst)
        total_r_lst.append(total_r)
        p_episodes_lst.append(p_lst)
        

    return np.mean(total_r_lst), land_left_cnt / n_episodes, p_episodes_lst

def eval_iql(env, agent, n_episodes):
    total_r_lst = []
    success_land = 0
    land_left_cnt = 0
    for _ in range(n_episodes):
        state = env.reset()
        total_r, ep_len = 0, 0

        for t in range(LunarLander_LEN):
            action = agent.iql_agent.get_action(torch.from_numpy(state).float().to(device))

            next_state, reward, done, info = env.step(action)

            total_r += reward
            ep_len += 1
            state = next_state

            if done:
                if reward != -100:
                    success_land += 1

                VIEW_W = 600
                SCALE = 30.0
                x = state[0] * (VIEW_W/SCALE/2) + (VIEW_W/SCALE/2)
                if x<= 10 and reward == 100:
                    land_left_cnt += 1
                break

        total_r_lst.append(total_r)

    return np.mean(total_r_lst), land_left_cnt, success_land

########################################################

train_ret_lst, train_land_lst = [], []
train_p_lst = []
eval_iql_ret_lst, eval_iql_left_lst, eval_iql_success_lst = [], [], []

def save(alpha, cvar_lr, iql_lrp, iql_lrv, temp, bs, iql_update, iql_sample_size, seed):
    path = './save/alpha_'+str(alpha) + '/cvar_lr_' + str(cvar_lr) +  '/i_lrp_' + str(iql_lrp) + '/i_lrv_' + str(iql_lrv) + '/temp_'+str(temp)
    path += '/bs_'+str(bs)+ '/iql_update_'+str(iql_update) + '/iql_sample_' + str(iql_sample_size) + '/seed_' + str(seed) + '/'
    os.makedirs(path, exist_ok=True)
    with open(path+'ret.npy', 'wb') as f:
        np.save(f, train_ret_lst)
    with open(path+'land.npy', 'wb') as f:
        np.save(f, train_land_lst)
    with open(path+'iql_ret.npy', 'wb') as f:
        np.save(f, eval_iql_ret_lst)
    with open(path + 'iql_left.npy', 'wb') as f:
        np.save(f, eval_iql_left_lst)
    with open(path + 'iql_success.npy','wb') as f:
        np.save(f, eval_iql_success_lst)
    with open(path + 'p.pkl', 'wb') as f:
        pickle.dump(train_p_lst, f)

def train(config, env, agent):
    for ep_i in range(config.epochs):
        ret, land_rate, p_lst = play_episodes(env, agent, config.n_episodes)
        #print(ep_i+1, 'ret', ret, 'rate', land_rate)
        train_ret_lst.append(ret)
        train_land_lst.append(land_rate)
        train_p_lst.append(p_lst)

        agent.gcvar()

        if (ep_i+1) % 50 == 0:
            agent.train_iql(config.iql_update, config.iql_sample_size, config.batch_size)

        if (ep_i+1) % 50  == 0:
            iql_ret, iql_left, iql_success = eval_iql(eval_env, agent, 30)
            eval_iql_ret_lst.append(iql_ret)
            eval_iql_left_lst.append(iql_left)
            eval_iql_success_lst.append(iql_success)
            #print('buffer size', agent.replay_buffer.size)
            #print('iql ret', iql_ret)
        
            save(config.cvar_alpha,
                config.cvar_lr,
                config.iql_policy_lr,
                config.iql_value_lr,
                config.temperature,
                config.batch_size,
                config.iql_update,
                config.iql_sample_size,
                config.seed)


train(args, env, agent)
