o
    hM                     @   s   d dl mZ d dlmZ d dlZd dlZd dlZd dlZd dlZd dl	m
Z
 ddlmZ d dlmZmZ edZd	d
 ZedZedZedZG dd dZG dd deZG dd de
ZdS )    )bisect_right)copyN)Dataset   )Vocab)sort_with_indicesunsortstanzac                 C   sR   g }t | D ] \}\}}|dkr|dkr| |d  d dkrq|||f q|S )Nr    r   )	enumerateappend)parafilteredicharlabel r   Z/var/www/html/env_mimamsha/lib/python3.10/site-packages/stanza/models/tokenization/data.pyfilter_consecutive_whitespaces   s   r   z\n\s*\nz^[\d]+([,\.]+[\d]+)*[,\.]*$z\sc                       sN   e Zd Zdddddddf fdd	Zdd Zdd	 Zd
d Zdd Z  ZS )TokenizationDatasetNtxtr   Fc                    sr  t  j|i | || _|| _|| _|| _|d }	|d }
|	d us'|d us'J |d u rIt|	}d| 	 }W d    n1 sCw   Y  n|}t
|}dd |D }dd |D }|
d urt|
$}d| 	 }t
|}dd |D }dd |D }W d    n1 sw   Y  nd	d |D }| jd
d  fddt||D | _dd | jD | _d S )Nr   r    c                 S      g | ]}|  qS r   rstrip.0ptr   r   r   
<listcomp>7       z0TokenizationDataset.__init__.<locals>.<listcomp>c                 S   s   g | ]}|r|qS r   r   r   r   r   r   r   8   r    c                 S   r   r   r   r   r   r   r   r   =   r    c                 S   s   g | ]	}|rt t|qS r   )mapintr   r   r   r   r   >       c                 S   s   g | ]	}d d |D qS )c                 S      g | ]}d qS r   r   )r   _r   r   r   r   @       ;TokenizationDataset.__init__.<locals>.<listcomp>.<listcomp>r   r   r   r   r   r   @   r#   skip_newlineFc                    s(   g | ]\}} fd dt ||D qS )c                    s,   g | ]\}} r|d kst d||fqS )
r
   )WHITESPACE_REsub)r   r   r   r)   r   r   r   C   s    r(   )zip)r   r   pcr-   r   r   r   C   s
    

c                 S   s   g | ]}t |qS r   )r   r   xr   r   r   r   H   r    )super__init__argseval
dictionaryvocabopenjoin	readlinesr   NEWLINE_WHITESPACE_REsplitgetr.   data)selftokenizer_argsinput_files
input_textr7   
evaluationr6   r4   kwargstxt_file
label_fileftexttext_chunkslabels	__class__r-   r   r3   "   s>   




zTokenizationDataset.__init__c                 C   s   dd | j D S )z
        Returns a list of the labels for all of the sentences in this DataLoader

        Used at eval time to compare to the results, for example
        c                 S   s$   g | ]}t td d |D qS )c                 s   s    | ]}|d  V  qdS r   Nr   r0   r   r   r   	<genexpr>P   s    z8TokenizationDataset.labels.<locals>.<listcomp>.<genexpr>nparraylist)r   sentr   r   r   r   P   s   $ z.TokenizationDataset.labels.<locals>.<listcomp>r>   r?   r   r   r   rJ   J   s   zTokenizationDataset.labelsc                 C   sH  t |}dd t| jd D }dd t| jd D }|| d }|| d }d}d}	td| jd d D ]i}
||
 |d krh|rh||||
  d  7 }|| jd v rWdnd}|||
d < || jd	 vrhd
}||
 dkr|	r|||
  d  | }|| jd v rdnd}|||
d < || jd vrd
}	|s|	s || S q6|| S )zT
        This function is to extract dictionary features for each character
        c                 S   r$   r%   r   r   r   r   r   r   r   X   r'   z9TokenizationDataset.extract_dict_feat.<locals>.<listcomp>num_dict_featc                 S   r$   r%   r   rV   r   r   r   r   Y   r'   r   Tr   wordsprefixesFsuffixes)lenranger4   lowerr6   )r?   r   idxlengthdict_forward_featsdict_backward_featsforward_wordbackward_wordprefixsuffixwindowfeatr   r   r   extract_dict_featR   s2    z%TokenizationDataset.extract_dict_featc                    s  g }g  j d D ]2}|dks|dkrq	|dkrdd }n|dkr&dd }n|d	kr/d
d }ntd| | q	 fdd}fdd}dj d v }dj d v }j d }	g }
g }g }t|D ]p\}\}}||}|r|t|d kr{dnd}|| |r|dkrdnd}|| |	r||}|| }|
| || || js|dks|dkrt|
j d kr|||
|| |
  |  |  qet|
dkrjst|
j d kr|||
|| |S )z7 Convert a paragraph to a list of processed sentences. 
feat_funcsend_of_parastart_of_paraspace_beforec                 S   s   |  drdS dS )Nr
   r   r   )
startswithr1   r   r   r   <lambda>       z7TokenizationDataset.para_to_sentences.<locals>.<lambda>capitalizedc                 S   s   | d   rdS dS )Nr   r   )isupperrn   r   r   r   ro      r    numericc                 S   s   t | d ur	dS dS )Nr   r   )
NUMERIC_REmatchrn   r   r   r   ro          z#Feature function "{}" is undefined.c                    s    fddD S )Nc                    s   g | ]}| qS r   r   )r   rG   rn   r   r   r      r    zKTokenizationDataset.para_to_sentences.<locals>.<lambda>.<locals>.<listcomp>r   rn   )funcsrn   r   ro      rp   c                    s0   t  fdd| D t |t |t| fS )Nc                    s   g | ]} j |qS r   )r7   unit2id)r   yrU   r   r   r          zSTokenizationDataset.para_to_sentences.<locals>.process_sentence.<locals>.<listcomp>rO   )
sent_unitssent_labels
sent_featsrU   r   r   process_sentence   s
   z?TokenizationDataset.para_to_sentences.<locals>.process_sentenceuse_dictionaryr   r         
max_seqlen)	r4   
ValueErrorformatr   r   r[   rh   r5   clear)r?   r   res	feat_funcfunccomposite_funcr~   use_end_of_parause_start_of_parar   current_unitscurrent_labelscurrent_featsr   unitr   featsrG   
dict_featsr   )rw   r?   r   para_to_sentencesw   s\   








z%TokenizationDataset.para_to_sentencesc                 C   s  | j d}| j d}|\}}}}|jd }	||kd }
tdd t||
D }tjt	||f|tj
d}tjt	||fdtjd}tjt	|||	ftjd}g }tt	|D ]q}t|| |
| ||< |||| |
| f ||d|
| ||  f< |||| |
| f ||d|
| ||  f< |||| |
| f ||d|
| ||  f< ||| || |
|  dg||
|  ||     q[||||fS )	a  
        Advance to a new position in a batch where we have partially processed the batch

        If we have previously built a batch of data and made predictions on them, then when we are trying to make
        prediction on later characters in those paragraphs, we can avoid rebuilding the converted data from scratch
        and just (essentially) advance the indices/offsets from where we read converted data in this old batch.
        In this case, eval_offsets index within the old_batch to advance the strings to process.
        <UNK><PAD>r   c                 s   s    | ]	\}}|| V  qd S Nr   )r   r   lr   r   r   rN          z8TokenizationDataset.advance_old_batch.<locals>.<genexpr>dtypeN)r7   rx   shapesumtolistmaxr.   torchfullr[   int64int32zerosfloat32r\   minr   )r?   eval_offsets	old_batchunkidpadidounitsolabels	ofeaturesoraw	feat_sizelenspad_lenunitsrJ   features	raw_unitsr   r   r   r   advance_old_batch   s"   	
000:z%TokenizationDataset.advance_old_batch)	__name__
__module____qualname__r3   rJ   rh   r   r   __classcell__r   r   rK   r   r   !   s    (%Br   c                       sd   e Zd ZdZdddddddf fdd	Zdd Zd	d
 Zdd Zdd Zdd Z	dddZ
  ZS )
DataLoaderz6
    This is the training version of the dataset.
    Nr   Fc                    sd   t  |||||| |d ur|n   _ fdd jD  _   tt	 j
 d d S )Nc                    s   g | ]}  |qS r   )r   )r   r   rU   r   r   r      rv   z'DataLoader.__init__.<locals>.<listcomp>z sentences loaded.)r2   r3   
init_vocabr7   r>   	sentencesinit_sent_idsloggerdebugr[   sentence_ids)r?   r4   rA   rB   r7   rC   r6   rK   rU   r   r3      s
   zDataLoader.__init__c                 C   
   t | jS r   )r[   r   rU   r   r   r   __len__      
zDataLoader.__len__c                 C   s   t | j| jd }|S )Nlang)r   r>   r4   )r?   r7   r   r   r   r      s   zDataLoader.init_vocabc                 C   sx   g | _ dg| _t| jD ]-\}}tt|D ]"}|  j ||fg7  _ |  j| jd t| j| | d  g7  _qqd S )Nr   r   )r   cumlenr   r   r\   r[   )r?   r   r   jr   r   r   r      s   .zDataLoader.init_sent_idsc                 C   s.   | j D ]}|D ]}|d dkr  dS qqdS )Nr   r   TFrT   )r?   sentencewordr   r   r   has_mwt   s   
zDataLoader.has_mwtc                 C   s"   | j D ]}t| q|   d S r   )r   randomshuffler   )r?   r   r   r   r   r     s   
zDataLoader.shuffle        c              	      sD  t  jd d d d } jd} jd}d jd f fdd	}|durd}|D ](}	|	 jd	 k rUt j|	d
 }
 j|
 }t|t |||	 j|
  dd }q-|d
7 } fdd|D } fdd|D } fddt	||D }t
t	||}nt jtt  j jd }dd |D } jd }tjt ||f|tjd}tjt ||fd	tjd}tjt |||ftjd}g }t|D ]@\}\}}||||d\}}}}|||dt |f< |||dt |f< |||dt |ddf< ||dg|t |    q|dkrI jsItj|j|k }d|||k< |||< tt |D ]}tt || D ]}|||f rEd|| |< q6q, jd r|dkr jstj|j|k }d|||k< tt |D ]}tt || D ]}|||f rd|||ddf< qwqmt|}t|}t|}||||fS )zf Get a batch of converted and padded PyTorch data from preprocessed raw text for training/prediction. r   r   r   r   r   c                    sP  j r| ntj\}}t fddj| | D g}j s*jdddkr,dn
t jddk }j sCjdddkrEdn
t jddk }t|d d }j s}|jd ks}J d	jd |d	
d
d tj| | D j rt|d tj| D ]"}	|tj| |	 d 7 }|j| |	  |jd kr nqn)	 tj\}
}	|tj|
 |	 d 7 }|j|
 |	  |jd krnq|r t|dkr |jd kr|d d }t|dkr dd tdt|d D }tjttt|tt|dd }|d |d  }tdd |D }tdd |D }tdd |D }dd |D }j shjd }|d | |d | |d | |d | f\}}}}|rt|dkr|d dkr|d dv r|d d |d d |d d |d d f\}}}}|d d |d< ||||fS )Nc                    s   g | ]}| d  qS r   r   r0   offsetr   r   r     rz   z=DataLoader.next.<locals>.strings_starting.<locals>.<listcomp>sent_drop_probr   Flast_char_drop_probr   z|The maximum sequence length {} is less than that of the longest sentence length ({}) in the data, consider increasing it! {}r
   c                 S   s   g | ]}d j | qS )z{}/{})r   r0   r   r   r   r     rv   r   Tr   c                 S   s   g | ]}d | qS )g      ?r   rV   r   r   r   r   .  r    )weightsc                 S      g | ]}|d  qS r%   r   r   sr   r   r   r   2  r    c                 S   r   r   r   r   r   r   r   r   3  r    c                 S   r   )r   r   r   r   r   r   r   4  r    c                 S   s   g | ]}|d  D ]}|qqS )   r   )r   r   r1   r   r   r   r   5      r   )r   r   )r5   r   choicer   r   r   r4   r=   r[   r   r9   r.   r\   r   choicesrR   reversedrP   concatenate)id_pairr   r   pidsidr   
drop_sentsdrop_last_char	total_lensid1pid1pcutoffr   rJ   r   r   rU   r   r   strings_starting  sP   "22J&
4*4z)DataLoader.next.<locals>.strings_startingNr   r   r   c                    s   g | ]
}t  j|d  qS r   )r   r   )r   eval_offsetrU   r   r   r   P  s    z#DataLoader.next.<locals>.<listcomp>c                    s   g | ]} j | qS r   )r   )r   pair_idrU   r   r   r   Q  rv   c                    s   g | ]\}}| j |  qS r   )r   )r   r   r   rU   r   r   r   R  r   
batch_sizec                 S   s   g | ]}d |fqS r%   r   r0   r   r   r   r   W  r    r   )r   r   r   )r[   r   r7   rx   r4   r   r   r   r   r.   rR   r   sampler   rP   r   r   r   r   r   r   r5   random_sampler   r\   r   
from_numpy)r?   r   unit_dropoutfeat_unit_dropoutr   r   r   r   r   r   r   pairid_pairspairsoffsetsoffsets_pairsr   rJ   r   r   r   r   u_l_f_r_maskr   	mask_featr   rU   r   next  sj   8
$ 



zDataLoader.next)Nr   r   )r   r   r   __doc__r3   r   r   r   r   r   r   r   r   r   rK   r   r      s    
r   c                       s@   e Zd ZdZ fddZdd Zdd Zdd	 Zd
d Z  Z	S )SortedDatasetal  
    Holds a TokenizationDataset for use in a torch DataLoader

    The torch DataLoader is different from the DataLoader defined here
    and allows for cpu & gpu parallelism.  Updating output_predictions
    to use this class as a wrapper to a TokenizationDataset means the
    calculation of features can happen in parallel, saving quite a
    bit of time.
    c                    s,   t    || _t| jjtd\| _| _d S )N)key)r2   r3   datasetr   r>   r[   indices)r?   r   rK   r   r   r3     s   
zSortedDataset.__init__c                 C   r   r   )r[   r>   rU   r   r   r   r     r   zSortedDataset.__len__c                 C   s   | j | j| S r   )r   r   r>   )r?   indexr   r   r   __getitem__  s   zSortedDataset.__getitem__c                 C   s   t || jS r   )r   r   )r?   arrr   r   r   r     s   zSortedDataset.unsortc                 C   sD  t dd |D rtd|d d d jd }| jjd}tdd |D d	 }tjt	||f|tj
d
}tjt	||fdtjd
}tjt	|||ftjd
}g }t|D ]D\}	}
|
d \}}}}t|||	d t	|f< t|||	d t	|f< t|||	d t	|d d f< ||dg|t	|    qW||||fS )Nc                 s   s    | ]	}t |d kV  qdS rM   r[   r0   r   r   r   rN     r   z(SortedDataset.collate.<locals>.<genexpr>z:Expected all paragraphs to have no preset sentence splits!r   r   r   r   c                 s   s     | ]}t |d  d V  qdS )r   r   Nr  r0   r   r   r   rN     s    r   r   )anyr   r   r   r7   rx   r   r   r   r[   r   r   r   r   r   r   r   )r?   samplesr   r   r   r   rJ   r   r   r   r   r   r   r   r   r   r   r   collate  s     zSortedDataset.collate)
r   r   r   r   r3   r   r  r   r  r   r   r   rK   r   r     s    	r   )bisectr   r   numpyrP   r   loggingrer   torch.utils.datar   r7   r   stanza.models.common.utilsr   r   	getLoggerr   r   compiler;   rt   r+   r   r   r   r   r   r   r   <module>   s(    



 8 +