import json
import random
import torch


class DataLoader:
    """
    Class for loading language id data and providing batches

    Attempt to recreate data pre-processing from: https://github.com/AU-DIS/LSTM_langid

    Uses methods from: https://github.com/AU-DIS/LSTM_langid/blob/main/src/language_datasets.py

    Data format is same as LSTM_langid
    """

    def __init__(self, device=None):
        self.batches = None
        self.batches_iter = None
        self.tag_to_idx = None
        self.idx_to_tag = None
        self.lang_weights = None
        self.device = device

    def load_data(self, batch_size, data_files, char_index, tag_index, randomize=False, randomize_range=(5,20),
                  max_length=None):
        """
        Load sequence data and labels, calculate weights for weighted cross entropy loss.
        Data is stored in a file, 1 example per line
        Example: {"text": "Hello world.", "label": "en"}
        """

        # set up examples from data files
        examples = []
        for data_file in data_files:
            examples += [x for x in open(data_file).read().split("\n") if x.strip()]
        random.shuffle(examples)
        examples = [json.loads(x) for x in examples]

        # add additional labels in this data set to tag index
        tag_index = dict(tag_index)
        new_labels = set([x["label"] for x in examples]) - set(tag_index.keys())
        for new_label in new_labels:
            tag_index[new_label] = len(tag_index)
        self.tag_to_idx = tag_index
        self.idx_to_tag = [i[1] for i in sorted([(v,k) for k,v in self.tag_to_idx.items()])]
        
        # set up lang counts used for weights for cross entropy loss
        lang_counts = [0 for _ in tag_index]

        # optionally limit text to max length
        if max_length is not None:
            examples = [{"text": x["text"][:max_length], "label": x["label"]} for x in examples]

        # randomize data
        if randomize:
            split_examples = []
            for example in examples:
                sequence = example["text"]
                label = example["label"]
                sequences = DataLoader.randomize_data([sequence], upper_lim=randomize_range[1], 
                                                      lower_lim=randomize_range[0])
                split_examples += [{"text": seq, "label": label} for seq in sequences]
            examples = split_examples
            random.shuffle(examples)

        # break into equal length batches
        batch_lengths = {}
        for example in examples:
            sequence = example["text"]
            label = example["label"]
            if len(sequence) not in batch_lengths:
                batch_lengths[len(sequence)] = []
            sequence_as_list = [char_index.get(c, char_index["UNK"]) for c in list(sequence)]
            batch_lengths[len(sequence)].append((sequence_as_list, tag_index[label]))
            lang_counts[tag_index[label]] += 1
        for length in batch_lengths:
            random.shuffle(batch_lengths[length])

        # create final set of batches
        batches = []
        for length in batch_lengths:
            for sublist in [batch_lengths[length][i:i + batch_size] for i in
                            range(0, len(batch_lengths[length]), batch_size)]:
                batches.append(sublist)

        self.batches = [self.build_batch_tensors(batch) for batch in batches]

        # set up lang weights
        most_frequent = max(lang_counts)
        # set to 0.0 if lang_count is 0 or most_frequent/lang_count otherwise
        lang_counts = [(most_frequent * x)/(max(1, x) ** 2) for x in lang_counts]
        self.lang_weights = torch.tensor(lang_counts, device=self.device, dtype=torch.float)

        # shuffle batches to mix up lengths
        random.shuffle(self.batches)
        self.batches_iter = iter(self.batches)

    @staticmethod
    def randomize_data(sentences, upper_lim=20, lower_lim=5):
        """
        Takes the original data and creates random length examples with length between upper limit and lower limit
        From LSTM_langid: https://github.com/AU-DIS/LSTM_langid/blob/main/src/language_datasets.py
        """

        new_data = []
        for sentence in sentences:
            remaining = sentence
            while lower_lim < len(remaining):
                lim = random.randint(lower_lim, upper_lim)
                m = min(len(remaining), lim)
                new_sentence = remaining[:m]
                new_data.append(new_sentence)
                split = remaining[m:].split(" ", 1)
                if len(split) <= 1:
                    break
                remaining = split[1]
        random.shuffle(new_data)
        return new_data

    def build_batch_tensors(self, batch):
        """
        Helper to turn batches into tensors
        """

        batch_tensors = dict()
        batch_tensors["sentences"] = torch.tensor([s[0] for s in batch], device=self.device, dtype=torch.long)
        batch_tensors["targets"] = torch.tensor([s[1] for s in batch], device=self.device, dtype=torch.long)

        return batch_tensors

    def next(self):
        return next(self.batches_iter)

