import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

from buffer import ReplayBuffer

class ActorNet(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim):
        super(ActorNet, self).__init__()
        self.l1 = nn.Linear(input_dim, hidden_dim)
        self.l2 = nn.Linear(hidden_dim, hidden_dim)
        self.out = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        out = self.out(x)
        mu = torch.tanh(out)
        return mu

class Quantile_QNet(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim, quantile_dim):
        super(Quantile_QNet, self).__init__()
        self.l1 = nn.Linear(state_dim + action_dim, hidden_dim)
        self.l2 = nn.Linear(hidden_dim, hidden_dim)
        self.out = nn.Linear(hidden_dim, quantile_dim)
    
    def forward(self, s, a):
        x = torch.cat([s, a], dim=-1)
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        out = self.out(x)
        sort_out, _ = torch.sort(out, dim=-1)
        return sort_out

def soft_update(target, source, tau):
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)

class Policy():
    def __init__(self, state_dim, action_dim, hidden_dim, quantile_dim, actor_lr, qf_lr, cvar_alpha, discount):
        self.actor_net = ActorNet(state_dim, action_dim, hidden_dim)
        self.actor_optimizer = optim.Adam(self.actor_net.parameters(), lr=actor_lr)
        self.target_actor_net = ActorNet(state_dim, action_dim, hidden_dim)
        self.target_actor_net.load_state_dict(self.actor_net.state_dict())

        self.q_net = Quantile_QNet(state_dim, action_dim, hidden_dim, quantile_dim)
        self.q_optimizer = optim.Adam(self.q_net.parameters(), lr=qf_lr)
        self.target_q_net = Quantile_QNet(state_dim, action_dim, hidden_dim, quantile_dim)
        self.target_q_net.load_state_dict(self.q_net.state_dict())

        self.quantile_dim = quantile_dim
        self.cvar_alpha = cvar_alpha
        self.discount = discount
        
        self.soft_update_coef = 0.005

        # buffer
        self.replay_buffer = ReplayBuffer(state_dim, action_dim)
        #self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
    def get_action(self, state):
        mu = self.actor_net(state)
        return mu

    def get_quantiles(self, state, action):
        return self.q_net(state, action)

    def qr_loss(self, curr_v, target_v):
        n = self.quantile_dim
        target_v = target_v.view(-1, n, 1).expand(-1, n, n)
        curr_v = curr_v.view(-1, 1, n).expand(-1, n, n)

        tau = torch.arange(0.5 * (1 / n), 1, 1 / n).view(1, n)
        error_loss = target_v - curr_v
        huber_loss = F.smooth_l1_loss(curr_v, target_v)
        value_loss = (tau - (error_loss < 0).float()).abs() * huber_loss
        value_loss = value_loss.mean(dim=2).sum(dim=1).mean()

        return value_loss

    def cvar_value(self, state, action):
        quantiles = self.q_net(state, action)  # quantiles [128, 80]
        idx = int(self.quantile_dim * self.cvar_alpha)
        cvar = quantiles[:, : idx]              # cvar [128, 16]
        cvar = cvar.mean(-1)
        return cvar.mean()

    def train(self, bs):
        state, action, reward, next_state, not_done = self.replay_buffer.sample(bs)
        
        curr_quantiles = self.q_net(state, action)
        with torch.no_grad():
            target_quantiles = reward + self.discount * self.target_q_net(next_state, self.target_actor_net(next_state)) * not_done

        td_loss = self.qr_loss(curr_quantiles, target_quantiles)
        self.q_optimizer.zero_grad()
        td_loss.backward()
        self.q_optimizer.step()
        soft_update(self.target_q_net, self.q_net, self.soft_update_coef)

        pred_action = self.actor_net(state)
        actor_loss = - self.cvar_value(state, pred_action)
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()
        soft_update(self.target_actor_net, self.actor_net, self.soft_update_coef)



