import numpy as np
import copy

class Policy:
    def __init__(self, n_traj, alpha):
        self.n_traj = n_traj
        self.alpha = alpha
        self.actions = np.array([i for i in range(4)])

        self.theta = np.zeros((8,8,4))
        self.traj_buf = []

    def put_data(self, traj):
        self.traj_buf.append(traj)

    def get_theta_prob(self, s):
        x, y = s
        theta = self.theta[x][y]
        exp_theta = np.exp(theta)
        prob_theta = exp_theta / np.sum(exp_theta)
        return prob_theta

    def get_action(self, s):
        prob = self.get_theta_prob(s)
        action = np.random.choice(self.actions, size=1, p=prob)[0]
        return action

    def eval_action(self, s):
        prob = self.get_theta_prob(s)
        action = np.argmax(prob)
        return action

    def der_theta_logpi(self, s, a):
        pi_theta = self.get_theta_prob(s)

        x, y = s
        der_theta = np.zeros((8,8,4))
        der_theta[x][y][a] = 1.
        der_theta[x][y] -= pi_theta

        return der_theta

    def get_quantile(self, ret_lst):
        return np.quantile(ret_lst, self.alpha)

    def gcvar(self, lr, gamma):
        # calculate return
        ret_lst = []
        for ep_i in range(self.n_traj):
            traj = self.traj_buf[ep_i]
            transition = copy.deepcopy(traj)
            transition.reverse()

            ret = 0.
            for t in range(len(traj)):
                item = transition[t]
                ret = item[2] + gamma * ret

            ret_lst.append(ret)

        # choose tail
        sort_idx = np.argsort(ret_lst)
        choose_size = int(self.n_traj * self.alpha)

        der_cvar_theta = np.zeros((8,8,4))
        quantile_alpha = self.get_quantile(ret_lst)

        for i in range(choose_size):
            traj = self.traj_buf[ sort_idx[i] ]
            sum_der_theta_logpi = np.zeros((8,8,4))

            for t in range(len(traj)):
                s,a = traj[t][0], traj[t][1]
                der_theta = self.der_theta_logpi(s, a)
                sum_der_theta_logpi += der_theta
            
            R_tau = ret_lst[ sort_idx[i] ]
            der_cvar_theta += (R_tau - quantile_alpha) * sum_der_theta_logpi

        der_cvar_theta /= choose_size

        # update policy
        self.theta += lr * der_cvar_theta

        # clearn buffer
        self.traj_buf = []