import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import copy

class ActorNet(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim, trainable_std=True):
        super(ActorNet, self).__init__()
        self.l1 = nn.Linear(input_dim, hidden_dim)
        self.l2 = nn.Linear(hidden_dim, hidden_dim)
        self.out = nn.Linear(hidden_dim, output_dim)

        self.trainable_std = trainable_std
        if self.trainable_std:
            self.logstd = nn.Parameter(torch.zeros(1, output_dim))
    
    def forward(self, x):
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        out = self.out(x)
        mu = torch.tanh(out) * 3.

        if self.trainable_std:
            std = torch.exp(self.logstd)
        else:
            logstd = torch.zeros_like(mu)
            std = torch.exp(logstd)

        return mu, std

class Policy():
    def __init__(self, state_dim, action_dim, hidden_dim, alpha, actor_lr, discount, n_episodes):
        self.actor_net = ActorNet(state_dim, action_dim, hidden_dim)
        self.actor_optimizer = optim.Adam(self.actor_net.parameters(), lr=actor_lr)

        self.alpha = alpha
        self.discount = discount
        self.n_episodes = n_episodes

        # buffer
        self.state_buf, self.action_buf, self.reward_buf, self.done_buf = [], [], [], []

    def get_action(self, state):
        mu, sigma = self.actor_net(state)
        return mu, sigma

    def put_data(self, state_lst, action_lst, reward_lst, done_lst):
        self.state_buf.append(state_lst)
        self.action_buf.append(action_lst)
        self.reward_buf.append(reward_lst)
        self.done_buf.append(done_lst)

    def get_quantile(self, ret_lst):
        return np.quantile(ret_lst, self.alpha)

    def update_policy(self):
        # calculate return of each traj
        ret_lst = []
        for i in range(self.n_episodes):
            reward_lst = self.reward_buf[i]
            reward_lst_ = copy.deepcopy(reward_lst)
            reward_lst_.reverse()

            ret = 0.
            for t in range(len(reward_lst)):
                ret = reward_lst_[t] + self.discount * ret
            ret_lst.append(ret)

        # sort return
        ret_lst = np.array(ret_lst)
        sort_idx = np.argsort(ret_lst)

        # sample size is alpha * batch_size
        choose_size = int(self.n_episodes * self.alpha)
        quantile_alpha = self.get_quantile(ret_lst)

        cvar_grad_lst = []
        for i in range(choose_size):
            state_lst = self.state_buf[sort_idx[i]]
            action_lst = self.action_buf[sort_idx[i]]

            # create tensor
            state_t = torch.FloatTensor(np.array(state_lst))
            action_t = torch.FloatTensor(np.array(action_lst))
            mu, sigma = self.actor_net(state_t)
            dist = torch.distributions.Normal(mu, sigma)
            # action has multi dim, sum them
            log_pi = dist.log_prob(action_t).sum(dim=-1, keepdim=True)   # [500, 1]
            sum_logpi = log_pi.sum(dim=0)

            R_tau = ret_lst[sort_idx[i]]
            cvar_grad_lst.append((R_tau - quantile_alpha) * sum_logpi )

        cvar_grad = torch.cat(cvar_grad_lst)
        cvar_loss = - cvar_grad.mean()

        self.actor_optimizer.zero_grad()
        cvar_loss.backward()
        self.actor_optimizer.step()
    
        ''' clean buffer '''
        self.state_buf, self.action_buf, self.reward_buf, self.done_buf = [], [], [], []
        