import os
import sys
sys.path.append('..')
from policy import Policy
import numpy as np
import random

import gym
from gym.envs.registration import register
import matplotlib.pyplot as plt

import argparse
parser = argparse.ArgumentParser(description='seed, alpha, temp')
parser.add_argument('--alpha', type=float, help='CVaR alpha', default=0.1)
parser.add_argument('--lr_p', type=float, help='lr of policy', default=0.01)
parser.add_argument('--seed', type=int, help='seed', default=1)
args = parser.parse_args()

############### setting #################
seed = args.seed
gamma = 0.999
lr_policy = args.lr_p
alpha = args.alpha

train_epochs = 1000
episodes_num = 50
eval_intvl = 50


register(
    id='GuardedMazeEnv-v0',
    entry_point='Maze_Discrete:GuardedMaze',
    kwargs=dict(
        mode=1,
        max_steps=100,
        guard_prob=1.0,
        goal_reward=10.,
        stochastic_trans=False,
    )
)
env = gym.make('GuardedMazeEnv-v0')
eval_env = gym.make('GuardedMazeEnv-v0')
env.seed(seed=seed)
eval_env.seed(seed=2**31-1-seed)
np.random.seed(seed)
random.seed(seed)


####################### interaction ################
def play_episode(env, model):
    global gamma

    traj = []
    ret = 0
    traj_len = 0

    s = env.reset()
    done = False
    while not done:
        a = model.get_action(s) 
        s2, r, done, info = env.step(a)
        traj_len += 1
        ret += r

        traj.append([s, a, r])
        
        s = s2

    long = False 
    if (6,5) not in env.state_traj and (2,6) in env.state_traj:
        long = True

    opt_long = False
    if long and traj_len <= 14:
        opt_long = True

    return traj, long, opt_long, ret


def eval_policy(env, model):
    ret, traj_len = 0, 0
    s = env.reset()
    done = False
    while not done:
        a = model.eval_action(s) 
        s2, r, done, info = env.step(a)
        traj_len += 1
        ret += r

        s = s2
    
    return ret, traj_len

print('CVaR', alpha, '| lr', lr_policy, '| seed', seed)
#############################################
agent = Policy(episodes_num, alpha)
train_ret_lst, eval_ret_lst = [], [] 
train_long_lst, train_optlong_lst = [], []

for ep_i in range(train_epochs):
    epoch_ret_lst = []
    long_cnt = 0
    opt_long_cnt = 0
    for _ in range(episodes_num):
        traj, long, opt_long, traj_ret = play_episode(env, agent)
        if long:
            long_cnt += 1
        if opt_long:
            opt_long_cnt += 1
        agent.put_data(traj)
        epoch_ret_lst.append(traj_ret)

    print('ep', ep_i+1, ' train ret', np.mean(epoch_ret_lst))

    agent.gcvar(lr_policy, gamma)

    train_ret_lst.append(np.mean(epoch_ret_lst))
    train_long_lst.append(long_cnt / episodes_num)
    train_optlong_lst.append(opt_long_cnt / episodes_num)

    if (ep_i+1) % eval_intvl == 0:
        eval_ret, _ = eval_policy(eval_env, agent)
        eval_ret_lst.append(eval_ret)

def save(alpha, lr, seed):
    root = './save/alpha_'+str(alpha) + '/lr_'+str(lr)+'/seed_'+str(seed) + '/'
    os.makedirs(root, exist_ok=True)
    with open(root + 'train.npy', 'wb') as f1:
        np.save(f1, train_ret_lst)
    with open(root + 'eval.npy', 'wb') as f2:
        np.save(f2, eval_ret_lst)
    with open(root + 'train_long.npy', 'wb') as f:
        np.save(f, train_long_lst)
    with open(root + 'train_long14.npy', 'wb') as f:
        np.save(f, train_optlong_lst)

def save_model(alpha, lr, seed):
    root = './save/alpha_'+str(alpha) + '/lr_'+str(lr)+'/seed_'+str(seed) + '/'
    os.makedirs(root, exist_ok=True)
    with open(root + 'theta.npy', 'wb') as f:
        np.save(f, agent.theta)
    

save(alpha, lr_policy, seed)
save_model(alpha, lr_policy, seed)

def plot_train_path(env, pi):
    s = env.reset()
    done = False
    while not done:
        a = pi.get_action(s)
        s_p, r, done, info = env.step(a)

        s = s_p

    env.show(show_traj=True)
    plt.show()

# plot_train_path(eval_env, agent)

def plot_eval_path(env, pi):
    s = env.reset()
    done = False
    while not done:
        a = pi.eval_action(s)
        s_p, r, done, info = env.step(a)

        s = s_p

    env.show(show_traj=True)
    plt.show()

# plot_eval_path(eval_env, agent)
    
