2323
2424UNK = "<unk>"
2525PAD = "<pad>"
26- PRETRAINED_ALIASES = {
26+ GLOVE_WORD_EMBEDDING = {
2727 "glove.42B.300d" ,
2828 "glove.840B.300d" ,
2929 "glove.6B.50d" ,
@@ -164,6 +164,7 @@ def _load_raw_data(data, is_test=False, tokenize_text=True, remove_no_label_data
164164 Args:
165165 data (Union[str, pandas,.Dataframe]): Training, test, or validation data in file or dataframe.
166166 is_test (bool, optional): Whether the data is for test or not. Defaults to False.
167+ tokenize_text (bool, optional): Whether to tokenize text. Defaults to True.
167168 remove_no_label_data (bool, optional): Whether to remove training/validation instances that have no labels.
168169 This is effective only when is_test=False. Defaults to False.
169170
@@ -281,7 +282,7 @@ def load_or_build_text_dict(
281282 dataset (list): List of training instances with index, label, and tokenized text.
282283 vocab_file (str, optional): Path to a file holding vocabuaries. Defaults to None.
283284 min_vocab_freq (int, optional): The minimum frequency needed to include a token in the vocabulary. Defaults to 1.
284- embed_file (str): Path to a file holding pre-trained embeddings.
285+ embed_file (str): Path to a file holding pre-trained embeddings or the name of the pretrained GloVe embedding. Defaults to None .
285286 embed_cache_dir (str, optional): Path to a directory for storing cached embeddings. Defaults to None.
286287 silent (bool, optional): Enable silent mode. Defaults to False.
287288 normalize_embed (bool, optional): Whether the embeddings of each word is normalized to a unit vector. Defaults to False.
@@ -319,13 +320,13 @@ def load_or_build_text_dict(
319320
320321
321322def _build_word_dict (vocab_list , min_vocab_freq = 1 , specials = None ):
322- r"""Build word dictionary, modified from `torchtext.vocab.build-vocab-from-iterator`
323+ r"""Build word dictionary, modified from `torchtext.vocab.build-vocab-from-iterator`
323324 (https://docs.pytorch.org/text/stable/vocab.html#build-vocab-from-iterator)
324325
325326 Args:
326327 vocab_list: List of words.
327328 min_vocab_freq (int, optional): The minimum frequency needed to include a token in the vocabulary. Defaults to 1.
328- specials: Special tokens (e.g., <unk>, <pad>) to add.
329+ specials: Special tokens (e.g., <unk>, <pad>) to add. Defaults to None.
329330
330331 Returns:
331332 dict: A dictionary which maps tokens to indices.
@@ -339,7 +340,7 @@ def _build_word_dict(vocab_list, min_vocab_freq=1, specials=None):
339340 sorted_by_freq_tuples = sorted (counter .items (), key = lambda x : (- x [1 ], x [0 ]))
340341 ordered_dict = OrderedDict (sorted_by_freq_tuples )
341342
342- # add special tokens at the beginning
343+ # add special tokens at the beginning
343344 tokens = specials or []
344345 for token , freq in ordered_dict .items ():
345346 if freq >= min_vocab_freq :
@@ -388,27 +389,29 @@ def load_or_build_label(datasets, label_file=None, include_test_labels=False):
388389
389390
390391def get_embedding_weights_from_file (word_dict , embed_file , silent = False , cache_dir = None ):
391- """If the word exists in the embedding file, load the pretrained word embedding.
392- Otherwise, assign a zero vector to that word.
392+ """Obtain the word embeddings from file. If the word exists in the embedding file,
393+ load the pretrained word embedding. Otherwise, assign a zero vector to that word.
394+ If the given `embed_file` is the name of a pretrained GloVe embedding, the function
395+ will first download the corresponding file.
393396
394397 Args:
395398 word_dict (dict): A dictionary for mapping tokens to indices.
396- embed_file (str): Path to a file holding pre-trained embeddings.
399+ embed_file (str): Path to a file holding pre-trained embeddings or the name of the pretrained GloVe embedding .
397400 silent (bool, optional): Enable silent mode. Defaults to False.
398401 cache_dir (str, optional): Path to a directory for storing cached embeddings. Defaults to None.
399402
400403 Returns:
401404 torch.Tensor: Embedding weights (vocab_size, embed_size).
402405 """
403406
404- if embed_file in PRETRAINED_ALIASES :
405- embed_file = _download_pretrained_embedding (embed_file , cache_dir = cache_dir )
407+ if embed_file in GLOVE_WORD_EMBEDDING :
408+ embed_file = _download_glove_embedding (embed_file , cache_dir = cache_dir )
406409 elif not os .path .isfile (embed_file ):
407410 raise ValueError (
408- "Got embed_file {}, but allowed pretrained " "embeddings are {}" .format (embed_file , PRETRAINED_ALIASES )
411+ "Got embed_file {}, but allowed pretrained " "embeddings are {}" .format (embed_file , GLOVE_WORD_EMBEDDING )
409412 )
410413
411- logging .info (f"Load pretrained embedding from file: { embed_file } ." )
414+ logging .info (f"Load pretrained embedding from { embed_file } ." )
412415 with open (embed_file ) as f :
413416 word_vectors = f .readlines ()
414417 embed_size = len (word_vectors [0 ].split ()) - 1
@@ -433,31 +436,36 @@ def get_embedding_weights_from_file(word_dict, embed_file, silent=False, cache_d
433436 embedding_weights [word_dict [word ]] = vector_dict [word ]
434437 vec_counts += 1
435438
436- logging .info (f"loaded { vec_counts } /{ len (word_dict )} word embeddings" )
439+ logging .info (f"Loaded { vec_counts } /{ len (word_dict )} word embeddings" )
437440
438441 return embedding_weights
439442
440443
441- def _download_pretrained_embedding ( embed_file , cache_dir = None ):
444+ def _download_glove_embedding ( embed_name , cache_dir = None ):
442445 """Download pretrained glove embedding from https://huggingface.co/stanfordnlp/glove/tree/main.
443446
447+ Args:
448+ embed_name (str): The name of the pretrained GloVe embedding. Defaults to None.
449+ cache_dir (str, optional): Path to a directory for storing cached embeddings. Defaults to None.
450+
444451 Returns:
445- str: Path to the cached or downloaded embedding file .
452+ str: Path to the file that contains the cached embeddings .
446453 """
447- cached_embed_file = f"{ cache_dir } /{ embed_file } .txt"
454+ cache_dir = ".vector_cache" if cache_dir is None else cache_dir
455+ cached_embed_file = f"{ cache_dir } /{ embed_name } .txt"
448456 if os .path .isfile (cached_embed_file ):
449457 return cached_embed_file
450458 os .makedirs (cache_dir , exist_ok = True )
451459
452- remote_embed_file = re .sub (r"6B.*" , "6B" , embed_file ) + ".zip"
460+ remote_embed_file = re .sub (r"6B.*" , "6B" , embed_name ) + ".zip"
453461 url = f"https://huggingface.co/stanfordnlp/glove/resolve/main/{ remote_embed_file } "
454- logging .info (f"Downloading pretrained embedding from { url } ." )
462+ logging .info (f"Downloading pretrained embeddings from { url } ." )
455463 try :
456464 zip_file , _ = urlretrieve (url , f"{ cache_dir } /{ remote_embed_file } " )
457465 with zipfile .ZipFile (zip_file , "r" ) as zf :
458466 zf .extractall (cache_dir )
459467 except Exception as e :
460468 os .remove (zip_file )
461469 raise e
462-
470+ logging . info ( f"Downloaded pretrained embeddings { embed_name } to { cached_embed_file } ." )
463471 return cached_embed_file
0 commit comments