"""
Utility functions for data transformations.
"""

import logging
import random

import torch

import stanza.models.common.seq2seq_constant as constant
from stanza.models.common.doc import HEAD, ID, UPOS

logger = logging.getLogger('stanza')

def map_to_ids(tokens, vocab):
    ids = [vocab[t] if t in vocab else constant.UNK_ID for t in tokens]
    return ids

def get_long_tensor(tokens_list, batch_size, pad_id=constant.PAD_ID):
    """ Convert (list of )+ tokens to a padded LongTensor. """
    sizes = []
    x = tokens_list
    while isinstance(x[0], list):
        sizes.append(max(len(y) for y in x))
        x = [z for y in x for z in y]
    # TODO: pass in a device parameter and put it directly on the relevant device?
    # that might be faster than creating it and then moving it
    tokens = torch.LongTensor(batch_size, *sizes).fill_(pad_id)
    for i, s in enumerate(tokens_list):
        tokens[i, :len(s)] = torch.LongTensor(s)
    return tokens

def get_float_tensor(features_list, batch_size):
    if features_list is None or features_list[0] is None:
        return None
    seq_len = max(len(x) for x in features_list)
    feature_len = len(features_list[0][0])
    features = torch.FloatTensor(batch_size, seq_len, feature_len).zero_()
    for i,f in enumerate(features_list):
        features[i,:len(f),:] = torch.FloatTensor(f)
    return features

def sort_all(batch, lens):
    """ Sort all fields by descending order of lens, and return the original indices. """
    if batch == [[]]:
        return [[]], []
    unsorted_all = [lens] + [range(len(lens))] + list(batch)
    sorted_all = [list(t) for t in zip(*sorted(zip(*unsorted_all), reverse=True))]
    return sorted_all[2:], sorted_all[1]

def get_augment_ratio(train_data, should_augment_predicate, can_augment_predicate, desired_ratio=0.1, max_ratio=0.5):
    """
    Returns X so that if you randomly select X * N sentences, you get 10%

    The ratio will be chosen in the assumption that the final dataset
    is of size N rather than N + X * N.

    should_augment_predicate: returns True if the sentence has some
      feature which we may want to change occasionally.  for example,
      depparse sentences which end in punct
    can_augment_predicate: in the depparse sentences example, it is
      technically possible for the punct at the end to be the parent
      of some other word in the sentence.  in that case, the sentence
      should not be chosen.  should be at least as restrictive as
      should_augment_predicate
    """
    n_data = len(train_data)
    n_should_augment = sum(should_augment_predicate(sentence) for sentence in train_data)
    n_can_augment = sum(can_augment_predicate(sentence) for sentence in train_data)
    n_error = sum(can_augment_predicate(sentence) and not should_augment_predicate(sentence)
                  for sentence in train_data)
    if n_error > 0:
        raise AssertionError("can_augment_predicate allowed sentences not allowed by should_augment_predicate")

    if n_can_augment == 0:
        logger.warning("Found no sentences which matched can_augment_predicate {}".format(can_augment_predicate))
        return 0.0
    n_needed = n_data * desired_ratio - (n_data - n_should_augment)
    # if we want 10%, for example, and more than 10% already matches, we can skip
    if n_needed < 0:
        return 0.0
    ratio = n_needed / n_can_augment
    if ratio > max_ratio:
        return max_ratio
    return ratio


def should_augment_nopunct_predicate(sentence):
    last_word = sentence[-1]
    return last_word.get(UPOS, None) == 'PUNCT'

def can_augment_nopunct_predicate(sentence):
    """
    Check that the sentence ends with PUNCT and also doesn't have any words which depend on the last word
    """
    last_word = sentence[-1]
    if last_word.get(UPOS, None) != 'PUNCT':
        return False
    # don't cut off MWT
    if len(last_word[ID]) > 1:
        return False
    if any(len(word[ID]) == 1 and word[HEAD] == last_word[ID][0] for word in sentence):
        return False
    return True

def augment_punct(train_data, augment_ratio,
                  should_augment_predicate=should_augment_nopunct_predicate,
                  can_augment_predicate=can_augment_nopunct_predicate,
                  keep_original_sentences=True):

    """
    Adds extra training data to compensate for some models having all sentences end with PUNCT

    Some of the models (for example, UD_Hebrew-HTB) have the flaw that
    all of the training sentences end with PUNCT.  The model therefore
    learns to finish every sentence with punctuation, even if it is
    given a sentence with non-punct at the end.

    One simple way to fix this is to train on some fraction of training data with punct.

    Params:
    train_data: list of list of dicts, eg a conll doc
    augment_ratio: the fraction to augment.  if None, a best guess is made to get to 10%

    should_augment_predicate: a function which returns T/F if a sentence already ends with not PUNCT
    can_augment_predicate: a function which returns T/F if it makes sense to remove the last PUNCT

    TODO: do this dynamically, as part of the DataLoader or elsewhere?
    One complication is the data comes back from the DataLoader as
    tensors & indices, so it is much more complicated to manipulate
    """
    if len(train_data) == 0:
        return []

    if augment_ratio is None:
        augment_ratio = get_augment_ratio(train_data, should_augment_predicate, can_augment_predicate)

    if augment_ratio <= 0:
        if keep_original_sentences:
            return list(train_data)
        else:
            return []

    new_data = []
    for sentence in train_data:
        if can_augment_predicate(sentence):
            if random.random() < augment_ratio and len(sentence) > 1:
                # todo: could deep copy the words
                #       or not deep copy any of this
                new_sentence = list(sentence[:-1])
                new_data.append(new_sentence)
            elif keep_original_sentences:
                new_data.append(new_sentence)

    return new_data
