import numpy as np
import copy
np.set_printoptions(precision=3)

state_idx = np.zeros((6,6))
idx = 0
for x_ in range(6):
    for y_ in range(6):
        state_idx[x_][y_] = idx
        idx += 1

class Qlearning:
    def __init__(self, gamma):
        self.map = [[1, 1, 1, 1, 1, 1, 1, 1],
                    [1, 0, 0, 0, 0, 0, 0, 1],
                    [1, 0, 1, 1, 1, 1, 0, 1],
                    [1, 0, 0, 0, 0, 1, 0, 1],
                    [1, 0, 0, 0, 0, 1, 0, 1],
                    [1, 0, 0, 0, 0, 1, 0, 1],
                    [1, 0, 0, 0, 0, 0, 0, 1],
                    [1, 1, 1, 1, 1, 1, 1, 1]]
        self.q_table = np.zeros((8, 8, 4))
        self.ACTION = {0: [-1, 0], 1: [0, 1], 2: [1, 0], 3: [0, -1]} # up, right, down, left
        self.gamma = gamma
        self.goal = (2, 6)

    def transition(self, s, a):
        delta = self.ACTION[a]
        x, y = s
        new_x, new_y = x + delta[0], y + delta[1]

        # next state
        if self.map[new_x][new_y] == 1:
            next_x, next_y = x, y
        else:
            next_x, next_y = new_x, new_y

        # reward
        if (next_x, next_y) == self.goal:
            r = 10
        else:
            r = -1
        return (next_x, next_y), r
   
    def qlearning_iteration(self):
        lr = 0.1
        q_table_tmp = self.q_table.copy()
        
        for x in range(1, 7):
            for y in range(1, 7):
                if (x,y) == self.goal:
                    continue
                elif self.map[x][y] == 1:
                    continue
                else:
                    for a in range(4):
                        next_s, r = self.transition((x,y), a)

                        if next_s == self.goal:
                            done = 1
                        else:
                            done = 0
                        curr_q = q_table_tmp[x][y][a]
                        next_q = np.max(q_table_tmp[next_s[0]][next_s[1]]) * (1 - done)
                        self.q_table[x][y][a] += lr * (r + self.gamma*next_q - curr_q)

    def qlearning(self):
        for i in range(200):
            self.qlearning_iteration()

    def q_policy(self, env_state):
        x, y = 7- env_state[0], env_state[1]
        return np.argmax(self.q_table[x][y])

    def softmax_policy(self, env_state):
        x, y = 7- env_state[0], env_state[1]
        qs = self.q_table[x][y]
        exp = np.exp(qs)
        prob = exp / np.sum(exp)
        actions = [0, 1, 2, 3]
        action = np.random.choice(actions, size=1, p=prob)[0]
        return action

    def save(self):
        with open('Q_star.npy', 'wb') as f:
            np.save(f, self.q_table)
                    
# agent = Qlearning(0.999)
# agent.qlearning()
# agent.save()

def main():
    import sys
    sys.path.append('..')
    import gym
    from gym.envs.registration import register
    import matplotlib.pyplot as plt

    register(
        id='GuardedMazeEnv-v0',
        entry_point='Maze_Discrete:GuardedMaze',
        kwargs=dict(
            mode=1,
            max_steps=100,
            guard_prob=1.0,
            goal_reward=10.,
            stochastic_trans=False,
        )
    )
    env = gym.make('GuardedMazeEnv-v0')
    env.seed(1)
    agent = Qlearning(0.999)
    agent.qlearning()
    print(np.max(agent.q_table, axis=-1))

    s, done = env.reset(), False
    while not done:
        a = agent.softmax_policy(s)
        s_prime, r , done, info = env.step(a)

        s = s_prime

    env.show(show_traj=True)
    plt.show()


class Policy:
    def __init__(self, n_traj, alpha):
        self.n_traj = n_traj
        self.alpha = alpha
        self.actions = np.array([i for i in range(4)])

        self.theta = np.zeros((8,8,4))
        self.p = np.zeros((8,8))

        self.traj_buf = []

        self.Q_star = np.load('Q_star.npy')


    def put_data(self, traj):
        self.traj_buf.append(traj)

    def get_theta_prob(self, s):
        x, y = s
        theta = self.theta[x][y]
        exp_theta = np.exp(theta)
        prob_theta = exp_theta / np.sum(exp_theta)
        return prob_theta
    
    def get_optimal_prob(self, s):
        x, y = s
        q_star = self.Q_star[x][y]

        logit = q_star - np.max(q_star)

        exp_q_star = np.exp(logit * 3)
        prob_optimal = exp_q_star / np.sum(exp_q_star)
        return prob_optimal
    
    def get_weight(self, s):
        x, y = s
        p = self.p[x][y]
        weight = 1 / (1 + np.exp(-p))
        return weight

    def get_mix_prob(self, s):
        prob_theta = self.get_theta_prob(s)
        prob_optimal = self.get_optimal_prob(s)
        weight = self.get_weight(s)

        prob = weight * prob_theta + (1-weight) * prob_optimal
        return prob

    def get_action(self, s, sample_twice=False):
        if sample_twice:
            weight = self.get_weight(s)

            if np.random.rand() < weight:
                prob = self.get_theta_prob(s)
            else:
                prob = self.get_optimal_prob(s)
            action = np.random.choice(self.actions, size=1, p=prob)[0]
            return action
        else:
            prob = self.get_mix_prob(s)
            action = np.random.choice(self.actions, size=1, p=prob)[0]
            return action

    def eval_action(self, s):
        prob = self.get_mix_prob(s)
        action = np.argmax(prob)
        return action

    def der_theta_logpi(self, s, a):
        pi_a = self.get_mix_prob(s)[a]
        weight = self.get_weight(s)
        pi_theta = self.get_theta_prob(s)
        pi_theta_a = pi_theta[a]

        x,y = s
        der_theta = np.zeros((8,8,4))
        der_theta[x][y][a] = 1.
        der_theta[x][y] -= pi_theta

        der_theta *= 1/pi_a * weight * pi_theta_a
        return der_theta

    def der_p_logpi(self, s, a):
        pi_a = self.get_mix_prob(s)[a]
        pi_theta_a = self.get_theta_prob(s)[a]
        pi_optimal_a = self.get_optimal_prob(s)[a]
        weight = self.get_weight(s)

        x,y = s
        der_p = np.zeros((8,8))
        der_p[x][y] = 1.
        
        der_p *= 1/pi_a * (pi_theta_a - pi_optimal_a) * weight * (1-weight)
        return der_p

    def get_quantile(self, ret_lst):
        return np.quantile(ret_lst, self.alpha)

    def gcvar(self, lr, gamma):
        # calculate return
        ret_lst = []
        for ep_i in range(self.n_traj):
            traj = self.traj_buf[ep_i]
            transition = copy.deepcopy(traj)
            transition.reverse()

            ret = 0.
            for t in range(len(traj)):
                item = transition[t]
                ret = item[2] + gamma * ret

            ret_lst.append(ret)

        print('discounted ret')
        print(ret_lst)

        # choose tail
        sort_idx = np.argsort(ret_lst)
        choose_size = int(self.n_traj * self.alpha)
    
        der_cvar_theta = np.zeros((8,8,4))
        der_cvar_p = np.zeros((8,8))
        quantile_alpha = self.get_quantile(ret_lst)

        for i in range(choose_size):
            traj = self.traj_buf[ sort_idx[i] ]

            sum_der_theta_logpi = np.zeros((8,8,4))
            sum_der_p_logpi = np.zeros((8,8))

            for t in range(len(traj)):
                s, a = traj[t][0], traj[t][1]
                der_theta = self.der_theta_logpi(s, a)
                der_p = self.der_p_logpi(s,a)
                sum_der_theta_logpi += der_theta
                sum_der_p_logpi += der_p

            R_tau = ret_lst[ sort_idx[i] ]

            der_cvar_theta += (R_tau - quantile_alpha) * sum_der_theta_logpi
            der_cvar_p += (R_tau - quantile_alpha) * sum_der_p_logpi

        der_cvar_theta /= choose_size
        der_cvar_p /= choose_size

        # update policy
        self.theta += lr * der_cvar_theta
        self.p += lr * der_cvar_p

        # clean buffer
        self.traj_buf = []
    
    
    
