import sys
sys.path.append('..')
from lunar_lander_risk import LunarLander
import os
import torch
import random
import numpy as np
from policy import Policy

import argparse
parser = argparse.ArgumentParser(description='lr temp alpha seed')
parser.add_argument('--lr_p', type=float, required=True)
parser.add_argument('--temp', type=float, required=True)
parser.add_argument('--alpha', type=float, required=True)
parser.add_argument('--seed', type=int, required=True)
args = parser.parse_args()

########### setting ###########
# hyperparameters
epochs = 4000                       # run agent for this many epochs
hidden_size = 128                   # number of units in NN hidden layers
actor_lr = args.lr_p                # learning rate for actor
pi_temperature = args.temp

discount = 0.999                    # discount factor
alpha = args.alpha                  # CVaR_alpha
n_episodes = 30

eval_intvl = 20
eval_episodes = 5

seed = args.seed

noise = 100
env = LunarLander(noise)
eval_env = LunarLander(noise)
action_size = env.action_space.n
state_size = env.observation_space.shape[0]
# set seed
env.seed(seed)
eval_env.seed(2**31-1-seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

############ interaction ############
def play_episode(env, agent):
    state = env.reset()
    state_list, action_list, reward_list, done_list = [], [], [], []
    episode_length = 0
    total_reward = 0
    land_left = False
    while True:
        action, action_prob = agent.select_action(state)
        next_state, reward, done, _ = env.step(action)
        episode_length += 1
        total_reward += reward

        # store agent's trajectory
        state_list.append(state)
        action_list.append(action)
        reward_list.append(reward)

        # end episode early
        if episode_length == 1000:
            done_list.append(done)
            done = True
        else:
            done_list.append(done)

        if done:
            # append the last state
            state_list.append(next_state)

            # check if land at left
            VIEWPORT_W = 600
            SCALE = 30.0
            x = next_state[0] * (VIEWPORT_W/SCALE/2) + (VIEWPORT_W/SCALE/2)
            if x <=10 and reward == 100:
                land_left = True
            break

        state = next_state
    
    # store to buffer
    agent.put_data(state_list, action_list, reward_list, done_list)
    return total_reward, land_left

def eval_policy_st(env, agent, n_episodes=5):
    ep_return = []
    for _ in range(n_episodes):
        total_reward = 0
        episode_length = 0
        s = env.reset()
        while True:
            a, _ = agent.select_action(s)
            s, r, done, info = env.step(a)
            total_reward += r
            episode_length += 1

            # end episode early
            if episode_length == 1000:
                done = True

            if done:
                break
        ep_return.append(total_reward)

    return np.mean(ep_return)

def eval_policy_dt(env, agent, n_episodes=5):
    ep_return = []
    for i in range(n_episodes):
        total_reward = 0
        episode_length = 0
        s = env.reset()
        while True:
            a, _ = agent.eval_action(s)
            s, r, done, info = env.step(a)
            total_reward += r
            episode_length += 1

            # end episode early
            if episode_length == 1000:
                done = True

            if done:
                break
        ep_return.append(total_reward)

    return np.mean(ep_return)

############ main ##########

agent = Policy(state_size, action_size, hidden_size, pi_temperature, alpha, actor_lr, discount, n_episodes)

print('method:', agent.name, 'lr_p:', actor_lr, 'pi_temp:', pi_temperature, 'seed:', seed)

# record
train_return = []
train_land_left = []
# eval_return = []
# eval_dt_return = []

def save(alpha, temp, lr_policy, seed):
    root = './save/GCVaR/alpha_' + str(alpha) + '/temp_' +str(temp) +'/lr_p_' + str(lr_policy) + '/seed_'+ str(seed) + '/'
    
    os.makedirs(root, exist_ok=True)

    with open(root+'train.npy', 'wb') as f:
        np.save(f, train_return)
    with open(root+'train_land_left.npy', 'wb') as f:
        np.save(f, train_land_left)

# main
for ep_i in range(epochs):
    ep_ret = []
    land_left_num = 0
    ''' collect on-policy data '''
    for _ in range(n_episodes):
        total_reward, land_lf = play_episode(env, agent)
        ep_ret.append(total_reward)
        if land_lf:
            land_left_num += 1
    
    ''' update '''
    agent.GCVaR()

    train_return.append(np.mean(ep_ret))
    train_land_left.append(land_left_num)

    if (ep_i +1)% eval_intvl == 0:
        # no evaluation to save time
        # not hurt since we are doing on policy learning
        save(alpha, pi_temperature, actor_lr, seed)

save(alpha, pi_temperature, actor_lr, seed)
