from configparser import ConfigParser
from argparse import ArgumentParser

import torch
import gym
from gym.envs.registration import register
import numpy as np
import os
from sac import SAC
from utils import make_transition, Dict
import sys
sys.path.append('..')

Cheetah_LEN = 500
register(
    id="HCPos-v0",
    entry_point="half_cheetah_pos:HalfCheetahPosEnv",
    max_episode_steps=Cheetah_LEN,
    reward_threshold=None,
    nondeterministic=False,
)

parser = ArgumentParser('parameters')
parser.add_argument("--algo", type=str, default = 'sac', help = 'algorithm to train')
parser.add_argument('--epochs', type=int, default=3000, help='number of epochs, (default: 3000)')
parser.add_argument('--tensorboard', type=bool, default=False, help='use_tensorboard, (default: False)')
parser.add_argument("--use_cuda", type=bool, default = False, help = 'cuda usage(default : True)')
parser.add_argument("--reward_scaling", type=float, default = 1.0, help = 'reward scaling(default : 1.0)')
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--eval_interval", type=int, default = 30, help = 'eval interval(default: 100)')
args = parser.parse_args()

parser = ConfigParser()
parser.read('config.ini')
agent_args = Dict(parser,args.algo)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
if args.use_cuda == False:
    device = 'cpu'

if args.tensorboard:
    from torch.utils.tensorboard import SummaryWriter
    writer = SummaryWriter()
else:
    writer = None

env = gym.make("HCPos-v0")
eval_env = gym.make("HCPos-v0")
action_dim = env.action_space.shape[0]
state_dim = env.observation_space.shape[0]

seed = args.seed
env.seed(seed)
eval_env.seed(2**31-1-seed)
np.random.seed(seed)
torch.manual_seed(seed)

print('env noise_scale', env.noise_scale)

if args.algo == 'sac':
    agent = SAC(writer, device, state_dim, action_dim, agent_args)

if (torch.cuda.is_available()) and (args.use_cuda):
    agent = agent.cuda()

############ eval function ###########
def eval_model(env, agent, n_episodes):
    max_episode_length = Cheetah_LEN
    return_lst, ep_len_lst, xpos_vio_lst = [], [], []
    final_pos_lst, max_pos_lst, min_pos_lst = [], [], []
    traj_vio_cnt = 0

    for _ in range(n_episodes):
        s, done = env.reset(), False
        pos_lst = []
        ep_r, total_step, xpos_vio = 0, 0, 0
        visit_noise = False
        while True:
            with torch.no_grad():
                a, _ = agent.get_action(torch.from_numpy(s).float().to(device))
                a = a.cpu().numpy()

            s_prime, r, done, info = env.step(a)

            xpos = info['x_position']
            pos_lst.append(xpos)
            if xpos < -3:
                xpos_vio += 1
                visit_noise = True
            ep_r += r
            total_step += 1

            if total_step == max_episode_length:
                done = True
            if done:
                if visit_noise:
                    traj_vio_cnt += 1
                final_pos_lst.append(info['x_position'])
                max_pos_lst.append(np.max(pos_lst))
                min_pos_lst.append(np.min(pos_lst))
                break

            s = s_prime

        return_lst.append(ep_r)
        ep_len_lst.append(total_step)
        xpos_vio_lst.append(xpos_vio)

        rate = np.array(xpos_vio_lst) / np.array(ep_len_lst)
    return np.mean(return_lst), rate.mean(), final_pos_lst, max_pos_lst, min_pos_lst, traj_vio_cnt / n_episodes

#######################################
eval_ret_lst, eval_vio_lst = [], []
eval_finalpos_lst, eval_maxpos_lst, eval_minpos_lst = [], [], []
eval_traj_vio_lst = []

# eval inital policy
eval_ret, eval_vio, fpos, maxpos, minpos, traj_vio = eval_model(eval_env, agent, 20)
eval_ret_lst.append(eval_ret)
eval_vio_lst.append(eval_vio)
eval_finalpos_lst.append(fpos)
eval_maxpos_lst.append(maxpos)
eval_minpos_lst.append(minpos)
eval_traj_vio_lst.append(traj_vio)

def save(seed):
    root = './save_noise50/seed_' + str(seed) + '/'
    os.makedirs(root, exist_ok=True)
    with open(root+'ret.npy', 'wb') as f:
        np.save(f, eval_ret_lst)
    with open(root+'vio.npy', 'wb') as f:
        np.save(f, eval_vio_lst)
    with open(root+'final_pos.npy', 'wb') as f:
        np.save(f, eval_finalpos_lst)
    with open(root+'max_pos.npy', 'wb') as f:
        np.save(f, eval_maxpos_lst)
    with open(root+'min_pos.npy', 'wb') as f:
        np.save(f, eval_minpos_lst)
    with open(root+'traj_vio.npy', 'wb') as f:
        np.save(f, eval_traj_vio_lst)


best_eval_score = -9999999
for ep_i in range(args.epochs):
    score = 0.
    state = env.reset()
    done = False
    while not done:
        action, _ = agent.get_action(torch.from_numpy(state).float().to(device))
        action = action.cpu().detach().numpy()
        next_state, reward, done, info = env.step(action)

        transition = make_transition(state,\
                                    action,\
                                    np.array([reward*args.reward_scaling]),\
                                    next_state,\
                                    np.array([done])\
                                    )
        agent.put_data(transition)

        state = next_state
        score += reward

        if agent.data.data_idx > agent_args.learn_start_size: 
            agent.train_net(agent_args.batch_size, ep_i)

    if args.tensorboard:
        writer.add_scalar("score/score", score, ep_i)
    #print('ep', ep_i+1, 'ret', score)

    if (ep_i+1) % args.eval_interval == 0:
        eval_ret, eval_vio, fpos, maxpos, minpos, traj_vio = eval_model(eval_env, agent, 20)
        eval_ret_lst.append(eval_ret)
        eval_vio_lst.append(eval_vio)
        eval_finalpos_lst.append(fpos)
        eval_maxpos_lst.append(maxpos)
        eval_minpos_lst.append(minpos)
        eval_traj_vio_lst.append(traj_vio)
        

        save(seed)
    
    if (ep_i+1) == 500:
        torch.save(agent.state_dict(), './save_noise50/agent_5b')
    if (ep_i+1) == 1000:
        torch.save(agent.state_dict(), './save_noise50/agent_1k')

