"""Stanza models classifier data functions."""

import collections
from collections import namedtuple
import logging
import json
import random
import re
from typing import List

from stanza.models.classifiers.utils import WVType
from stanza.models.common.vocab import PAD, PAD_ID, UNK, UNK_ID
import stanza.models.constituency.tree_reader as tree_reader

logger = logging.getLogger('stanza')

class SentimentDatum:
    def __init__(self, sentiment, text, constituency=None):
        self.sentiment = sentiment
        self.text = text
        self.constituency = constituency

    def __eq__(self, other):
        if self is other:
            return True
        if not isinstance(other, SentimentDatum):
            return False
        return self.sentiment == other.sentiment and self.text == other.text and self.constituency == other.constituency

    def __str__(self):
        return str(self._asdict())

    def _asdict(self):
        if self.constituency is None:
            return {'sentiment': self.sentiment, 'text': self.text}
        else:
            return {'sentiment': self.sentiment, 'text': self.text, 'constituency': str(self.constituency)}

def update_text(sentence: List[str], wordvec_type: WVType) -> List[str]:
    """
    Process a line of text (with tokenization provided as whitespace)
    into a list of strings.
    """
    # stanford sentiment dataset has a lot of random - and /
    # remove those characters and flatten the newly created sublists into one list each time
    sentence = [y for x in sentence for y in x.split("-") if y]
    sentence = [y for x in sentence for y in x.split("/") if y]
    sentence = [x.strip() for x in sentence]
    sentence = [x for x in sentence if x]
    if sentence == []:
        # removed too much
        sentence = ["-"]
    # our current word vectors are all entirely lowercased
    sentence = [word.lower() for word in sentence]
    if wordvec_type == WVType.WORD2VEC:
        return sentence
    elif wordvec_type == WVType.GOOGLE:
        new_sentence = []
        for word in sentence:
            if word != '0' and word != '1':
                word = re.sub('[0-9]', '#', word)
            new_sentence.append(word)
        return new_sentence
    elif wordvec_type == WVType.FASTTEXT:
        return sentence
    elif wordvec_type == WVType.OTHER:
        return sentence
    else:
        raise ValueError("Unknown wordvec_type {}".format(wordvec_type))


def read_dataset(dataset, wordvec_type: WVType, min_len: int) -> List[SentimentDatum]:
    """
    returns a list where the values of the list are
      label, [token...]
    """
    lines = []
    for filename in str(dataset).split(","):
        with open(filename, encoding="utf-8") as fin:
            new_lines = json.load(fin)
        new_lines = [(str(x['sentiment']), x['text'], x.get('constituency', None)) for x in new_lines]
        lines.extend(new_lines)
    # TODO: maybe do this processing later, once the model is built.
    # then move the processing into the model so we can use
    # overloading to potentially make future model types
    lines = [SentimentDatum(x[0], update_text(x[1], wordvec_type), tree_reader.read_trees(x[2])[0] if x[2] else None) for x in lines]
    if min_len:
        lines = [x for x in lines if len(x.text) >= min_len]
    return lines

def dataset_labels(dataset):
    """
    Returns a sorted list of label name
    """
    labels = set([x.sentiment for x in dataset])
    if all(re.match("^[0-9]+$", label) for label in labels):
        # if all of the labels are integers, sort numerically
        # maybe not super important, but it would be nicer than having
        # 10 before 2
        labels = [str(x) for x in sorted(map(int, list(labels)))]
    else:
        labels = sorted(list(labels))
    return labels

def dataset_vocab(dataset):
    vocab = set()
    for line in dataset:
        for word in line.text:
            vocab.add(word)
    vocab = [PAD, UNK] + list(vocab)
    if vocab[PAD_ID] != PAD or vocab[UNK_ID] != UNK:
        raise ValueError("Unexpected values for PAD and UNK!")
    return vocab

def sort_dataset_by_len(dataset, keep_index=False):
    """
    returns a dict mapping length -> list of items of that length

    an OrderedDict is used so that the mapping is sorted from smallest to largest
    """
    sorted_dataset = collections.OrderedDict()
    lengths = sorted(list(set(len(x.text) for x in dataset)))
    for l in lengths:
        sorted_dataset[l] = []
    for item_idx, item in enumerate(dataset):
        if keep_index:
            sorted_dataset[len(item.text)].append((item, item_idx))
        else:
            sorted_dataset[len(item.text)].append(item)
    return sorted_dataset

def shuffle_dataset(sorted_dataset, batch_size, batch_single_item):
    """
    Given a dataset sorted by len, sorts within each length to make
    chunks of roughly the same size.  Returns all items as a single list.
    """
    dataset = []
    for l in sorted_dataset.keys():
        items = list(sorted_dataset[l])
        random.shuffle(items)
        dataset.extend(items)
    batches = []
    next_batch = []
    for item in dataset:
        if batch_single_item > 0 and len(item.text) >= batch_single_item:
            batches.append([item])
        else:
            next_batch.append(item)
            if len(next_batch) >= batch_size:
                batches.append(next_batch)
                next_batch = []
    if len(next_batch) > 0:
        batches.append(next_batch)
    random.shuffle(batches)
    return batches


def check_labels(labels, dataset):
    """
    Check that all of the labels in the dataset are in the known labels.

    Actually, unknown labels could be acceptable if we just treat the model as always wrong.
    However, this is a good sanity check to make sure the datasets match
    """
    new_labels = dataset_labels(dataset)
    not_found = [i for i in new_labels if i not in labels]
    if not_found:
        raise RuntimeError('Dataset contains labels which the model does not know about:' + str(not_found))

