import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
from iql_agent import IQL
from buffer import ReplayBuffer
import copy
import numpy as np

class Actor_p(nn.Module):
    def __init__(self, state_size, action_size, hidden_size):
        super(Actor_p, self).__init__()
        self.fc1 = nn.Linear(state_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.p = nn.Linear(hidden_size, 1)
        self.mu = nn.Linear(hidden_size, action_size)
        self.log_std = nn.Parameter(torch.zeros(1, action_size))

    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        p = torch.sigmoid(self.p(x))

        mu = torch.tanh(self.mu(x))
        log_std = self.log_std
        return mu, log_std, p

class Policy_mix():
    def __init__(self,
                state_size,
                action_size,
                hidden_size,
                cvar_learning_rate,
                iql_policy_learning_rate,
                iql_value_learning_rate,
                cvar_alpha,
                iql_tau,
                iql_temperature,
                iql_expectile,
                gamma,
                device):
        self.device = device
        self.iql_agent = IQL(state_size,
                            action_size,
                            iql_policy_learning_rate,
                            iql_value_learning_rate,
                            hidden_size,
                            iql_tau,
                            iql_temperature,
                            iql_expectile,
                            gamma,
                            device)
        self.actor = Actor_p(state_size,
                            action_size,
                            hidden_size).to(device)
        
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=cvar_learning_rate)

        self.cvar_alpha = cvar_alpha
        self.state_buf, self.action_buf, self.reward_buf = [],[],[]
        self.replay_buffer = ReplayBuffer(state_size, action_size)
        self.gamma = gamma

    def add_traj(self, state_lst, action_lst, reward_lst):
        self.state_buf.append(state_lst)
        self.action_buf.append(action_lst)
        self.reward_buf.append(reward_lst)
    
    def get_action(self, state):
        with torch.no_grad():
            mu, log_std, p = self.actor(state)
            dist = Normal(mu, log_std.exp()[0])

            mu_iql, log_std_iql = self.iql_agent.actor_local(state)
            dist_iql = Normal(mu_iql, log_std_iql.exp())

        rnd = torch.rand(1)
        if rnd[0] < p[0]:
            # sample from pi_1
            action = dist.sample()
        else:
            # sample from iql
            action = dist_iql.sample()
        return action.cpu().numpy()

    def gcvar(self):
        ret_lst = []
        n_episodes = len(self.state_buf)

        for i in range(n_episodes):
            reward_lst = self.reward_buf[i]
            reward_lst_ = copy.deepcopy(reward_lst)
            reward_lst_.reverse()

            ret = 0.
            for t in range(len(reward_lst)):
                ret = reward_lst_[t] + self.gamma * ret
            ret_lst.append(ret)

        # sort return
        sort_idx = np.argsort(ret_lst)
        # choose tail
        choose_size = int(n_episodes * self.cvar_alpha)
        quantile = np.quantile(ret_lst, self.cvar_alpha)

        cvar_grad_lst = []
        for i in range(choose_size):
            state_lst = self.state_buf[sort_idx[i]]
            action_lst = self.action_buf[sort_idx[i]]

            state = torch.FloatTensor(np.array(state_lst)).to(self.device)  # [500, 18]
            action = torch.FloatTensor(np.array(action_lst)).to(self.device) #[500, 6]

            mu, log_std, p = self.actor(state)
            dist = Normal(mu, log_std.exp())
            with torch.no_grad():
                mu_iql, log_std_iql = self.iql_agent.actor_local(state)
                dist_iql = Normal(mu_iql, log_std_iql.exp())

            log_prob = dist.log_prob(action).sum(dim=-1, keepdim=True)  # [500, 1]
            log_prob_iql = dist_iql.log_prob(action).sum(dim=-1,keepdim=True) # [500, 1]
            action_prob = p * torch.exp(log_prob) + (1- p) * torch.exp(log_prob_iql)

            action_log_prob = torch.log(action_prob)

            sum_log_prob = action_log_prob.sum(dim=0)

            R = ret_lst[sort_idx[i]]
            cvar_grad_lst.append((R - quantile) * sum_log_prob)

        cvar_grad = torch.cat(cvar_grad_lst)
        cvar_loss = - cvar_grad.mean()

        self.actor_optimizer.zero_grad()
        cvar_loss.backward()
        self.actor_optimizer.step()

        self.state_buf, self.action_buf, self.reward_buf = [], [], []


    def train_iql(self, iql_update, iql_sample_size, batch_size):
        for _ in range(iql_update):
            for states, actions, rewards, next_states, dones in self.replay_buffer.make_mini_batch(iql_sample_size, batch_size):
                self.iql_agent.learn((states, actions, rewards, next_states, dones))
        
    
        


