"""
Processor for performing tokenization
"""

import io
import logging

import torch

from stanza.models.tokenization.data import TokenizationDataset
from stanza.models.tokenization.trainer import Trainer
from stanza.models.tokenization.utils import output_predictions
from stanza.pipeline._constants import *
from stanza.pipeline.processor import UDProcessor, register_processor
from stanza.pipeline.registry import PROCESSOR_VARIANTS
from stanza.models.common import doc

# these imports trigger the "register_variant" decorations
from stanza.pipeline.external.jieba import JiebaTokenizer
from stanza.pipeline.external.spacy import SpacyTokenizer
from stanza.pipeline.external.sudachipy import SudachiPyTokenizer
from stanza.pipeline.external.pythainlp import PyThaiNLPTokenizer

logger = logging.getLogger('stanza')

TOKEN_TOO_LONG_REPLACEMENT = "<UNK>"

# class for running the tokenizer
@register_processor(name=TOKENIZE)
class TokenizeProcessor(UDProcessor):

    # set of processor requirements this processor fulfills
    PROVIDES_DEFAULT = set([TOKENIZE])
    # set of processor requirements for this processor
    REQUIRES_DEFAULT = set([])
    # default max sequence length
    MAX_SEQ_LENGTH_DEFAULT = 1000

    def _set_up_model(self, config, pipeline, device):
        # set up trainer
        if config.get('pretokenized'):
            self._trainer = None
        else:
            self._trainer = Trainer(model_file=config['model_path'], device=device)

        # get and typecheck the postprocessor
        postprocessor = config.get('postprocessor')
        if postprocessor and callable(postprocessor):
            self._postprocessor = postprocessor
        elif not postprocessor:
            self._postprocessor = None
        else:
            raise ValueError("Tokenizer recieved 'postprocessor' option of unrecognized type; postprocessor must be callable. Got %s" % postprocessor)

    def process_pre_tokenized_text(self, input_src):
        """
        Pretokenized text can be provided in 2 manners:

        1.) str, tokenized by whitespace, sentence split by newline
        2.) list of token lists, each token list represents a sentence

        generate dictionary data structure
        """

        document = []
        if isinstance(input_src, str):
            sentences = [sent.strip().split() for sent in input_src.strip().split('\n') if len(sent.strip()) > 0]
        elif isinstance(input_src, list):
            sentences = input_src
        idx = 0
        for sentence in sentences:
            sent = []
            for token_id, token in enumerate(sentence):
                sent.append({doc.ID: (token_id + 1, ), doc.TEXT: token, doc.MISC: f'start_char={idx}|end_char={idx + len(token)}'})
                idx += len(token) + 1
            document.append(sent)
        raw_text = ' '.join([' '.join(sentence) for sentence in sentences])
        return raw_text, document

    def process(self, document):
        if not (isinstance(document, str) or isinstance(document, doc.Document) or (self.config.get('pretokenized') or self.config.get('no_ssplit', False))):
            raise ValueError("If neither 'pretokenized' or 'no_ssplit' option is enabled, the input to the TokenizerProcessor must be a string or a Document object.  Got %s" % str(type(document)))

        if isinstance(document, doc.Document):
            if self.config.get('pretokenized'):
                return document
            document = document.text

        if self.config.get('pretokenized'):
            raw_text, document = self.process_pre_tokenized_text(document)
            return doc.Document(document, raw_text)

        if hasattr(self, '_variant'):
            return self._variant.process(document)

        raw_text = '\n\n'.join(document) if isinstance(document, list) else document

        max_seq_len = self.config.get('max_seqlen', TokenizeProcessor.MAX_SEQ_LENGTH_DEFAULT)

        # set up batches
        batches = TokenizationDataset(self.config, input_text=raw_text, vocab=self.vocab, evaluation=True, dictionary=self.trainer.dictionary)
        # get dict data
        with torch.no_grad():
            _, _, _, document = output_predictions(None, self.trainer, batches, self.vocab, None,
                                                   max_seq_len,
                                                   orig_text=raw_text,
                                                   no_ssplit=self.config.get('no_ssplit', False),
                                                   num_workers = self.config.get('num_workers', 0),
                                                   postprocessor = self._postprocessor)

        # replace excessively long tokens with <UNK> to avoid downstream GPU memory issues in POS
        for sentence in document:
            for token in sentence:
                if len(token['text']) > max_seq_len:
                    token['text'] = TOKEN_TOO_LONG_REPLACEMENT

        return doc.Document(document, raw_text)

    def bulk_process(self, docs):
        """
        The tokenizer cannot use UDProcessor's sentence-level cross-document batching interface, and requires special handling.
        Essentially, this method concatenates the text of multiple documents with "\n\n", tokenizes it with the neural tokenizer,
        then splits the result into the original Documents and recovers the original character offsets.
        """
        if hasattr(self, '_variant'):
            return self._variant.bulk_process(docs)

        if self.config.get('pretokenized'):
            res = []
            for document in docs:
                raw_text, document = self.process_pre_tokenized_text(document.text)
                res.append(doc.Document(document, raw_text))
            return res

        combined_text = '\n\n'.join([thisdoc.text for thisdoc in docs])
        processed_combined = self.process(doc.Document([], text=combined_text))

        # postprocess sentences and tokens to reset back pointers and char offsets
        charoffset = 0
        sentst = senten = 0
        for thisdoc in docs:
            while senten < len(processed_combined.sentences) and processed_combined.sentences[senten].tokens[-1].end_char - charoffset <= len(thisdoc.text):
                senten += 1

            sentences = processed_combined.sentences[sentst:senten]
            thisdoc.sentences = sentences
            for sent in sentences:
                # fix doc back pointers for sentences
                sent._doc = thisdoc

                # fix char offsets for tokens and words
                for token in sent.tokens:
                    token._start_char -= charoffset
                    token._end_char -= charoffset
                    if token.words:  # not-yet-processed MWT can leave empty tokens
                        for word in token.words:
                            word._start_char -= charoffset
                            word._end_char -= charoffset

            # Here we need to fix up the SpacesAfter for the very last token
            # and the SpacesBefore for the first token of the next doc
            # After all, we had connected the text with \n\n
            # Need to be careful about this - in a case such as
            #   " -text one- "
            #   " -text two- "
            # We want the SpacesBefore for the second document to reflect
            # the extra space at the start of its text
            # and the SpacesAfter for the first document to reflect
            # the whitespace after its text
            if len(sentences) > 0:
                last_token = sentences[-1].tokens[-1]
                last_whitespace = thisdoc.text[last_token.end_char:]
                last_token.spaces_after = last_whitespace

                first_token = sentences[0].tokens[0]
                first_whitespace = thisdoc.text[:first_token.start_char]
                first_token.spaces_before = first_whitespace

            thisdoc.num_tokens = sum(len(sent.tokens) for sent in sentences)
            thisdoc.num_words = sum(len(sent.words) for sent in sentences)
            sentst = senten

            charoffset += len(thisdoc.text) + 2

        return docs
