"""
A module to use a Constituency Parser to make an embedding for a tree

The embedding can be produced just from the words and the top of the
tree, or it can be done with a form of attention over the nodes

Can be done over an existing parse tree or unparsed text
"""


import torch
import torch.nn as nn

from stanza.models.constituency.trainer import Trainer

class TreeEmbedding(nn.Module):
    def __init__(self, constituency_parser, args):
        super(TreeEmbedding, self).__init__()

        self.config = {
            "all_words":   args["all_words"],
            "backprop":    args["backprop"],
            #"batch_norm":  args["batch_norm"],
            "node_attn":   args["node_attn"],
            "top_layer":   args["top_layer"],
        }

        self.constituency_parser = constituency_parser

        # word_lstm:         hidden_size * num_tree_lstm_layers * 2 (start & end)
        # transition_stack:  transition_hidden_size
        # constituent_stack: hidden_size
        self.hidden_size = self.constituency_parser.hidden_size + self.constituency_parser.transition_hidden_size
        if self.config["all_words"]:
            self.hidden_size += self.constituency_parser.hidden_size * self.constituency_parser.num_tree_lstm_layers
        else:
            self.hidden_size += self.constituency_parser.hidden_size * self.constituency_parser.num_tree_lstm_layers * 2

        if self.config["node_attn"]:
            self.query = nn.Linear(self.constituency_parser.hidden_size, self.constituency_parser.hidden_size)
            self.key = nn.Linear(self.hidden_size, self.constituency_parser.hidden_size)
            self.value = nn.Linear(self.constituency_parser.hidden_size, self.constituency_parser.hidden_size)

            # TODO: cat transition and constituent hx as well?
            self.output_size = self.constituency_parser.hidden_size * self.constituency_parser.num_tree_lstm_layers
        else:
            self.output_size = self.hidden_size

        # TODO: maybe have batch_norm, maybe use Identity
        #if self.config["batch_norm"]:
        #    self.input_norm = nn.BatchNorm1d(self.output_size)

    def embed_trees(self, inputs):
        if self.config["backprop"]:
            states = self.constituency_parser.analyze_trees(inputs)
        else:
            with torch.no_grad():
                states = self.constituency_parser.analyze_trees(inputs)

        constituent_lists = [x.constituents for x in states]
        states = [x.state for x in states]

        word_begin_hx = torch.stack([state.word_queue[0].hx for state in states])
        word_end_hx = torch.stack([state.word_queue[state.word_position].hx for state in states])
        transition_hx = torch.stack([self.constituency_parser.transition_stack.output(state.transitions) for state in states])
        # go down one layer to get the embedding off the top of the S, not the ROOT
        # (in terms of the typical treebank)
        # the idea being that the ROOT has no additional information
        # and may even have 0s for the embedding in certain circumstances,
        # such as after learning UNTIED_MAX long enough
        if self.config["top_layer"]:
            constituent_hx = torch.stack([self.constituency_parser.constituent_stack.output(state.constituents) for state in states])
        else:
            constituent_hx = torch.cat([constituents[-2].tree_hx for constituents in constituent_lists], dim=0)

        if self.config["all_words"]:
            # need B matrices of N x hidden_size
            key = [torch.stack([torch.cat([word.hx, thx, chx]) for word in state.word_queue], dim=0)
                   for state, thx, chx in zip(states, transition_hx, constituent_hx)]
        else:
            key = torch.cat((word_begin_hx, word_end_hx, transition_hx, constituent_hx), dim=1).unsqueeze(1)

        if not self.config["node_attn"]:
            return key
        key = [self.key(x) for x in key]

        node_hx = [torch.stack([con.tree_hx for con in constituents], dim=0) for constituents in constituent_lists]
        queries = [self.query(nhx).reshape(nhx.shape[0], -1) for nhx in node_hx]
        values = [self.value(nhx).reshape(nhx.shape[0], -1) for nhx in node_hx]
        # TODO: could pad to make faster here
        attn = [torch.matmul(q, k.transpose(0, 1)) for q, k in zip(queries, key)]
        attn = [torch.softmax(x, dim=0) for x in attn]
        previous_layer = [torch.matmul(weight.transpose(0, 1), value) for weight, value in zip(attn, values)]
        return previous_layer

    def forward(self, inputs):
        return embed_trees(self, inputs)

    def get_norms(self):
        lines = ["constituency_parser." + x for x in self.constituency_parser.get_norms()]
        for name, param in self.named_parameters():
            if param.requires_grad and not name.startswith('constituency_parser.'):
                lines.append("%s %.6g" % (name, torch.norm(param).item()))
        return lines


    def get_params(self, skip_modules=True):
        model_state = self.state_dict()
        # skip all of the constituency parameters here -
        # we will add them by calling the model's get_params()
        skipped = [k for k in model_state.keys() if k.startswith("constituency_parser.")]
        for k in skipped:
            del model_state[k]

        parser = self.constituency_parser.get_params(skip_modules)

        params = {
            'model':         model_state,
            'constituency':  parser,
            'config':        self.config,
        }
        return params

    @staticmethod
    def from_parser_file(args, foundation_cache=None):
        constituency_parser = Trainer.load(args['model'], args, foundation_cache)
        return TreeEmbedding(constituency_parser.model, args)

    @staticmethod
    def model_from_params(params, args, foundation_cache=None):
        # TODO: integrate with peft
        constituency_parser = Trainer.model_from_params(params['constituency'], None, args, foundation_cache)
        model = TreeEmbedding(constituency_parser, params['config'])
        model.load_state_dict(params['model'], strict=False)
        return model
