from argparse import ArgumentParser
import random
import torch
import numpy as np
import gym
from gym.envs.registration import register
from policy_mix import Policy_mix
import os
import sys
sys.path.append('..')

Pendulum_LEN = 300
register(
    id="IvPos-v0",
    entry_point="inverted_pendulum_pos:InvertedPendulumPosEnv",
    max_episode_steps=Pendulum_LEN,
    reward_threshold=None,
    nondeterministic=False,
)

parser = ArgumentParser('parameters')
parser.add_argument("--epochs", type=int, default=6000, help="Number of train epochs")
parser.add_argument("--n_episodes", type=int, default=30, help="Number of episodes per epoch")
parser.add_argument("--batch_size", type=int, default=256, help="Batch size, default: 256")
parser.add_argument("--hidden_size", type=int, default=128, help="")
parser.add_argument("--cvar_lr", type=float, default=3e-4, help="")
parser.add_argument("--iql_policy_lr", type=float, default=1e-4, help="")
parser.add_argument("--iql_value_lr", type=float, default=1e-4, help="")
parser.add_argument("--temperature", type=float, default=2, help="")
parser.add_argument("--expectile", type=float, default=0.9, help="")
parser.add_argument("--tau", type=float, default=5e-3, help="")
parser.add_argument("--cvar_alpha", type=float, default=0.2, help="")
parser.add_argument("--gamma", type=float, default=0.999)
parser.add_argument('--iql_intvl', type=int, default=50, help='iql update frequency')
parser.add_argument('--iql_update', type=int, default=3, help='iql update round')
parser.add_argument('--iql_sample_size', type=int, default=100000, help='iql update sample size')
parser.add_argument("--seed", type=int, default=1)
args = parser.parse_args()

####################################################
env = gym.make("IvPos-v0")
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])
print('env action bound', max_action)

seed = args.seed
env.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)

device = torch.device('cpu')
agent = Policy_mix(state_dim,
                    action_dim,
                    max_action,
                    args.hidden_size,
                    args.cvar_lr,
                    args.iql_policy_lr,
                    args.iql_value_lr,
                    args.cvar_alpha,
                    args.tau,
                    args.temperature,
                    args.expectile,
                    args.gamma,
                    device)

print('CVaR', args.cvar_alpha, '| lr', args.cvar_lr, '| hidden', args.hidden_size)
print('IQL temp', args.temperature, ', expectile', args.expectile, ', batch_size', args.batch_size, ', lrp', args.iql_policy_lr, ', lrv', args.iql_value_lr)
###########################################
def play_episodes(env, agent, n_episodes):
    total_r_lst, vio_rate_lst = [], []
    traj_visit_cnt = 0

    for _ in range(n_episodes):
        state = env.reset()
        total_r, ep_len, vio_cnt = 0, 0, 0
        visit_noise = True
        state_lst, action_lst, reward_lst = [], [], []
        for t in range(Pendulum_LEN):
            action = agent.get_action(torch.from_numpy(state).float().to(device))
            
            next_state, reward, done, info = env.step(action)

            # record
            agent.replay_buffer.add(state, action, reward, next_state, done)
            state_lst.append(state)
            action_lst.append(action)
            reward_lst.append(reward)

            # update status
            total_r += reward
            ep_len += 1
            xpos = info["x_position"]
            if xpos > 0.04:
                vio_cnt += 1
                visit_noise = True
            state = next_state

            if done or (t+1) == Pendulum_LEN:
                if visit_noise:
                    traj_visit_cnt += 1
                break
        
        agent.add_traj(state_lst, action_lst, reward_lst)
        total_r_lst.append(total_r)
        vio_rate_lst.append(vio_cnt / ep_len)

    return np.mean(total_r_lst), np.mean(vio_rate_lst), traj_visit_cnt/n_episodes
##########################################################

train_ret_lst, train_vio_lst, train_traj_vio_lst = [], [], []

def save(alpha, cvar_lr, iql_lrp, iql_lrv, iql_intvl, iql_update, iql_sample_size, seed):
    path = './save/alpha_'+str(alpha) + '/cvar_lr_' + str(cvar_lr) + '/iql_lrp_' + str(iql_lrp) + '/iql_lrv_' + str(iql_lrv)
    path += '/iql_intvl_'+ str(iql_intvl) +'/iql_update_' + str(iql_update) + '/iql_sample_'+ str(iql_sample_size) + '/seed_' + str(seed) + '/'
    
    os.makedirs(path, exist_ok=True)
    with open(path+'ret.npy', 'wb') as f:
        np.save(f, train_ret_lst)
    with open(path+'rate.npy', 'wb') as f:
        np.save(f, train_vio_lst)
    with open(path+'traj_rate.npy', 'wb') as f:
        np.save(f, train_traj_vio_lst)

def train(config, env, agent):
    
    for ep_i in range(config.epochs):
        ret, vio_rate, traj_rate = play_episodes(env, agent, config.n_episodes)
        # if (ep_i+1)% iql_intvel == 0:
        #     print(ep_i+1, 'ret', ret, 'rate', vio_rate, 'len', ep_len)
        train_ret_lst.append(ret)
        train_vio_lst.append(vio_rate)
        train_traj_vio_lst.append(traj_rate)
        agent.gcvar()

        if (ep_i + 1) % config.iql_intvl == 0:
            #print('buffer size', agent.replay_buffer.size)
            agent.train_iql(config.iql_update, config.iql_sample_size, config.batch_size)

    
        if (ep_i+1) % config.iql_intvl == 0:
            save(config.cvar_alpha, 
                config.cvar_lr,
                config.iql_policy_lr,
                config.iql_value_lr,
                config.iql_intvl,
                config.iql_update,
                config.iql_sample_size,
                seed)

train(args, env, agent)

