import torch
import gym
from gym.envs.registration import register
import numpy as np
import random
import os
from policy import Policy
import sys
sys.path.append('..')

########## register env ###########
Cheetah_LEN = 500
register(
    id="HCPos-v0",
    entry_point="half_cheetah_pos:HalfCheetahPosEnv",
    max_episode_steps=Cheetah_LEN,
    reward_threshold=None,
    nondeterministic=False,
)
####################################

import argparse
parser = argparse.ArgumentParser(description='lr temp alpha seed')
parser.add_argument('--env_name', type=str, default="HCPos-v0")
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--alpha', type=float, default=0.2)
parser.add_argument('--seed', type=int, default=1)
args = parser.parse_args()

############## setting ##############
train_steps = int(1e6)
eval_intvl = int(1e4)
hidden_dim = 256
quantile_dim = 80
actor_lr = args.lr
qf_lr = actor_lr
discount = 0.99
alpha = args.alpha
seed = args.seed
buffer_sample_size = 128

seed = args.seed
env = gym.make(args.env_name)
eval_env = gym.make(args.env_name)
action_dim = env.action_space.shape[0]
state_dim = env.observation_space.shape[0]
max_action = float(env.action_space.high[0])
print(args.env_name, ' action bound', max_action)
env.seed(seed)
eval_env.seed(2**31-1-seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

############### interaction ############
def eval_policy(env, agent, n_episodes):
    max_ep_length = Cheetah_LEN

    ep_return, ep_vio_rate = [], []
    final_pos_lst, max_pos_lst, min_pos_lst = [], [], []
    traj_visit_cnt = 0

    for _ in range(n_episodes):
        ep_ret, ep_len, xpos_vio = 0, 0, 0
        state = env.reset()
        visit_noise = False
        pos_lst = []
        while True:
            with torch.no_grad():
                mu = agent.get_action(torch.from_numpy(state).float()) # [6]
            action = mu.numpy()
            next_state, reward, done, info = env.step(action)
            
            ep_ret += reward
            ep_len += 1
            xpos= info['x_position']
            pos_lst.append(xpos)
            if xpos < -3:
                xpos_vio += 1
                visit_noise = True

            if ep_len == max_ep_length:
                done = True

            if done:
                final_pos_lst.append(info['x_position'])
                max_pos_lst.append(np.max(pos_lst))
                min_pos_lst.append(np.min(pos_lst))
                if visit_noise:
                    traj_visit_cnt += 1
                break

            state = next_state

        ep_return.append(ep_ret)
        ep_vio_rate.append(1. * xpos_vio / ep_len)

    return np.mean(ep_ret), np.mean(ep_vio_rate), traj_visit_cnt/n_episodes, final_pos_lst, max_pos_lst, min_pos_lst

########################################################
print('CVaR', alpha, 'lr policy', actor_lr, 'lr q', qf_lr, 'hidden_dim', hidden_dim, 'sample size', buffer_sample_size, 'seed', seed)

agent = Policy(state_dim, action_dim, hidden_dim, quantile_dim, actor_lr, qf_lr, discount)

# record eval info
eval_return, eval_vio_rate, eval_traj_vio_rate = [], [], []
eval_finalpos_lst, eval_maxpos_lst, eval_minpos_lst = [], [], []

def save(alpha, lr, hidden_dim, seed):
    root = './save/alpha_'+str(alpha)+'/lr_'+str(lr)
    root += '/h_dim_'+str(hidden_dim) + '/seed_' + str(seed) + '/'
    os.makedirs(root, exist_ok=True)

    with open(root+'ret.npy', 'wb') as f1:
        np.save(f1, eval_return)
    with open(root+'rate.npy', 'wb') as f2:
        np.save(f2, eval_vio_rate)
    with open(root+'traj_rate.npy', 'wb') as f:
        np.save(f, eval_traj_vio_rate)
    with open(root+'final_pos.npy', 'wb') as f:
        np.save(f, eval_finalpos_lst)
    with open(root + 'max_pos.npy', 'wb') as f:
        np.save(f, eval_maxpos_lst)
    with open(root + 'min_pos.npy', 'wb') as f:
        np.save(f, eval_minpos_lst)


# random explore first
# then add gaussian noise
random_timesteps = int(25e3)


state, done = env.reset(), False
k, pre_reward = None, None
ep_len = 0

for t_step in range(train_steps):

    # choose action
    if t_step < random_timesteps:
        action = env.action_space.sample()
    else:
        with torch.no_grad():
            mu = agent.get_action(torch.from_numpy(state).float())
            mu = mu.numpy()
        action = (mu + np.random.normal(0, 0.1, size=action_dim)).clip(-max_action, max_action)

    # track k
    if k is None:               # the initial state
        
        with torch.no_grad():
            quantiles = agent.get_quantiles(torch.from_numpy(state).float(), torch.from_numpy(action).float()) # [64] 
            idx = int(alpha * quantile_dim)
            k = quantiles[idx]

    else:                       # other state

        k = (k - pre_reward) / discount

    # step
    next_state, reward, done, info = env.step(action)

    # add to buffer
    k_ = k.numpy()
    agent.replay_buffer.add(state, k_, action, reward, next_state, done)

    # update_status
    ep_len += 1
    pre_reward = reward
    state = next_state

    # train if has enough samples
    if agent.replay_buffer.size > 1000:
        agent.train(buffer_sample_size)

    # eval every "eval_intvl"
    if t_step == 0 or (t_step + 1) % eval_intvl == 0:
        eval_ret, eval_vio, traj_vio, fpos, maxpos, minpos = eval_policy(eval_env, agent, 20)
        eval_return.append(eval_ret)
        eval_vio_rate.append(eval_vio)
        eval_traj_vio_rate.append(traj_vio)
        eval_finalpos_lst.append(fpos)
        eval_maxpos_lst.append(maxpos)
        eval_minpos_lst.append(minpos)

        save(alpha, actor_lr, hidden_dim, seed)

    if ep_len == Cheetah_LEN:
        done = True

    if done:
        # reset episode
        state, done = env.reset(), False
        k, pre_reward = None, None
        ep_len = 0

        
