"""
The full encoder-decoder model, built on top of the base seq2seq modules.
"""

import logging
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np

import stanza.models.common.seq2seq_constant as constant
from stanza.models.common import utils
from stanza.models.common.seq2seq_modules import LSTMAttention
from stanza.models.common.beam import Beam
from stanza.models.common.seq2seq_constant import UNK_ID

logger = logging.getLogger('stanza')

class Seq2SeqModel(nn.Module):
    """
    A complete encoder-decoder model, with optional attention.

    A parent class which makes use of the contextual_embedding (such as a charlm)
    can make use of unsaved_modules when saving.
    """
    def __init__(self, args, emb_matrix=None, contextual_embedding=None):
        super().__init__()

        self.unsaved_modules = []

        self.vocab_size = args['vocab_size']
        self.emb_dim = args['emb_dim']
        self.hidden_dim = args['hidden_dim']
        self.nlayers = args['num_layers'] # encoder layers, decoder layers = 1
        self.emb_dropout = args.get('emb_dropout', 0.0)
        self.dropout = args['dropout']
        self.pad_token = constant.PAD_ID
        self.max_dec_len = args['max_dec_len']
        self.top = args.get('top', 1e10)
        self.args = args
        self.emb_matrix = emb_matrix
        self.add_unsaved_module("contextual_embedding", contextual_embedding)

        logger.debug("Building an attentional Seq2Seq model...")
        logger.debug("Using a Bi-LSTM encoder")
        self.num_directions = 2
        self.enc_hidden_dim = self.hidden_dim // 2
        self.dec_hidden_dim = self.hidden_dim

        self.use_pos = args.get('pos', False)
        self.pos_dim = args.get('pos_dim', 0)
        self.pos_vocab_size = args.get('pos_vocab_size', 0)
        self.pos_dropout = args.get('pos_dropout', 0)
        self.edit = args.get('edit', False)
        self.num_edit = args.get('num_edit', 0)
        self.copy = args.get('copy', False)

        self.emb_drop = nn.Dropout(self.emb_dropout)
        self.drop = nn.Dropout(self.dropout)
        self.embedding = nn.Embedding(self.vocab_size, self.emb_dim, self.pad_token)
        self.input_dim = self.emb_dim
        if self.contextual_embedding is not None:
            self.input_dim += self.contextual_embedding.hidden_dim()
        self.encoder = nn.LSTM(self.input_dim, self.enc_hidden_dim, self.nlayers, \
                bidirectional=True, batch_first=True, dropout=self.dropout if self.nlayers > 1 else 0)
        self.decoder = LSTMAttention(self.emb_dim, self.dec_hidden_dim, \
                batch_first=True, attn_type=self.args['attn_type'])
        self.dec2vocab = nn.Linear(self.dec_hidden_dim, self.vocab_size)
        if self.use_pos and self.pos_dim > 0:
            logger.debug("Using POS in encoder")
            self.pos_embedding = nn.Embedding(self.pos_vocab_size, self.pos_dim, self.pad_token)
            self.pos_drop = nn.Dropout(self.pos_dropout)
        if self.edit:
            edit_hidden = self.hidden_dim//2
            self.edit_clf = nn.Sequential(
                    nn.Linear(self.hidden_dim, edit_hidden),
                    nn.ReLU(),
                    nn.Linear(edit_hidden, self.num_edit))

        if self.copy:
            self.copy_gate = nn.Linear(self.dec_hidden_dim, 1)

        SOS_tensor = torch.LongTensor([constant.SOS_ID])
        self.register_buffer('SOS_tensor', SOS_tensor)

        self.init_weights()

    def add_unsaved_module(self, name, module):
        self.unsaved_modules += [name]
        setattr(self, name, module)

    def init_weights(self):
        # initialize embeddings
        init_range = constant.EMB_INIT_RANGE
        if self.emb_matrix is not None:
            if isinstance(self.emb_matrix, np.ndarray):
                self.emb_matrix = torch.from_numpy(self.emb_matrix)
            assert self.emb_matrix.size() == (self.vocab_size, self.emb_dim), \
                    "Input embedding matrix must match size: {} x {}".format(self.vocab_size, self.emb_dim)
            self.embedding.weight.data.copy_(self.emb_matrix)
        else:
            self.embedding.weight.data.uniform_(-init_range, init_range)
        # decide finetuning
        if self.top <= 0:
            logger.debug("Do not finetune embedding layer.")
            self.embedding.weight.requires_grad = False
        elif self.top < self.vocab_size:
            logger.debug("Finetune top {} embeddings.".format(self.top))
            self.embedding.weight.register_hook(lambda x: utils.keep_partial_grad(x, self.top))
        else:
            logger.debug("Finetune all embeddings.")
        # initialize pos embeddings
        if self.use_pos:
            self.pos_embedding.weight.data.uniform_(-init_range, init_range)

    def zero_state(self, inputs):
        batch_size = inputs.size(0)
        device = self.SOS_tensor.device
        h0 = torch.zeros(self.encoder.num_layers*2, batch_size, self.enc_hidden_dim, requires_grad=False, device=device)
        c0 = torch.zeros(self.encoder.num_layers*2, batch_size, self.enc_hidden_dim, requires_grad=False, device=device)
        return h0, c0

    def encode(self, enc_inputs, lens):
        """ Encode source sequence. """
        h0, c0 = self.zero_state(enc_inputs)

        packed_inputs = nn.utils.rnn.pack_padded_sequence(enc_inputs, lens, batch_first=True)
        packed_h_in, (hn, cn) = self.encoder(packed_inputs, (h0, c0))
        h_in, _ = nn.utils.rnn.pad_packed_sequence(packed_h_in, batch_first=True)
        hn = torch.cat((hn[-1], hn[-2]), 1)
        cn = torch.cat((cn[-1], cn[-2]), 1)
        return h_in, (hn, cn)

    def decode(self, dec_inputs, hn, cn, ctx, ctx_mask=None, src=None, never_decode_unk=False):
        """ Decode a step, based on context encoding and source context states."""
        dec_hidden = (hn, cn)
        decoder_output = self.decoder(dec_inputs, dec_hidden, ctx, ctx_mask, return_logattn=self.copy)
        if self.copy:
            h_out, dec_hidden, log_attn = decoder_output
        else:
            h_out, dec_hidden = decoder_output

        h_out_reshape = h_out.contiguous().view(h_out.size(0) * h_out.size(1), -1)
        decoder_logits = self.dec2vocab(h_out_reshape)
        decoder_logits = decoder_logits.view(h_out.size(0), h_out.size(1), -1)
        log_probs = self.get_log_prob(decoder_logits)

        if self.copy:
            copy_logit = self.copy_gate(h_out)
            if self.use_pos:
                # can't copy the UPOS
                log_attn = log_attn[:, :, 1:]

            # renormalize
            log_attn = torch.log_softmax(log_attn, -1)
            # calculate copy probability for each word in the vocab
            log_copy_prob = torch.nn.functional.logsigmoid(copy_logit) + log_attn
            # scatter logsumexp
            mx = log_copy_prob.max(-1, keepdim=True)[0]
            log_copy_prob = log_copy_prob - mx
            # here we make space in the log probs for vocab items
            # which might be copied from the encoder side, but which
            # were not known at training time
            # note that such an item cannot possibly be predicted by
            # the model as a raw output token
            # however, the copy gate might score high on copying a
            # previously unknown vocab item
            copy_prob = torch.exp(log_copy_prob)
            copied_vocab_shape = list(log_probs.size())
            if torch.max(src) >= copied_vocab_shape[-1]:
                copied_vocab_shape[-1] = torch.max(src) + 1
            copied_vocab_prob = log_probs.new_zeros(copied_vocab_shape)
            scattered_copy = src.unsqueeze(1).expand(src.size(0), copy_prob.size(1), src.size(1))
            # fill in the copy tensor with the copy probs of each character
            # the rest of the copy tensor will be filled with -largenumber
            copied_vocab_prob = copied_vocab_prob.scatter_add(-1, scattered_copy, copy_prob)
            zero_mask = (copied_vocab_prob == 0)
            log_copied_vocab_prob = torch.log(copied_vocab_prob.masked_fill(zero_mask, 1e-12)) + mx
            log_copied_vocab_prob = log_copied_vocab_prob.masked_fill(zero_mask, -1e12)

            # combine with normal vocab probability
            log_nocopy_prob = -torch.log(1 + torch.exp(copy_logit))
            if log_probs.shape[-1] < copied_vocab_shape[-1]:
                # for previously unknown vocab items which are in the encoder,
                # we reuse the UNK_ID prediction
                # this gives a baseline number which we can combine with
                # the copy gate prediction
                # technically this makes log_probs no longer represent
                # a probability distribution when looking at unknown vocab
                # this is probably not a serious problem
                # an example of this usage is in the Lemmatizer, such as a
                # plural word in English with the character "ã" in it instead of "a"
                # if "ã" is not known in the training data, the lemmatizer would
                # ordinarily be unable to output it, and thus the seq2seq model
                # would have no chance to depluralize "ãntennae" -> "ãntenna"
                # however, if we temporarily add "ã" to the encoder vocab,
                # then let the copy gate accept that letter, we find the Lemmatizer
                # seq2seq model will want to copy that particular vocab item
                # this allows the Lemmatizer to produce "ã" instead of requiring
                # that it produces UNK, then going back to the input text to
                # figure out which UNK it intended to produce
                new_log_probs = log_probs.new_zeros(copied_vocab_shape)
                new_log_probs[:, :, :log_probs.shape[-1]] = log_probs
                new_log_probs[:, :, log_probs.shape[-1]:] = new_log_probs[:, :, UNK_ID].unsqueeze(2)
                log_probs = new_log_probs
            log_probs = log_probs + log_nocopy_prob
            log_probs = torch.logsumexp(torch.stack([log_copied_vocab_prob, log_probs]), 0)

        if never_decode_unk:
            log_probs[:, :, UNK_ID] = float("-inf")
        return log_probs, dec_hidden

    def embed(self, src, src_mask, pos, raw):
        embed_src = src.clone()
        embed_src[embed_src >= self.vocab_size] = UNK_ID
        enc_inputs = self.emb_drop(self.embedding(embed_src))
        batch_size = enc_inputs.size(0)
        if self.use_pos:
            assert pos is not None, "Missing POS input for seq2seq lemmatizer."
            pos_inputs = self.pos_drop(self.pos_embedding(pos))
            enc_inputs = torch.cat([pos_inputs.unsqueeze(1), enc_inputs], dim=1)
            pos_src_mask = src_mask.new_zeros([batch_size, 1])
            src_mask = torch.cat([pos_src_mask, src_mask], dim=1)
        if raw is not None and self.contextual_embedding is not None:
            raw_inputs = self.contextual_embedding(raw)
            if self.use_pos:
                raw_zeros = raw_inputs.new_zeros((raw_inputs.shape[0], 1, raw_inputs.shape[2]))
                raw_inputs = torch.cat([raw_inputs, raw_zeros], dim=1)
            enc_inputs = torch.cat([enc_inputs, raw_inputs], dim=2)
        src_lens = list(src_mask.data.eq(constant.PAD_ID).long().sum(1))
        return enc_inputs, batch_size, src_lens, src_mask

    def forward(self, src, src_mask, tgt_in, pos=None, raw=None):
        # prepare for encoder/decoder
        enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask, pos, raw)

        # encode source
        h_in, (hn, cn) = self.encode(enc_inputs, src_lens)

        if self.edit:
            edit_logits = self.edit_clf(hn)
        else:
            edit_logits = None

        dec_inputs = self.emb_drop(self.embedding(tgt_in))

        log_probs, _ = self.decode(dec_inputs, hn, cn, h_in, src_mask, src=src)
        return log_probs, edit_logits

    def get_log_prob(self, logits):
        logits_reshape = logits.view(-1, self.vocab_size)
        log_probs = F.log_softmax(logits_reshape, dim=1)
        if logits.dim() == 2:
            return log_probs
        return log_probs.view(logits.size(0), logits.size(1), logits.size(2))

    def predict_greedy(self, src, src_mask, pos=None, raw=None, never_decode_unk=False):
        """ Predict with greedy decoding. """
        enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask, pos, raw)

        # encode source
        h_in, (hn, cn) = self.encode(enc_inputs, src_lens)

        if self.edit:
            edit_logits = self.edit_clf(hn)
        else:
            edit_logits = None

        # greedy decode by step
        dec_inputs = self.embedding(self.SOS_tensor)
        dec_inputs = dec_inputs.expand(batch_size, dec_inputs.size(0), dec_inputs.size(1))

        done = [False for _ in range(batch_size)]
        total_done = 0
        max_len = 0
        output_seqs = [[] for _ in range(batch_size)]

        while total_done < batch_size and max_len < self.max_dec_len:
            log_probs, (hn, cn) = self.decode(dec_inputs, hn, cn, h_in, src_mask, src=src, never_decode_unk=never_decode_unk)
            assert log_probs.size(1) == 1, "Output must have 1-step of output."
            _, preds = log_probs.squeeze(1).max(1, keepdim=True)
            # if a unlearned character is predicted via the copy mechanism,
            # use the UNK embedding for it
            dec_inputs = preds.clone()
            dec_inputs[dec_inputs >= self.vocab_size] = UNK_ID
            dec_inputs = self.embedding(dec_inputs) # update decoder inputs
            max_len += 1
            for i in range(batch_size):
                if not done[i]:
                    token = preds.data[i][0].item()
                    if token == constant.EOS_ID:
                        done[i] = True
                        total_done += 1
                    else:
                        output_seqs[i].append(token)
        return output_seqs, edit_logits

    def predict(self, src, src_mask, pos=None, beam_size=5, raw=None, never_decode_unk=False):
        """ Predict with beam search. """
        if beam_size == 1:
            return self.predict_greedy(src, src_mask, pos, raw, never_decode_unk=never_decode_unk)

        enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask, pos, raw)

        # (1) encode source
        h_in, (hn, cn) = self.encode(enc_inputs, src_lens)

        if self.edit:
            edit_logits = self.edit_clf(hn)
        else:
            edit_logits = None

        # (2) set up beam
        with torch.no_grad():
            h_in = h_in.data.repeat(beam_size, 1, 1) # repeat data for beam search
            src_mask = src_mask.repeat(beam_size, 1)
            # repeat decoder hidden states
            hn = hn.data.repeat(beam_size, 1)
            cn = cn.data.repeat(beam_size, 1)
        device = self.SOS_tensor.device
        beam = [Beam(beam_size, device) for _ in range(batch_size)]

        def update_state(states, idx, positions, beam_size):
            """ Select the states according to back pointers. """
            for e in states:
                br, d = e.size()
                s = e.contiguous().view(beam_size, br // beam_size, d)[:,idx]
                s.data.copy_(s.data.index_select(0, positions))

        # (3) main loop
        for i in range(self.max_dec_len):
            dec_inputs = torch.stack([b.get_current_state() for b in beam]).t().contiguous().view(-1, 1)
            # if a unlearned character is predicted via the copy mechanism,
            # use the UNK embedding for it
            dec_inputs[dec_inputs >= self.vocab_size] = UNK_ID
            dec_inputs = self.embedding(dec_inputs)
            log_probs, (hn, cn) = self.decode(dec_inputs, hn, cn, h_in, src_mask, src=src, never_decode_unk=never_decode_unk)
            log_probs = log_probs.view(beam_size, batch_size, -1).transpose(0,1).contiguous() # [batch, beam, V]

            # advance each beam
            done = []
            for b in range(batch_size):
                is_done = beam[b].advance(log_probs.data[b])
                if is_done:
                    done += [b]
                # update beam state
                update_state((hn, cn), b, beam[b].get_current_origin(), beam_size)

            if len(done) == batch_size:
                break

        # back trace and find hypothesis
        all_hyp, all_scores = [], []
        for b in range(batch_size):
            scores, ks = beam[b].sort_best()
            all_scores += [scores[0]]
            k = ks[0]
            hyp = beam[b].get_hyp(k)
            hyp = utils.prune_hyp(hyp)
            hyp = [i.item() for i in hyp]
            all_hyp += [hyp]

        return all_hyp, edit_logits

