o
    hMS                     @   s   d Z ddlZddlZddlZddlZddlZddlmZ ddlm	Z	 ddl
mZ ddlmZmZ ddlmZ ddlmZ ddlmZmZ dd	lmZmZ ed
ZG dd dejZG dd deZdddZdddZe dkrue  dS dS )a  
Prototype of ensembling N models together on the same dataset

The main inference method is to run the normal transition sequence,
but sum the scores for the N models and use that to choose the highest
scoring transition

Example of how to run it to build a silver dataset
(or just parse a text file in general):

# first, use this tool to build a saved ensemble
python3 stanza/models/constituency/ensemble.py
   saved_models/constituency/wsj_inorder_?.pt
   --save_name saved_models/constituency/en_ensemble.pt

# then use the ensemble directly as a model in constituency_parser.py
python3 stanza/models/constituency_parser.py
   --save_name saved_models/constituency/en_ensemble.pt
   --mode parse_text
   --tokenized_file /nlp/scr/horatio/en_silver/en_split_100
   --predict_file /nlp/scr/horatio/en_silver/en_split_100.inorder.mrg
   --retag_package en_combined_bert
   --lang en

then, ideally, run a second time with a set of topdown models,
then take the trees which match from the files
    N)utils)FoundationCache)BaseTrainer	ModelType
MultiState)Trainer)build_optimizerbuild_scheduler)ParseResult
ScoredTreezstanza.constituency.trainerc                       s  e Zd Zd8 fdd	Zdd Zd9 fdd	Zed	d
 Zedd Zedd Z	dd Z
dd Zdd Zedd Zedd Zedd Zdd Zdd Zdd  Zd!d" Zd#d$ Zd%d& Zd'd( Zd)d* Zd+d, Zd9d-d.Zd:d0d1Zd2d3 Zd;d4d5Zd;d6d7Z  ZS )<EnsembleNc                    s.  t     | _|r3|rtddu rt t|tr|g}tdd	|  fdd|D }n|s9tdt
|| _t| jD ]\}}| jd  | krhtd	|d || | jd  | | jd j|jkrtd
|d  d||  d|d  d| jd j d||  d|j | jd j|jkrtd|d || f | jd j|jkrtd|d || f | jd  | krtd|d || f | jd j|jkrtd|d || f qD| jd j| _|   tdt| j | dtj
tjt| jt| jdd dS )z
        Loads each model in filenames

        If foundation_cache is None, we build one on our own,
        as the expectation is the models will reuse modules
        such as pretrain, charlm, bert
        z6both filenames and models set when making the EnsembleNzModels used for ensemble:
  %s
  c                    s    g | ]}t j| d djqS )F)load_optimizerfoundation_cache)r   loadmodel).0filenameargsr    ^/var/www/html/env_mimamsha/lib/python3.10/site-packages/stanza/models/constituency/ensemble.py
<listcomp>F        z%Ensemble.__init__.<locals>.<listcomp>z"filenames and models both not set!r   z,Models {} and {} are incompatible.  {} vs {}zModels z and z) are incompatible: different transitions
z:

z9Models %s and %s are incompatible: different constituentsz8Models %s and %s are incompatible: different root_labelsz6Models %s and %s are incompatible: different uses_xposz=Models %s and %s are incompatible: different reverse_sentencez$Number of models in the Ensemble: %dweighted_sumT)requires_grad)super__init__r   
ValueErrorr   
isinstancestrloggerinfojoinnn
ModuleListmodels	enumeratetransition_schemeformattransitionsconstituentsroot_labels	uses_xposreverse_sentence_reverse_sentencedetach_submodelsdebuglenregister_parametertorch	Parameterzeros)selfr   	filenamesr(   r   	model_idxr   	__class__r   r   r   1   s@   

*F0zEnsemble.__init__c                 C   s(   | j D ]}| D ]\}}d|_q	qd S )NF)r(   named_parametersr   )r9   r   _	parameterr   r   r   r2   b   s
   
zEnsemble.detach_submodelsTc                    s    t  | |r|   d S d S N)r   trainr2   )r9   moder<   r   r   rB   h   s   zEnsemble.trainc                 C      | j d jS Nr   )r(   r,   r9   r   r   r   r,   o      zEnsemble.transitionsc                 C   rD   rE   )r(   r.   rF   r   r   r   r.   s   rG   zEnsemble.root_labelsc                 C   s   t |  jS rA   )next
parametersdevicerF   r   r   r   rJ   w   s   zEnsemble.devicec                 C   s   t dd | jD S )zF
        Limit on the number of consecutive unary transitions
        c                 s   s    | ]}|  V  qd S rA   )unary_limitr   mr   r   r   	<genexpr>   s    z'Ensemble.unary_limit.<locals>.<genexpr>)minr(   rF   r   r   r   rK   {   s   zEnsemble.unary_limitc                 C      | j d  S rE   )r(   r*   rF   r   r   r   r*         zEnsemble.transition_schemec                 C   rP   rE   )r(   has_unary_transitionsrF   r   r   r   rR      rQ   zEnsemble.has_unary_transitionsc                 C   rD   rE   )r(   is_top_downrF   r   r   r   rS      rG   zEnsemble.is_top_downc                 C   s   | j S rA   )r1   rF   r   r   r   r0         zEnsemble.reverse_sentencec                 C   s   | j d jd S )Nr   retag_method)r(   r   rF   r   r   r   rU      s   zEnsemble.retag_methodc                 C   rP   rE   )r(   r/   rF   r   r   r   r/      rQ   zEnsemble.uses_xposc                 C      | j d |S rE   )r(   get_top_constituent)r9   r-   r   r   r   rW         zEnsemble.get_top_constituentc                 C   rV   rE   )r(   get_top_transition)r9   r,   r   r   r   rY      rX   zEnsemble.get_top_transitionc           	   	   C   s   dg}|   D ]-\}}|jr4|ds4t| dk  }dt|  }|d||||	 f  qt
| jD ]\}}| }t|dkrT|d|  || q:td| d S )	NNORMS FOR MODEL PARAMETERSmodels.gư>z%.6gz%s %s %d %dr   z  ---- MODEL %d ----r   )r>   r   
startswithr6   sumabsitemnormappendnelementr)   r(   	get_normsr4   extendr#   r$   r%   )	r9   linesnameparamr8   r`   r;   r   sublinesr   r   r   	log_norms   s   
zEnsemble.log_normsc                 C   sF   dg}|   D ]\}}|jr|d||j qtd| d S )NrZ   z{} {}r   )r>   r   ra   r+   shaper#   r$   r%   )r9   re   rf   rg   r   r   r   
log_shapes   s   zEnsemble.log_shapesc                 C   s0   |   }dd | D }|dd | jD dS )Nc                 S   s    i | ]\}}| d s||qS )r[   )r\   )r   kvr   r   r   
<dictcomp>   r   z'Ensemble.get_params.<locals>.<dictcomp>c                 S   s   g | ]}|  qS r   )
get_paramsr   xr   r   r   r      s    z'Ensemble.get_params.<locals>.<listcomp>)base_paramschildren_params)
state_dictitemsr(   )r9   model_stater   r   r   ro      s
   zEnsemble.get_paramsc                    s>    fdd| j D }tt| }dd t| D }|S )Nc                    s   g | ]	}|  qS r   )initial_state_from_preterminalsr   r   gold_sequences
gold_treespreterminal_listsr   r   r          z<Ensemble.initial_state_from_preterminals.<locals>.<listcomp>c                 S   s    g | ]\}}}t |||d qS )        r   )r   states	gold_treegold_sequencer   r   r   r      s    )r(   listzip)r9   r|   r{   rz   state_batchr   ry   r   rw      s   
z(Ensemble.initial_state_from_preterminalsc                    p   g  t |D ]}t|d}|du r n | qt dkr6 fdd| jD  tt   dd  D   S )z
        Read from the data_iterator batch_size tagged sentences and turn them into new parsing states

        Expects a list of list of (word, tag)
        Nr   c                       g | ]}|  qS r   )initial_state_from_wordsrx   r   r   r   r          z:Ensemble.build_batch_from_tagged_words.<locals>.<listcomp>c                 S      g | ]	}t |d d dqS Nr~   r   r   r   r   r   r   r      r}   rangerH   ra   r4   r(   r   r   )r9   
batch_sizedata_iteratorr?   sentencer   r   r   build_batch_from_tagged_words   s   
z&Ensemble.build_batch_from_tagged_wordsc                    r   )zk
        Read from the data_iterator batch_size trees and turn them into N lists of parsing states
        Nr   c                    r   r   )initial_state_from_gold_treesrx   r   r   r   r      r   z3Ensemble.build_batch_from_trees.<locals>.<listcomp>c                 S   r   r   r   r   r   r   r   r      r}   r   )r9   r   r   r?   r   r   r   r   build_batch_from_trees   s   
zEnsemble.build_batch_from_treesc                    s`  t tdd |D  }dd t| j|D }tj|dd}td|| j}tj|dd| }| jd  tj|ddtj	|
ddd}   fd	dtt|d D }|rtt|d |D ]A\}\}}	|	| s||d d f jd
d\}
}|D ]} j| | r j| ||< |||f ||<  n	qd ||< d ||< qf|||dfS )Nc                 S      g | ]}|j qS r   r   rp   r   r   r   r          z$Ensemble.predict.<locals>.<listcomp>c                 S   s   g | ]	\}}| |qS r   )forward)r   r   r   r   r   r   r      r}      )dimz
BTM,MT->BTr      c                    s   g | ]	} j |  qS r   )r,   )r   idxr   pred_maxr   r   r      r}   T)
descending)r   r   r(   r6   stackeinsumr   r]   argmaxtake_along_dim	unsqueezedetachcpur   r4   r)   is_legalsortr,   squeeze)r9   r   r   predictionsflat_predictionsscores
pred_transr   statetransr?   indicesindexr   r   r   predict   s0   
 zEnsemble.predictFc                    sZ   g }t tdd |D  } fddt| j|D }t t| }dd t||D }|S )Nc                 S   r   r   r   rp   r   r   r   r   
  r   z'Ensemble.bulk_apply.<locals>.<listcomp>c                    s    g | ]\}}|j | d qS ))fail)
bulk_applyr   rq   yr   r,   r   r   r     r   c                 S   s   g | ]
\}}|j |d qS )r   )_replacer   r   r   r   r         )r   r   r(   )r9   r   r,   r   
new_statesr   r   r   r   r     s   zEnsemble.bulk_applyc                 C   sL   t dt| |   t|}| j|| j|| jddd}dd |D }|S )a  
        This parses tagged words and returns a list of trees.

        `parse_tagged_words` is useful at Pipeline time -
          it takes words & tags and processes that into trees.

        The tagged words should be represented:
          one list per sentence
            each sentence is a list of (word, tag)
        The return value is a list of ParseTree objects

        TODO: this really ought to be refactored with base_model
        zProcessing %d sentencesF)
keep_statekeep_constituentsc                 S   s   g | ]}|j d  jqS )r   )r   tree)r   tr   r   r   r   $  s    z/Ensemble.parse_tagged_words.<locals>.<listcomp>)r#   r3   r4   evaliterparse_sentences_no_gradr   r   )r9   wordsr   sentence_iteratortreebankresultsr   r   r   parse_tagged_words  s   zEnsemble.parse_tagged_wordsc                    s  g }g }	|||}
t tt|
}tg }|rtt }t|
dkr||
\}}}| |
|}
t  t|
D ]2\}}|| rg|	| }| j
rJ| }|j}|t|t|dgdd |	||   | q5t dkr fddt|
D }
 fddt|D }t|t|
 D ]/}t|d}|s|||}t|dkr nt|}t|d}|
| |t|t|
  qt|
dks!t||	}|S )a
  
        Repeat transitions to build a list of trees from the input batches.

        The data_iterator should be anything which returns the data for a parse task via next()
        build_batch_fn is a function that turns that data into State objects
        This will be called to generate batches of size batch_size until the data is exhausted

        The return is a list of tuples: (gold_tree, [(predicted, score) ...])
        gold_tree will be left blank if the data did not include gold trees
        currently score is always 1.0, but the interface may be expanded
        to get a score from the result of the parsing

        transition_choice: which method of the model to use for
        choosing the next transition

        TODO: refactor with base_model
        r   Nc                       g | ]
\}}| vr|qS r   r   )r   r   r   remover   r   r   X  r   z,Ensemble.parse_sentences.<locals>.<listcomp>c                    r   r   r   )r   r   	batch_idxr   r   r   r   Y  r   )r   r   r4   r   defaultdictr   setr)   finishedget_treer0   reverser   ra   r   r   addrH   r   unsort)r9   r   build_batch_fnr   transition_choicer   r   keep_scoresr   treebank_indicesr   batch_indiceshorizon_iteratorr-   pred_scoresr,   r   r   r   predicted_treer   r?   horizon_statehorizon_batchr   r   r   parse_sentences'  sJ   







#zEnsemble.parse_sentencesc              
   C   sB   t   | |||||||W  d    S 1 sw   Y  d S rA   )r6   no_gradr   )r9   r   r   r   r   r   r   r   r   r   r   r   j  s   
$z Ensemble.parse_sentences_no_grad)NNN)T)F)FFF)__name__
__module____qualname__r   r2   rB   propertyr,   r.   rJ   rK   r*   rR   rS   r0   rU   r/   rW   rY   ri   rk   ro   rw   r   r   r   r   r   r   r   __classcell__r   r   r<   r   r   0   sB    1





	

!	
Cr   c                       s~   e Zd ZdZd fdd	Zeddd	Zd
d Zedd Z	dd Z
edd Zedd Zedd ZedddZ  ZS )EnsembleTrainerzj
    Stores a list of constituency models, useful for combining their results into one stronger model
    Nr   r~   Fc	           	   
      s   t  |||||||| d S rA   )r   r   )	r9   ensemble	optimizer	schedulerepochs_trainedbatches_trainedbest_f1
best_epochfirst_optimizerr<   r   r   r   r  s   zEnsembleTrainer.__init__c                 C   s(   t | ||d}|| dd }t|S )N)r   rJ   )r   togetr   )r   r:   r   r   r   r   r   
from_filesu  s   zEnsembleTrainer.from_filesc                 C   sR   g }| j jD ] }|jddr!ddlm} |||j|jd q|d  q|S )Nuse_peftFr   )get_peft_model_state_dict)adapter_name)	r   r(   r   r   peftr   ra   
bert_model	peft_name)r9   paramsr   r   r   r   r   get_peft_params{  s   zEnsembleTrainer.get_peft_paramsc                 C   s   t jS rA   )r   ENSEMBLErF   r   r   r   
model_type  rT   zEnsembleTrainer.model_typec                    sl   fdd| j jD  t fdd D r#td d t d S tddfd	d D   d S )
Nc                    r   r   )num_words_knownrL   r   r   r   r     r   z7EnsembleTrainer.log_num_words_known.<locals>.<listcomp>c                 3   s    | ]	}| d  kV  qdS )r   Nr   rp   )nwkr   r   rN     s    z6EnsembleTrainer.log_num_words_known.<locals>.<genexpr>zINumber of words in the training set known to each sub-model: %d out of %dr   zANumber of words in the training set known to the sub-models:
  %sr   c                    s   g | ]
}d |t  f qS )z%d/%d)r4   rp   r   r   r   r     r   )r   r(   allr#   r$   r4   r%   )r9   r   r   )r   r   r   log_num_words_known  s   &z#EnsembleTrainer.log_num_words_knownc                    s,    fdd}t   }||_t| ||}|S )Nc                  3   s,       D ]\} }| ds| |fV  qd S )Nr[   )r>   r\   )npr   r   r   fake_named_parameters  s   

z>EnsembleTrainer.build_optimizer.<locals>.fake_named_parameters)copyr>   r	   )r   r   r   r  
fake_modelr   r   r   r   r	     s
   
zEnsembleTrainer.build_optimizerc              
   C   sp   t | jd j| |}|dd d ur1z
||d  W |S  ty0 } ztd| |d }~ww td |S )Nr   optimizer_state_dictz Failed to load optimizer from %sz`Attempted to load optimizer to resume training, but optimizer not saved.  Creating new optimizer)	r   r	   r(   r   r   load_state_dictr    r#   r$   )r   
checkpointr   r   r   er   r   r   r     s   
zEnsembleTrainer.load_optimizerc                 C   s0   t | jd j||d}d|v r||d  |S )Nr   )r   scheduler_state_dict)r
   r(   r   r  )r   r   r  r   r   r   r   r   load_scheduler  s   zEnsembleTrainer.load_schedulerc           	         s   t | tr	| d n| }t | tr| d ni }|d u r!d gt| }|d u r,d gt| }t|t|kr@tdt| t|f t|t|krTtdt| t|f  fddt|||D }t |d}|j|dd	 | d
d }|S )Nrs   rr   z9Model file had params length %d and peft params length %dz7Model file had params length %d and peft name length %dc              	      s&   g | ]\}}}t j|| |d qS ))r   )r   model_from_params)r   model_param
peft_parampnamer   r   r   r     s    z5EnsembleTrainer.model_from_params.<locals>.<listcomp>)r(   F)strictrJ   )	r!   dictr4   r    r   r   r  r   r   )	r   peft_paramsr   r   r   rs   rr   r(   r   r   r   r   r
    s"   
z!EnsembleTrainer.model_from_params)NNr   r   r~   r   FrA   )NN)r   r   r   __doc__r   staticmethodr   r   r   r   r   r	   r   r	  r
  r   r   r   r<   r   r   n  s"    



r   c                 C   s   t  }|jdtd dd |jdtd dd |jdtd dd t| |jdd	d
d |jdtdd dd |jdtd ddd t| } | S )Nz--charlm_forward_filez$Exact path to use for forward charlm)typedefaulthelpz--charlm_backward_filez%Exact path to use for backward charlmz--wordvec_pretrain_filez'Exact name of the pretrain file to readz--langenzLanguage to use)r  r  r(   +zWhich model(s) to load)r  nargsr  r  z--save_nameTz#Where to save the combined ensemble)r  r  requiredr  )argparseArgumentParseradd_argumentr"   r   add_device_argsvars
parse_args)r   parserr   r   r   r    s   
r  c                 C   s6   t | } t }t| | d |}|j| d dd d S )Nr(   	save_nameF)save_optimizer)r  r   r   r   save)r   r   r   r   r   r   main  s   r$  __main__rA   )!r  r  r  loggingosr6   torch.nnr&   stanza.models.commonr   %stanza.models.common.foundation_cacher   'stanza.models.constituency.base_trainerr   r    stanza.models.constituency.stater   "stanza.models.constituency.trainerr    stanza.models.constituency.utilsr	   r
   stanza.server.parser_evalr   r   	getLoggerr#   Moduler   r   r  r$  r   r   r   r   r   <module>   s0    
  @
]

