o
    h                    @   s  d Z ddlmZ ddlZddlmZ ddlZddlZddlZddl	Z	ddl
mZ ddlmZ ddlmZ ddlmZ ddlmZmZ dd	lmZmZ dd
lmZ ddlmZ ddlmZ ddlmZ ddl m!Z! ddl"m#Z# ddl$m%Z% ddl&m'Z' ddl(m)Z) ddl*m+Z+m,Z, e-dZ.e-dZ/edddgZ0edg dZ1G dd deZ2G dd deZ3G dd  d eZ4G d!d" d"eej5Z6dS )#aJ  
A version of the BaseModel which uses LSTMs to predict the correct next transition
based on the current known state.

The primary purpose of this class is to implement the prediction of the next
transition, which is done by concatenating the output of an LSTM operated over
previous transitions, the words, and the partially built constituents.

A complete processing of a sentence is as follows:
  1) Run the input words through an encoder.
     The encoder includes some or all of the following:
       pretrained word embedding
       finetuned word embedding for training set words - "delta_embedding"
       POS tag embedding
       pretrained charlm representation
       BERT or similar large language model representation
       attention transformer over the previous inputs
       labeled attention transformer over the first attention layer
     The encoded input is then put through a bi-lstm, giving a word representation
  2) Transitions are put in an embedding, and transitions already used are tracked
     in an LSTM
  3) Constituents already built are also processed in an LSTM
  4) Every transition is chosen by taking the output of the current word position,
     the transition LSTM, and the constituent LSTM, and classifying the next
     transition
  5) Transitions are repeated (with constraints) until the sentence is completed
    )
namedtupleN)Enum)pack_padded_sequence)extract_bert_embeddingsMaxoutLinear)attach_bert_modelunsort)PAD_IDUNK_ID)	BaseModel)LabelAttentionModule)LSTMTreeStack)TransitionScheme)Tree)PartitionedTransformerModule)ConcatSinusoidalEncoding)TransformerTreeStack)	TreeStack)build_nonlinearityinitialize_linearstanzazstanza.constituency.trainerWordNodevaluehxConstituent)r   tree_hxtree_cxc                   @   s   e Zd ZdZdZdZdS )SentenceBoundary         N)__name__
__module____qualname__NONEWORDS
EVERYTHING r(   r(   `/var/www/html/env_mimamsha/lib/python3.10/site-packages/stanza/models/constituency/lstm_model.pyr   D   s    r   c                   @   s   e Zd ZdZdZdS )StackHistoryr   r    N)r"   r#   r$   LSTMATTNr(   r(   r(   r)   r*   I   s    r*   c                   @   s4   e Zd ZdZdZdZdZdZdZdZ	dZ
d	Zd
ZdS )ConstituencyCompositionr   r    r!                  	   
   N)r"   r#   r$   BILSTMMAX	TREE_LSTM
BILSTM_MAXBIGRAMr,   TREE_LSTM_CX
UNTIED_MAXKEY
UNTIED_KEYr(   r(   r(   r)   r-      s    r-   c                       s  e Zd Z fddZedd Zedd Zdd Zd	d
 Zdd Z	e
dd Zdd Zdd Zdd Zdd Zdd Zdd Zdd Zdd Zdd  Zd!d" Zd#d$ Zd%d& Zd'd( Zd)d* Zd+d, Zd-d. Zd/d0 Zd1d2 Zd<d4d5Zd6d7 Zd8d9 Z d<d:d;Z!  Z"S )=	LSTMModelc                    sp  t  j|d ||dd|d || _g | _|j}| dtjj	|dd dd	 t
|jD | _| d
tjtt|jdd |jd | _|jd | _tt|	| _| jd | _| jdtj| _| jtjtjtjfv r| jd | _| j| j dkr| j| j | j| j  | _|d t jkr| jd | _| j|d  dkr| j|d  t|d   | _| jtjkr| j| j dkrt!d| jd | _"|d t jkr| j"|d  dkrt#$dt"|d  | j"|d  | j"|d   | _"| jd | _%| jd | _&| jd | _'| j| j% | j' | _(|dur&| d| |  j(| j)* 7  _(|j+s%t!dnd| _)|durG| d| |  j(| j,* 7  _(|j+rFt!dnd| _,tt-|| _.d d	 t
| j.D | _/t0dkscJ t1dksjJ tjt| j.d! | j'dd"| _2tj3j4| j2j5d#d$ | d%tjtt| j.d! dd t-|| _6tt|
| _7| j%dkrd&d	 t
| j7D | _8tjt|
d! | j%dd"| _9tj3j4| j9j5d'd$ | d(tjtt| j7d! dd | jd) | _:| jd* | _;| jd+ | _<t=| jd, | _>t=| jd- | _?t=| jd. | _@| d/tA| j| j;  | d0tA| j:d| j | jd1 | _B| jBtCjDurT| Ed2tjFd3tjG| j(dd  | Ed4tjFd3tjG| j(dd  |pa| jd5 pa| jd6 | _HtI| ||| jd7d| jH || _J|dur|du rt!d8| jKjLj| _M|d9 r|d9 |jLjNkr|jLjNd |d9< tjO|d9 ddd:| _Ptj3Q| jPj5 nd| _P| j(| jM | _(d| _Rd| _StTU| jr| jd; d! d! | _StV| jd< | jS| jd= | jd> | jd? | jd@ | jdA | jdB | j(| jdC | jdD | jdE | jdF dG| _R|  j(| jS7  _(d| _WtTX| jr| jRdu r&t#YdH n[| jdI r1| j(| _Zn| jS| _Zt[| jZ| jdJ | jdK | jdK | jdL | jdM | jdN | jdO | jdP | jdQ | jdR | jSd! | jdS | jdT | jdU | _W| j(| jdM | jdL   | _(tj\| j(| j| j:d| j<dV| _]tO| jd! | j| j; | _^t_| j^| jdW | jd!  tt|| _`dXd	 t
| j`D | _a| dYtjtt|dd tjt|| j&dZ| _btj3j4| jbj5d'd$ |d t j\krtc| j&| j"| j:| j<| jBtCjdu | j@d[| _en!|d t jkrtf| j&| j"| j@d|d d\| _en	t!d]g|d tt|| _hd^d	 t
| jhD | _itjt| ji| jdZ| _jtj3j4| jjj5d3d$ |d t j\krdtc| j| j| j:| j<| jBtCjdu | j@d[| _kn!|d t jkr|tf| j| j| j@d|d d\| _kn	t!d_g|d |d` r| jj| _lntjt| ji| jdZ| _ltj3j4| jlj5d3d$ | datjtt|dd | jtjks| jtjmkrtj\| j| j| j:d| j<dV| _n| jtjkrtO| jd! | j| _ot_| jo| jdW | jd!  ntO| j| j| _ptO| j| j| _qt_| jp| jdW | j t_| jq| jdW | j n|| jtjrkr9tO| j| j| _ot_| jo| jdW | j n_| jtjskr| EdbtjFtjGt|| j| jdd | EdctjFtjGt|| jdd tt|D ]}tj3jt| ju| | jdW dd qntj3v| jwdd| jd! de   n| jtjxkrtO| j| j| _otO| jd! | j| _yt_| jo| jdW | j t_| jy| jdW | j n| jtjkrtz| j| j| _{n| jtjks| jtjkrT| jdf r| dft|| jdf dg n| dft}  tjO| j| jdf  | jdd:| _~tO| j| jdf  | j| _| jtjkr:| EdhtjFtjG| j| j| j ddd n^| EdhtjFtjGt|| j| j| j ddd nD| jtjkrjtj\| j| j| j;| j<di| _nn.| jtjkrtjt|
d! | j;| j dZ| _tj\| j| j| j;| j<di| _nnt!djg| jt| jdW | _| jdkd| _| | jdl t|| j| _dS )ma  
        pretrain: a Pretrain object
        transitions: a list of all possible transitions which will be
          used to build trees
        constituents: a list of all possible constituents in the treebank
        tags: a list of all possible tags in the treebank
        words: a list of all known words, used for a delta word embedding.
          note that there will be an attempt made to learn UNK words as well,
          and tags by themselves may help UNK words
        rare_words: a list of rare words, used to occasionally replace with UNK
        root_labels: probably ROOT, although apparently some treebanks like TOP or even s
        constituent_opens: a list of all possible open nodes which will go on the stack
          - this might be different from constituents if there are nodes
            which represent multiple constituents at once
        args: hidden_size, transition_hidden_size, etc as gotten from
          constituency_parser.py

        Note that it might look like a hassle to pass all of this in
        when it can be collected directly from the trees themselves.
        However, that would only work at train time.  At eval or
        pipeline time we will load the lists from the saved model.
        transition_schemereversedF)r?   unary_limitreverse_sentenceroot_labels	embeddingT)freezec                 S   s   i | ]\}}| d d|qS )     )replace.0iwordr(   r(   r)   
<dictcomp>   s    z&LSTMModel.__init__.<locals>.<dictcomp>vocab_tensorsrequires_gradr   r   hidden_sizeconstituency_compositionreduce_headsconstituent_stackconstituent_headsz6--reduce_heads and --constituent_heads not compatible!transition_hidden_sizetransition_stacktransition_headszEtransition_hidden_size %d %% transition_heads %d != 0.  reconfiguringtag_embedding_dimtransition_embedding_dimdelta_embedding_dimNforward_charlmz*Got a backward charlm as a forward charlm!backward_charlmz*Got a forward charlm as a backward charlm!c                 S      i | ]	\}}||d  qS r    r(   rI   r(   r(   r)   rM   +      r    )num_embeddingsembedding_dimpadding_idxg?)stddelta_tensorsc                 S   r^   r_   r(   rJ   rK   tr(   r(   r)   rM   B  r`   g      ?tag_tensorsnum_lstm_layersnum_tree_lstm_layerslstm_layer_dropoutword_dropoutpredict_dropoutlstm_input_dropout
word_zerosconstituent_zerossentence_boundary_vectorsword_start_embeddingg?word_end_embeddingbert_finetunestage1_bert_finetuneuse_peftz,Cannot have a bert model without a tokenizerbert_hidden_layers)biaspattn_d_modelpattn_num_layerspattn_num_heads
pattn_d_kv
pattn_d_ffpattn_relu_dropoutpattn_residual_dropoutpattn_attention_dropout
pattn_biaspattn_morpho_emb_dropoutpattn_timingpattn_encoder_max_len)d_modeln_headd_qkvd_ff
ff_dropoutresidual_dropoutattention_dropoutword_input_sizerx   morpho_emb_dropouttimingencoder_max_lenzLNot using Labeled Attention, as the Partitioned Attention module is not usedlattn_combined_inputlattn_d_input_proj
lattn_d_kv	lattn_d_llattn_d_projlattn_combine_as_selflattn_resdroplattn_q_as_matrixlattn_residual_dropoutlattn_attention_dropout
lattn_d_fflattn_relu_dropoutlattn_partitioned)
input_sizerQ   
num_layersbidirectionaldropoutnonlinearityc                 S      i | ]\}}||qS r(   r(   rf   r(   r(   r)   rM         transition_tensors)ra   rb   )r   rQ   ri   r   uses_boundary_vectorinput_dropout)r   output_sizer   use_position	num_headsz+Unhandled transition_stack StackHistory: {}c                 S   r   r(   r(   )rJ   rK   xr(   r(   r)   rM     r   z,Unhandled constituent_stack StackHistory: {}combined_dummy_embeddingconstituent_open_tensorsreduce_linear_weightreduce_linear_bias)r   g      ?reduce_position2   
reduce_key)r   rQ   r   r   %Unhandled ConstituencyComposition: {}maxout_knum_output_layers)super__init__getargsunsaved_modulesembadd_unsaved_modulenn	Embeddingfrom_pretrained	enumeratevocab	vocab_mapregister_buffertorchtensorrangelenshape
vocab_sizerb   sortedlistconstituentsrQ   r-   r5   rR   r,   r<   r=   rS   r*   
ValueErrorrV   loggerwarningrY   rZ   r[   r   r\   
hidden_dimis_forward_lmr]   setdelta_wordsdelta_word_mapr
   r   delta_embeddinginitnormal_weight
rare_wordstagstag_maptag_embeddingri   rj   rk   Dropoutrl   rm   rn   zerosrq   r   r%   register_parameter	Parameterrandnforce_bert_savedr   	peft_name
bert_modelconfigbert_dimnum_hidden_layersLinearbert_layer_mixzeros_partitioned_transformer_modulery   r>   
uses_pattnr   label_attention_module
uses_lattnerrorlattn_d_inputr   r+   	word_lstmword_to_constituentr   transitionstransition_maptransition_embeddingr   r'   rW   r   formatconstituent_opensconstituent_open_mapconstituent_open_embeddingrT   dummy_embeddingr8   constituent_reduce_lstmreduce_linearreduce_forwardreduce_backwardr6   r;   kaiming_normal_r   uniform_r   r9   reduce_bigramMultiheadAttentionreduce_attnr   Identityreduce_queryreduce_valuer7   r:   constituent_reduce_embeddingr   r   r   build_output_layersoutput_layers)selfpretrainr\   r]   r   bert_tokenizerr   r   r   r   r   wordsr   rC   r   rA   r   
emb_matrix	layer_idx	__class__r(   r)   r      s    

$
$$$












*& $ .4
 zLSTMModel.__init__c                 C   s,   |  ddo|  dddko|  dddkS )N	use_lattnTr   r   r   )r   r   r(   r(   r)   r   8  s   ,zLSTMModel.uses_lattnc                 C   s   | d dko| d dkS )Nr{   r   rz   r(   r  r(   r(   r)   r   <  s   zLSTMModel.uses_pattnc           	      C   sR  | j |j kr| j tjkrtd| j |j | D ]\}}|drT| j tjkrT|dkr1| j}n|dkr9| j}ntd|t	t
| jD ]}|| j|j qGq|dr| |}t|jjd |jjd }t|j}|jdd	|f |dd	|f< |j| qz| |j|j W q ty } ztd
| |d	}~ww d	S )a\  
        Copy parameters from the other model to this model

        word_lstm can change size if the other model didn't use pattn / lattn and this one does.
        In that case, the new values are initialized to 0.
        This will rebuild the model in such a way that the outputs will be
        exactly the same as the previous model.
        zbModels are incompatible: self.constituency_composition == {}, other.constituency_composition == {}zreduce_linear.zreduce_linear.weightzreduce_linear.biasz"Unexpected other parameter name {}zword_lstm.weight_ih_l0.NzCould not process %s)rR   r-   r;   r   r   named_parameters
startswithr   r   r   r   r   datacopy_get_parameterminr   r   
zeros_likeAttributeError)	r	  othernameother_parametermy_parameteridx	copy_size
new_valueser(   r(   r)   copy_with_new_structure@  s2   	

z!LSTMModel.copy_with_new_structurec           
         s   |d }| j | j | j  | j g| j g|  }| j g| |g } sDtdd t||D }t||D ]\}}	t|| jd |	 q4|S t fddt||D }|S )a  
        Build a ModuleList of Linear transformations for the given num_output_layers

        The final layer size can be specified.
        Initial layer size is the combination of word, constituent, and transition vectors
        Middle layer sizes are self.hidden_size
        r   c                 S   s   g | ]
\}}t ||qS r(   )r   r   rJ   r   r   r(   r(   r)   
<listcomp>t      z1LSTMModel.build_output_layers.<locals>.<listcomp>r   c                    s   g | ]
\}}t || qS r(   r   r%  r   r(   r)   r&  y  r'  )rQ   rj   rV   r   
ModuleListzipr   r   )
r	  r   final_layer_sizer   middle_layerspredict_input_sizepredict_output_sizer  output_layerr   r(   r(  r)   r  e  s   &
zLSTMModel.build_output_layersc                    s   t  fdd|D S )Nc                 3   s(    | ]}| j v p|  j v V  qd S N)r   lowerrJ   rL   r	  r(   r)   	<genexpr>~  s   & z,LSTMModel.num_words_known.<locals>.<genexpr>)sum)r	  r  r(   r3  r)   num_words_known}  s   zLSTMModel.num_words_knownc                 C   s
   | j d S )Nretag_methodr  r3  r(   r(   r)   r7    s   
zLSTMModel.retag_methodc                 C   s   | j d d uo| j d dkS )Nretag_packager7  xposr  r3  r(   r(   r)   	uses_xpos  s   zLSTMModel.uses_xposc                 C   sP   |  j |g7  _ t| || |dur"|dv r$| D ]\}}d|_qdS dS dS )z
        Adds a module which will not be saved to disk

        Best used for large models such as pretrained word embeddings
        N)r\   r]   F)r   setattrr  rP   )r	  r  module_	parameterr(   r(   r)   r     s   zLSTMModel.add_unsaved_modulec                 C   s   | dd | jv S )N.r   )splitr   )r	  r  r(   r(   r)   is_unsaved_module  s   zLSTMModel.is_unsaved_modulec              
      s6  g }t   | jtjkr8ddh |d t| jD ]\}}|d|t| j	| 
 t| j| 
 f  q fdd|  D }t|dkrK|S tt| tdd	 |D }td
d	 |D }dt| d t| d }|D ]%\}}	t|	 dk 
 }
dt|	
  }|||||
|	 f  qs|S )Nr   r   zreduce_linear:z  %s weight %.6g bias %.6gc                    s&   g | ]\}}|j r| vr||fqS r(   rO   rJ   r  paramskipr(   r)   r&       & z'LSTMModel.get_norms.<locals>.<listcomp>r   c                 s   s    | ]	\}}t |V  qd S r0  r   rB  r(   r(   r)   r4    s    z&LSTMModel.get_norms.<locals>.<genexpr>c                 s   s*    | ]\}}t d t|  V  qdS )%.6gN)r   r   normitemrB  r(   r(   r)   r4    s   ( z%-z
s   norm %zs  zeros %d / %dgư>rH  )r   rR   r-   r;   appendr   r   r   rI  r   rJ  r   r  r   printmaxstrr5  absnelement)r	  linesc_idxc_openactive_paramsmax_name_lenmax_norm_lenformat_stringr  rC  r   rI  r(   rD  r)   	get_norms  s&   
6zLSTMModel.get_normsc                 C   s(   dg}| |   td| d S )NNORMS FOR MODEL PARAMETERS
)extendrX  r   infojoin)r	  rQ  r(   r(   r)   	log_norms  s   zLSTMModel.log_normsc                 C   sF   dg}|   D ]\}}|jr|d||j qtd| d S )NrY  z{} {}rZ  )r  rP   rK  r   r   r   r\  r]  )r	  rQ  r  rC  r(   r(   r)   
log_shapes  s   zLSTMModel.log_shapesc                    sT  t  j}jfdd g }dd |D }t|D ]r\}}|| }t fdd|D }|}	jrDfdd|D }
n|}
tfdd|
D }	|}|	|g}j
dkrjrmfd	d|D }nd
d |D }tfdd|D }|}|| || qjdurj|}t||D ]	\}}|| qjdurĈj|}t||D ]	\}}|| qdd |D }jtjurjdjdfdd|D }jdur2tjd jj||jtjujdurjjndjd  ojd  jd	}jdur(fdd|D }dd t||D }jdurHd|}dd t||D }jdurkjd r[||}n||}dd t||D }fdd|D }tj j!j"j#|dd}$|\}}tj j!j"%|\}}g }t|D ]\}}jtjur|dt&|d |ddf n|dt&||ddf '(jtjurt)ddddf g}|fddt|D 7 }|t)dt&|d ddf  nt)dj*g}|fddt|D 7 }|t)dj* j+r!t,t-|}|| q|S )z
        Produce initial word queues out of the model's LSTMs for use in the tagged word lists.

        Operates in a batched fashion to reduce the runtime for the LSTM operations
        c                    s(     | d }|d ur|S   |  tS r0  )r   r1  r   )rL   r   )r   r(   r)   map_word  s   z/LSTMModel.initial_word_queues.<locals>.map_wordc                 S      g | ]	}d d |D qS )c                 S   s   g | ]}|j d  jqS r   )childrenlabelr2  r(   r(   r)   r&        z<LSTMModel.initial_word_queues.<locals>.<listcomp>.<listcomp>r(   )rJ   tagged_wordsr(   r(   r)   r&    s    z1LSTMModel.initial_word_queues.<locals>.<listcomp>c                    s"   g | ]}j  |jd  j qS rb  )rN   rc  rd  r2  )r`  r	  r(   r)   r&       " c                    s0   g | ]}| j v rt  jd  k rdn|qS )rare_word_unknown_frequencyN)r   randomr   r2  r3  r(   r)   r&    s    (c                        g | ]} j  j|t qS r(   )re   r   r   r   r2  r3  r(   r)   r&         r   c                    s(   g | ]}t    jd  k rdn|jqS )tag_unknown_frequencyN)ri  r   rd  r2  r3  r(   r)   r&       ( c                 S      g | ]}|j qS r(   )rd  r2  r(   r(   r)   r&        c                    rj  r(   )rh   r   r   r   )rJ   tagr3  r(   r)   r&    rk  Nc                 S   s   g | ]	}t j|d dqS )r   dimr   catrJ   word_inputsr(   r(   r)   r&    r`   c                    s    g | ]}t j| gd dqS r   rq  rs  ru  )word_end
word_startr(   r)   r&    rk  r   rt   ru   )keep_endpointsr   detachr   c                    s0   g | ]}  |d |jd d j j  qS )r    axis)r   squeezer5  in_features)rJ   featurer3  r(   r)   r&    s   0 c                 S   s"   g | ]\}}t j||fd dqS )r   r|  rs  rJ   r   yr(   r(   r)   r&    rg  c                 S   8   g | ]\}}t j||d |jd d d f fddqS Nr   r   r|  r   rt  r   r  r(   r(   r)   r&  
     8 r   c                 S   r  r  r  r  r(   r(   r)   r&    r  c                       g | ]}  |qS r(   )rl   ru  r3  r(   r)   r&    r   Fenforce_sortedr    c                    s*   g | ]\}}t | |d  ddf qS r   Nr   rJ   r   tag_nodesentence_outputr(   r)   r&  +  s    r   c                    s&   g | ]\}}t | |d d f qS r0  r  r  r  r(   r)   r&  0  s    ).next
parametersdevicer   r   r   stackrD   trainingr   rY   r   rK  r\   build_char_representationr*  r]   rq   r   r%   rr   	unsqueezers   r   r   r   r  r   r  r   r   r   r   utilsrnnpack_sequencer   pad_packed_sequencer   r   r   r   ro   rB   r   r@   )r	  tagged_word_listsr  all_word_inputsall_word_labelssentence_idxrf  word_labelsword_idx
word_inputdelta_labels	delta_idxdelta_inputrv  
tag_labelstag_idx	tag_inputall_forward_charsforward_charsall_backward_charsbackward_charsbert_embeddingspartitioned_embeddingslabeled_representationspacked_word_inputword_outputr=  word_output_lensword_queues
word_queuer(   )r`  r	  r  r   rx  ry  r)   initial_word_queues  s   








 

&zLSTMModel.initial_word_queuesc                 C   s
   | j  S )zA
        Return an initial TreeStack with no transitions
        )rW   initial_stater3  r(   r(   r)   initial_transitions:  s   
zLSTMModel.initial_transitionsc                 C   s   | j td| j| jS )zB
        Return an initial TreeStack with no constituents
        N)rT   r  r   rp   r3  r(   r(   r)   initial_constituents@  s   zLSTMModel.initial_constituentsc                 C   s   |j S r0  r   )r	  	word_noder(   r(   r)   get_wordF  s   zLSTMModel.get_wordc                 C   s   | |j}|j}| jtjkr#t||j| j	| j
| j| j	| j
S | jtjkrT|j}|j| j	| j
}| j| j|t }| |}|| j	| j
}t|||| S t||jd | j
 dd S Nr   )r  word_positionr   rR   r-   r7   r   r   viewrj   rQ   ro   r:   rd  rh   r   r   r   r  r  )r	  stater  rL   rp  r   
tag_tensorr   r(   r(   r)   transform_word_to_constituentI  s   (
z'LSTMModel.transform_word_to_constituentc                 C   s2   |j }| j| j|  }| |}t||dd S r  )rd  r   r   r   r   r  )r	  dummyrd  
open_indexr   r(   r(   r)   dummy_constituentY  s   
zLSTMModel.dummy_constituentc                     s  j tjksj tjkrdd |D }fdd|D }tdd |D tjj|d jdfddt	||D }fd	d|D }tj
|d
d}tjjjj|dd |D dd}| j tjkr d
 d   dddddf } dddddf }tj||fd
d}	nDtjjj d \ }
 fddt	tt|
|
D  tj
dd  D dd  dddjf  ddjdf  }	|	d}d}n:j tjkrdd |D }fdd|D }tj
|d
d}|}	|	}d}nj tjkrNdd |D }fdd|D }fdd|D }fddt	||D }	tj
|	dd}	|	d}	|	}d}n͈j tjkrdd |D }g }|D ]H}tj|dd}|jd d
krtj|ddddf |d
dddf fd
d}|d }tj||fdd}|t|dj  q`tj
|ddd}|}	|	}d}nZj tj!krdd |D }fdd|D }dd |D }dd t	||D }fd d|D }fd!d|D }tj
|dd}	|	d}d}nj tj"ksj tj#krd"d |D }fd#d|D }fd$d|D }fd%d|D }j tj"krOfd&d|D }nfd'd|D }fd(dt	||D }d)d |D }fd*d|D }fd+d|D }d,d t	||D }d-d |D }tj
|ddd}	|	}d}n{j tj$tj%fv rfd.d|D }t
|d}td/d |D d0d |D }fd1d|D }d2d |D }tj
d3d |D d
d}d4d |D }d5d |D }d6d t	||D }tj
|d
d}|||f\}\}}nt&d7'j g }t(t	||D ]M\}\}}d8d |D }t)|t*r>t+||d9}nt,|D ]}t+||d9}|}qB|t-||dd|ddf |durl|dd|ddf nd q$|S ):a2  
        Build new constituents with the given label from the list of children

        labels is a list of labels for each of the new nodes to construct
        children_lists is a list of children that go under each of the new nodes
        lists of each are used so that we can stack operations
        c                 S   ra  )c                 S   s   g | ]	}|j jd qS rb  )r   r   r~  rJ   childr(   r(   r)   r&  l  r`   ;LSTMModel.build_constituents.<locals>.<listcomp>.<listcomp>r(   rJ   rc  r(   r(   r)   r&  l  r`   z0LSTMModel.build_constituents.<locals>.<listcomp>c                    "   g | ]}   j j|  qS r(   r   r   r   rJ   rd  r3  r(   r)   r&  m  rg  c                 s       | ]}t |V  qd S r0  rG  r  r(   r(   r)   r4  o      z/LSTMModel.build_constituents.<locals>.<genexpr>r   r  c                    s2   g | ]\}}|g| |g g t |   qS r(   rG  rJ   lhxnhx)
max_lengthr   r(   r)   r&  r  s   2 c                       g | ]
}  t|qS r(   rn   r   r  rJ   r  r3  r(   r)   r&  s      r   r|  c                 S   s   g | ]}t |d  qS r_   rG  rJ   r   r(   r(   r)   r&  u  re  Fr  Nr  c                    s*   g | ]\}} d |d  |ddf qS r  r(   )rJ   r   length)lstm_outputr(   r)   r&       * c                 S   s   g | ]	}t |d jqS rb  )r   rM  valuesr  r(   r(   r)   r&    r`   c                 S   ra  )c                 S      g | ]}|j jqS r(   r   r   r  r(   r(   r)   r&        r  r(   r  r(   r(   r)   r&    r`   c              	      &   g | ]}  tt|d jqS rb  rn   r   rM  r  r  r  r3  r(   r)   r&    rF  c                 S   ra  )c                 S   r  r(   r  r  r(   r(   r)   r&    r  r  r(   r  r(   r(   r)   r&    r`   c              	      r  rb  r  r  r3  r(   r)   r&    rF  c                       g | ]} j | qS r(   r   r  r3  r(   r)   r&    r   c                    s2   g | ]\}}t  j| |d  j|  qS rb  )r   matmulr   r~  r   )rJ   	label_idxhx_layerr3  r(   r)   r&    s    &c                 S   ra  )c                 S   r  r(   r  r  r(   r(   r)   r&    r  r  r(   r  r(   r(   r)   r&    r`   r    c                 S   ra  )c                 S   r  r(   r  r  r(   r(   r)   r&    r  r  r(   r  r(   r(   r)   r&    r`   c                    r  r(   r  r  r3  r(   r)   r&    rg  c                 S   s   g | ]}t |qS r(   r   r  r  r(   r(   r)   r&    r   c                 S   s.   g | ]\}}t j|d d |fd dqS )r   r|  )r   rt  r  r  r(   r(   r)   r&    s   . c                    s$   g | ]}  |||d  dqS )r   r   )r  r~  r  r3  r(   r)   r&       $ c                    s    g | ]}  t|d jqS rb  )rn   r   rM  r  r  r3  r(   r)   r&    rk  c                 S       g | ]}t d d |D qS )c                 S   r  r(   r  r  r(   r(   r)   r&    r  r  r  r  r(   r(   r)   r&    rk  c                    s$   g | ]}  ||jd  dqS )r   r  )r   reshaper   r  r3  r(   r)   r&    r  c                    r  r(   )r  r  r3  r(   r)   r&    r   c                    *   g | ]}| |jd   jdd dqS r   r  r   r  r   rS   	transposer  r3  r(   r)   r&    r  c                    s   g | ]	}t | jqS r(   r   r  r   r  r3  r(   r)   r&    r`   c                    r  r(   r  r  r3  r(   r)   r&    r   c                    s"   g | ]\}}t | j| qS r(   r  )rJ   r  r  r3  r(   r)   r&    rg  c                 S   s&   g | ]}t jjj|d dd dqS )r   rq  r    )r   r   
functionalsoftmaxr  r  r(   r(   r)   r&    rF  c                    r  r(   )r  r  r3  r(   r)   r&    r   c                    r  r  r  r  r3  r(   r)   r&    r  c                 S   s"   g | ]\}}t ||d qS )r   )r   r  r~  )rJ   r   r  r(   r(   r)   r&    rg  c                 S   s   g | ]}| d qS )r  )r  r  r(   r(   r)   r&    r   c              	      s(   g | ]}    j j|  qS r(   )rn   r   r   r   r  r3  r(   r)   r&    rm  c                 s   r  r0  rG  r  r(   r(   r)   r4    r  c                 S   ra  )c                 S   r  r(   r  r  r(   r(   r)   r&    r  r  r(   r  r(   r(   r)   r&    r`   c                    r  r(   r  r  r3  r(   r)   r&    r  c                 S   s   g | ]}|j d dqS rw  )rM  r  r(   r(   r)   r&    re  c                 S   rn  r(   )r  r  r(   r(   r)   r&    ro  c                 S   r  )c                 S   r  r(   )r   r   r  r(   r(   r)   r&    r  r  r  r  r(   r(   r)   r&    rk  c                 S   s   g | ]}|j d qS rb  )indicesr  )rJ   uhxr(   r(   r)   r&    re  c                 S   s"   g | ]\}}| d |d qS rb  )gatherr~  )rJ   ncxncir(   r(   r)   r&    rg  r   c                 S   s   g | ]}|j j qS r(   r  r  r(   r(   r)   r&    r  )rd  rc  ).rR   r-   r5   r8   rM  r   r   rQ   r  r*  r  r   r  r  r   r   r   rt  r  r   r   r   r   r   r  r6   r;   r9   rn   r   r   rK  r  r,   r<   r=   r7   r:   r   r   r   
isinstancerN  r   r@   r   ) r	  labelschildren_listsnode_hxlabel_hxunpacked_hx	packed_hx
forward_hxbackward_hxr   lstm_lengthslstm_hxlstm_cxlabel_indicesr  stacked_nhx	bigram_hxquery_hxqueriesweightsvalue_hxnode_cxnode_cx_indicesunpacked_cx	packed_cxr=  r   r   rd  rc  noder   r(   )r  r  r	  r   r)   build_constituents`  s   
 
 8


	

2


HzLSTMModel.build_constituentsc                 C   s6   dd |D }t jdd |D dd}| j|||S )Nc                 S   rn  r(   r  )rJ   r  r(   r(   r)   r&    ro  z/LSTMModel.push_constituents.<locals>.<listcomp>c                 S   s   g | ]	}|j d d qS )r  N)r   r  r(   r(   r)   r&    r`   r   r|  )r   r  rT   push_states)r	  constituent_stacksr   current_nodesconstituent_inputr(   r(   r)   push_constituents  s   
zLSTMModel.push_constituentsc                 C   s
   |j j j S )z
        Extract only the top constituent from a state's constituent
        sequence, even though it has multiple addition pieces of
        information
        r  )r	  r   r(   r(   r)   get_top_constituent  s   
zLSTMModel.get_top_constituentc                    s8   t  fdd|D } |d} j|||S )z
        Push all of the given transitions on to the stack as a batch operations.

        Significantly faster than doing one transition at a time.
        c                    s   g | ]
} j  j|  qS r(   )r   r   )rJ   
transitionr3  r(   r)   r&    r  z.LSTMModel.push_transitions.<locals>.<listcomp>r   )r   r  r   r  rW   r  )r	  transition_stacksr   transition_idxtransition_inputr(   r3  r)   push_transitions  s   zLSTMModel.push_transitionsc                 C   s   |j j S )z
        Extract only the top transition from a state's transition
        sequence, even though it has multiple addition pieces of
        information
        r  )r	  r   r(   r(   r)   get_top_transition  s   zLSTMModel.get_top_transitionc                    s   t dd |D }t  fdd|D }t  fdd|D }t j|||fdd}t jD ]\}} |} jsK|t jd k rK |}||}q1|S )a  
        Return logits for a prediction of what transition to make next

        We've basically done all the work analyzing the state as
        part of applying the transitions, so this method is very simple

        return shape: (num_states, num_transitions)
        c                 S   s   g | ]	}| |jjqS r(   )r  r  r   rJ   r  r(   r(   r)   r&  +  r`   z%LSTMModel.forward.<locals>.<listcomp>c                       g | ]	} j |jqS r(   )rW   outputr   r  r3  r(   r)   r&  ,  r`   c                    r  r(   )rT   r  r   r  r3  r(   r)   r&  2  r`   r   r|  )	r   r  rt  r   r  rm   r   r   r   )r	  statesword_hxtransition_hxconstituent_hxr   r   r/  r(   r3  r)   forward"  s   	


zLSTMModel.forwardTc                    s    |}tj|dd tj| ddd}     fddtt|D }|rvt	t
||D ]A\}\}}||su||ddf jdd\}	}
|
D ]}j| |rlj| ||< |||f ||<  n	qPd||< d||< q4|||dfS )a5  
        Generate and return predictions, along with the transitions those predictions represent

        If is_legal is set to True, will only return legal transitions.
        This means returning None if there are no legal transitions.
        Hopefully the constraints prevent that from happening
        r   rq  c                    s   g | ]	}j  |  qS r(   )r   rJ   r   pred_maxr	  r(   r)   r&  I  r`   z%LSTMModel.predict.<locals>.<listcomp>NT)
descending)r"  r   argmaxtake_along_dimr  r{  cpur   r   r   r*  is_legalsortr   r~  )r	  r  r*  predictionsscores
pred_transr   r  transr=  r  indexr(   r$  r)   predict<  s&   
zLSTMModel.predictc           	         s     |}g }g }t||D ]C\} fddt|jd D }t|dkr-|d q|| }tj|dd}t|d}|| }| j	|  |||  qt
|}|||fS )z
        Generate and return predictions, and randomly choose a prediction weighted by the scores

        TODO: pass in a temperature
        c                    s"   g | ]} j |  r|qS r(   )r   r*  r#  r	  r  r(   r)   r&  c  rg  z-LSTMModel.weighted_choice.<locals>.<listcomp>r   Nrq  r   )r"  r*  r   r   r   rK  r   r  multinomialr   r  )	r	  r  r,  r.  
all_scores
prediction	legal_idxr-  r   r(   r2  r)   weighted_choiceY  s    



zLSTMModel.weighted_choicec                    s\     |}dd |D }tj fdd|D |jd}tj||ddd}|||dfS )zK
        For each State, return the next item in the gold_sequence
        c                 S   s   g | ]}|j |j qS r(   )gold_sequencenum_transitions)rJ   r  r(   r(   r)   r&  u  re  z*LSTMModel.predict_gold.<locals>.<listcomp>c                    r  r(   )r   )rJ   rg   r3  r(   r)   r&  v  r   r  r   rq  )r"  r   r   r  r(  r  r~  )r	  r  r,  r   r  r-  r(   r3  r)   predict_goldp  s
   
zLSTMModel.predict_goldc                    s      }|r fdd| D }|D ]}||= qt j}|d j|d< |d j|d< |d j|d< |d j|d< |d j|d< t jtsJJ |d|d	d  j	D  j
 j jt j j j  d
}|S )z7
        Get a dictionary for saving the model
        c                    s   g | ]	}  |r|qS r(   )rA  )rJ   kr3  r(   r)   r&    r`   z(LSTMModel.get_params.<locals>.<listcomp>rq   rR   rW   rT   r?   r+   c                 S   s   g | ]}t |qS r(   )reprr  r(   r(   r)   r&    r  )model
model_typer   r   r   r   r  r   rC   r   rA   )
state_dictkeyscopydeepcopyr   r  r  r   r   r   r   r   r   r   rC   r   rA   )r	  skip_modulesmodel_stateskippedr;  r   paramsr(   r3  r)   
get_paramsz  s2   zLSTMModel.get_params)T)#r"   r#   r$   r   staticmethodr   r   r$  r  r6  propertyr7  r:  r   rA  rX  r^  r_  r  r  r  r  r  r  r  r  r  r  r  r"  r1  r7  r:  rG  __classcell__r(   r(   r  r)   r>      sJ      c

%
  	
	

r>   )7__doc__collectionsr   rA  enumr   loggingmathri  r   torch.nnr   torch.nn.utils.rnnr   #stanza.models.common.bert_embeddingr   "stanza.models.common.maxout_linearr   stanza.models.common.utilsr   r	   stanza.models.common.vocabr
   r   %stanza.models.constituency.base_modelr   *stanza.models.constituency.label_attentionr   *stanza.models.constituency.lstm_tree_stackr   ,stanza.models.constituency.parse_transitionsr   %stanza.models.constituency.parse_treer   2stanza.models.constituency.partitioned_transformerr   .stanza.models.constituency.positional_encodingr   1stanza.models.constituency.transformer_tree_stackr   %stanza.models.constituency.tree_stackr    stanza.models.constituency.utilsr   r   	getLoggerr   tloggerr   r   r   r*   r-   Moduler>   r(   r(   r(   r)   <module>   sB    

 