11# coding: utf-8 
22# 2021/8/1 @ tongshiwei 
33
4+ import  torch 
45import  json 
56import  os .path 
67from  typing  import  List , Tuple 
@@ -59,12 +60,12 @@ class I2V(object):
5960    """ 
6061
6162    def  __init__ (self , tokenizer , t2v , * args , tokenizer_kwargs : dict  =  None ,
62-                  pretrained_t2v = False , model_dir = MODEL_DIR , ** kwargs ):
63+                  pretrained_t2v = False , model_dir = MODEL_DIR , device = 'cpu' ,  ** kwargs ):
6364        if  pretrained_t2v :
6465            logger .info ("Use pretrained t2v model %s"  %  t2v )
65-             self .t2v  =  get_t2v_pretrained_model (t2v , model_dir )
66+             self .t2v  =  get_t2v_pretrained_model (t2v , model_dir ,  device )
6667        else :
67-             self .t2v  =  T2V (t2v , * args , ** kwargs )
68+             self .t2v  =  T2V (t2v , device = device ,  * args , ** kwargs )
6869        if  tokenizer  ==  'bert' :
6970            self .tokenizer  =  BertTokenizer .from_pretrained (
7071                ** tokenizer_kwargs  if  tokenizer_kwargs  is  not None  else  {})
@@ -82,31 +83,53 @@ def __init__(self, tokenizer, t2v, *args, tokenizer_kwargs: dict = None,
8283                                                      ** tokenizer_kwargs  if  tokenizer_kwargs  is  not None  else  {})
8384        self .params  =  {
8485            "tokenizer" : tokenizer ,
85-             "tokenizer_kwargs" : tokenizer_kwargs ,
8686            "t2v" : t2v ,
8787            "args" : args ,
88+             "tokenizer_kwargs" : tokenizer_kwargs ,
89+             "pretrained_t2v" : pretrained_t2v ,
90+             "model_dir" : model_dir ,
8891            "kwargs" : kwargs ,
89-             "pretrained_t2v" : pretrained_t2v 
9092        }
93+         self .device  =  torch .device (device )
9194
9295    def  __call__ (self , items , * args , ** kwargs ):
9396        """transfer item to vector""" 
9497        return  self .infer_vector (items , * args , ** kwargs )
9598
9699    def  tokenize (self , items , * args , key = lambda  x : x , ** kwargs ) ->  list :
97-         # """tokenize item""" 
100+         """ 
101+         tokenize item 
102+         Parameter 
103+         ---------- 
104+         items: a list of questions 
105+         Return 
106+         ---------- 
107+         tokens: list 
108+         """ 
98109        return  self .tokenizer (items , * args , key = key , ** kwargs )
99110
100111    def  infer_vector (self , items , key = lambda  x : x , ** kwargs ) ->  tuple :
112+         """ 
113+         get question embedding 
114+         NotImplemented 
115+         """ 
101116        raise  NotImplementedError 
102117
103118    def  infer_item_vector (self , tokens , * args , ** kwargs ) ->  ...:
119+         """NotImplemented""" 
104120        return  self .infer_vector (tokens , * args , ** kwargs )[0 ]
105121
106122    def  infer_token_vector (self , tokens , * args , ** kwargs ) ->  ...:
123+         """NotImplemented""" 
107124        return  self .infer_vector (tokens , * args , ** kwargs )[1 ]
108125
109126    def  save (self , config_path ):
127+         """ 
128+         save model weights in config_path 
129+         Parameter: 
130+         ---------- 
131+         config_path: str 
132+         """ 
110133        with  open (config_path , "w" , encoding = "utf-8" ) as  wf :
111134            json .dump (self .params , wf , ensure_ascii = False , indent = 2 )
112135
@@ -123,6 +146,7 @@ def load(cls, config_path, *args, **kwargs):
123146
124147    @classmethod  
125148    def  from_pretrained (cls , name , model_dir = MODEL_DIR , * args , ** kwargs ):
149+         """NotImplemented""" 
126150        raise  NotImplementedError 
127151
128152    @property  
@@ -327,13 +351,13 @@ def infer_vector(self, items: Tuple[List[str], List[dict], str, dict],
327351        return  self .t2v .infer_vector (inputs , * args , ** kwargs ), self .t2v .infer_tokens (inputs , * args , ** kwargs )
328352
329353    @classmethod  
330-     def  from_pretrained (cls , name , model_dir = MODEL_DIR , * args , ** kwargs ):
354+     def  from_pretrained (cls , name , model_dir = MODEL_DIR , device = 'cpu' ,  * args , ** kwargs ):
331355        model_path  =  path_append (model_dir , get_pretrained_model_info (name )[0 ].split ('/' )[- 1 ], to_str = True )
332356        for  i  in  [".tar.gz" , ".tar.bz2" , ".tar.bz" , ".tar.tgz" , ".tar" , ".tgz" , ".zip" , ".rar" ]:
333357            model_path  =  model_path .replace (i , "" )
334358        logger .info ("model_path: %s"  %  model_path )
335359        tokenizer_kwargs  =  {"tokenizer_config_dir" : model_path }
336-         return  cls ("elmo" , name , pretrained_t2v = True , model_dir = model_dir ,
360+         return  cls ("elmo" , name , pretrained_t2v = True , model_dir = model_dir ,  device = device , 
337361                   tokenizer_kwargs = tokenizer_kwargs )
338362
339363
@@ -386,17 +410,19 @@ def infer_vector(self, items: Tuple[List[str], List[dict], str, dict],
386410        -------- 
387411        vector:list 
388412        """ 
413+         is_batch  =  isinstance (items , list )
414+         items  =  items  if  is_batch  else  [items ]
389415        inputs  =  self .tokenize (items , key = key , return_tensors = return_tensors )
390416        return  self .t2v .infer_vector (inputs , * args , ** kwargs ), self .t2v .infer_tokens (inputs , * args , ** kwargs )
391417
392418    @classmethod  
393-     def  from_pretrained (cls , name , model_dir = MODEL_DIR , * args , ** kwargs ):
419+     def  from_pretrained (cls , name , model_dir = MODEL_DIR , device = 'cpu' ,  * args , ** kwargs ):
394420        model_path  =  path_append (model_dir , get_pretrained_model_info (name )[0 ].split ('/' )[- 1 ], to_str = True )
395421        for  i  in  [".tar.gz" , ".tar.bz2" , ".tar.bz" , ".tar.tgz" , ".tar" , ".tgz" , ".zip" , ".rar" ]:
396422            model_path  =  model_path .replace (i , "" )
397423        logger .info ("model_path: %s"  %  model_path )
398424        tokenizer_kwargs  =  {"tokenizer_config_dir" : model_path }
399-         return  cls ("bert" , name , pretrained_t2v = True , model_dir = model_dir ,
425+         return  cls ("bert" , name , pretrained_t2v = True , model_dir = model_dir ,  device = device , 
400426                   tokenizer_kwargs = tokenizer_kwargs )
401427
402428
@@ -452,7 +478,7 @@ def infer_vector(self, items: Tuple[List[str], List[dict], str, dict],
452478        return  i_vec , t_vec 
453479
454480    @classmethod  
455-     def  from_pretrained (cls , name , model_dir = MODEL_DIR , ** kwargs ):
481+     def  from_pretrained (cls , name , model_dir = MODEL_DIR , device = 'cpu' ,  ** kwargs ):
456482        model_path  =  path_append (model_dir , get_pretrained_model_info (name )[0 ].split ('/' )[- 1 ], to_str = True )
457483        for  i  in  [".tar.gz" , ".tar.bz2" , ".tar.bz" , ".tar.tgz" , ".tar" , ".tgz" , ".zip" , ".rar" ]:
458484            model_path  =  model_path .replace (i , "" )
@@ -461,7 +487,7 @@ def from_pretrained(cls, name, model_dir=MODEL_DIR, **kwargs):
461487        tokenizer_kwargs  =  {
462488            "tokenizer_config_dir" : model_path ,
463489        }
464-         return  cls ("disenq" , name , pretrained_t2v = True , model_dir = model_dir ,
490+         return  cls ("disenq" , name , pretrained_t2v = True , model_dir = model_dir ,  device = device , 
465491                   tokenizer_kwargs = tokenizer_kwargs , ** kwargs )
466492
467493
@@ -495,18 +521,20 @@ def infer_vector(self, items: Tuple[List[str], List[dict], str, dict],
495521        token embeddings 
496522        question embedding 
497523        """ 
524+         is_batch  =  isinstance (items , list )
525+         items  =  items  if  is_batch  else  [items ]
498526        encodes  =  self .tokenize (items , key = key , meta = meta , * args , ** kwargs )
499527        return  self .t2v .infer_vector (encodes ), self .t2v .infer_tokens (encodes )
500528
501529    @classmethod  
502-     def  from_pretrained (cls , name , model_dir = MODEL_DIR , * args , ** kwargs ):
530+     def  from_pretrained (cls , name , model_dir = MODEL_DIR , device = 'cpu' ,  * args , ** kwargs ):
503531        model_path  =  path_append (model_dir , get_pretrained_model_info (name )[0 ].split ('/' )[- 1 ], to_str = True )
504532        for  i  in  [".tar.gz" , ".tar.bz2" , ".tar.bz" , ".tar.tgz" , ".tar" , ".tgz" , ".zip" , ".rar" ]:
505533            model_path  =  model_path .replace (i , "" )
506534        logger .info ("model_path: %s"  %  model_path )
507535        tokenizer_kwargs  =  {
508536            "tokenizer_config_dir" : model_path }
509-         return  cls ("quesnet" , name , pretrained_t2v = True , model_dir = model_dir ,
537+         return  cls ("quesnet" , name , pretrained_t2v = True , model_dir = model_dir ,  device = device , 
510538                   tokenizer_kwargs = tokenizer_kwargs )
511539
512540
@@ -520,7 +548,7 @@ def from_pretrained(cls, name, model_dir=MODEL_DIR, *args, **kwargs):
520548}
521549
522550
523- def  get_pretrained_i2v (name , model_dir = MODEL_DIR ):
551+ def  get_pretrained_i2v (name , model_dir = MODEL_DIR ,  device = 'cpu' ):
524552    """ 
525553    It is a good idea if you want to switch item to vector earily. 
526554
@@ -560,4 +588,4 @@ def get_pretrained_i2v(name, model_dir=MODEL_DIR):
560588        )
561589    _ , t2v  =  get_pretrained_model_info (name )
562590    _class , * params  =  MODEL_MAP [t2v ], name 
563-     return  _class .from_pretrained (* params , model_dir = model_dir )
591+     return  _class .from_pretrained (* params , model_dir = model_dir ,  device = device )
0 commit comments