import sys
sys.path.append('..')
from lunar_lander_risk import LunarLander
import os
import torch
import random
import numpy as np
from qrdqn import Policy

import argparse
parser = argparse.ArgumentParser(description='lr temp alpha seed')
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--alpha', type=float, default=0.1)
parser.add_argument('--seed', type=int, default=1)
args = parser.parse_args()

############## setting ##############
LunarLander_LEN = 500

max_episodes = 4000*500
train_steps = int(2e6)
eval_step_intvl = int(1e4)
eval_episode_intvl = 50

hidden_dim = 128
quantile_dim = 80
lr = args.lr
discount = 0.999
alpha = args.alpha
seed = args.seed
buffer_sample_size = 128
reward_scale = 0.1

noise = 100
env = LunarLander(noise)
eval_env = LunarLander(noise)
action_num = env.action_space.n
state_dim = env.observation_space.shape[0]
env.seed(seed)
eval_env.seed(2**31-1-seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


############# interaction ############
def eval_policy(env, agent, n_episodes):
    ep_return = []
    land_left = 0
    for _ in range(n_episodes):
        ep_ret = 0
        state = env.reset()
        
        for t in range(LunarLander_LEN):
            # choose action
            action = agent.get_action(torch.from_numpy(state).float())

            # step
            next_state, reward, done, info = env.step(action)

            # update status
            state = next_state
            ep_ret += reward

            if t+1 == LunarLander_LEN:
                done = True

            if done:
                # check if land at left
                VIEWPORT_W = 600
                SCALE = 30.0
                x = next_state[0] * (VIEWPORT_W/SCALE/2) + (VIEWPORT_W/SCALE/2)
                if x <=10 and reward == 100:
                    land_left += 1

                break

        ep_return.append(ep_ret)

    return np.mean(ep_ret), land_left

#####################################################
print('CVaR', alpha, '| lr', lr, '| hidden', hidden_dim, '| quantile', quantile_dim, '| batch_size', buffer_sample_size, '| seed', seed)
agent = Policy(state_dim, action_num, hidden_dim, quantile_dim, lr, discount, alpha)

# record eval info
eval_ret_step, eval_land_step = [], []
eval_ret_episode, eval_land_episode = [], []

def save(alpha, lr, hidden_dim, quantile_dim, seed):
    root = './save/alpha_'+str(alpha)+'/lr_'+str(lr)
    root += '/h_dim_'+str(hidden_dim)+ '/qt_dim_'+str(quantile_dim) + '/seed_' + str(seed) + '/'
    os.makedirs(root, exist_ok=True)

    with open(root+'ret_step.npy', 'wb') as f1:
        np.save(f1, eval_ret_step)
    with open(root+'land_left_step.npy', 'wb') as f2:
        np.save(f2, eval_land_step)

    with open(root+'ret_episode.npy', 'wb') as f3:
        np.save(f3, eval_ret_episode)
    with open(root+'land_left_episode.npy', 'wb') as f4:
        np.save(f4, eval_land_episode)

# eval initial policy
eval_ret, eval_lf = eval_policy(eval_env, agent, 30)
eval_ret_step.append(eval_ret)
eval_land_step.append(eval_lf)
eval_ret_episode.append(eval_ret)
eval_land_episode.append(eval_lf)

################# train ##################

# explore strategy: linear decay epsilon greedy
def use_epsilon_greedy(t):
    start_epsilon = 1.0
    end_epsilon = 0.02
    final_exploration_steps = 10**5

    if t > final_exploration_steps:
        epsilon = end_epsilon
    else:
        diff = end_epsilon - start_epsilon
        epsilon = start_epsilon + diff * (t / final_exploration_steps)
    
    if np.random.rand() < epsilon:
        return True
    else:
        return False

total_step = 0
for episode_i in range(max_episodes):
    state = env.reset()

    for t in range(LunarLander_LEN):
        # choose action
        if use_epsilon_greedy(total_step):
            action = env.action_space.sample()
        else:
            action = agent.get_action(torch.from_numpy(state).float())        

        # step
        next_state, reward, done, info = env.step(action)
        reward = reward * reward_scale

        # print('k', k, 'r', reward)

        # add buffer
        agent.replay_buffer.add(state, action, reward, next_state, done)

        # update status
        state = next_state
        total_step += 1

        # train if have enough sample
        if agent.replay_buffer.size > 1000:
            agent.train(buffer_sample_size)

        # eval
        if total_step % eval_step_intvl == 0:
            eval_ret, eval_lf = eval_policy(eval_env, agent, 30)
            eval_ret_step.append(eval_ret)
            eval_land_step.append(eval_lf)
            save(alpha, lr, hidden_dim, quantile_dim, seed)
        
        if t+1 == LunarLander_LEN:
            done = True
        if done:
            break
        if total_step >= train_steps:
            break
    
    #
    if (episode_i+1) % eval_episode_intvl == 0:
        eval_ret, eval_lf = eval_policy(eval_env, agent, 30)
        eval_ret_episode.append(eval_ret)
        eval_land_episode.append(eval_lf)

    if total_step >= train_steps:
        break


save(alpha, lr,hidden_dim, quantile_dim, seed)
