Skip to content

Commit 948f627

Browse files
committed
Update data_utils.py: (1) finalize doc strings (2) move ".vector_cache" from main to _down.. for API
1 parent 5db5943 commit 948f627

File tree

3 files changed

+29
-22
lines changed

3 files changed

+29
-22
lines changed

docs/cli/nn.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ If a model was trained before by this package, the training procedure can start
7777

7878
To use your own word embeddings or vocabulary set, specify the following parameters:
7979

80-
- **embed_file**: choose one of the pretrained embeddings: `glove.6B.50d`, `glove.6B.100d`, `glove.6B.200d`, `glove.6B.300d`, `glove.42B.300d`, `glove.840B.300d`, or specify the path to your word embeddings with each line containing a word followed by its vectors.
80+
- **embed_file**: choose one of the pretrained embeddings: `glove.6B.50d`, `glove.6B.100d`, `glove.6B.200d`, `glove.6B.300d`, `glove.42B.300d`, `glove.840B.300d`, or specify the path to your word embeddings with each line containing a word followed by its vectors. Example:
8181

8282
.. code-block::
8383

libmultilabel/nn/data_utils.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
UNK = "<unk>"
2525
PAD = "<pad>"
26-
PRETRAINED_ALIASES = {
26+
GLOVE_WORD_EMBEDDING = {
2727
"glove.42B.300d",
2828
"glove.840B.300d",
2929
"glove.6B.50d",
@@ -164,6 +164,7 @@ def _load_raw_data(data, is_test=False, tokenize_text=True, remove_no_label_data
164164
Args:
165165
data (Union[str, pandas,.Dataframe]): Training, test, or validation data in file or dataframe.
166166
is_test (bool, optional): Whether the data is for test or not. Defaults to False.
167+
tokenize_text (bool, optional): Whether to tokenize text. Defaults to True.
167168
remove_no_label_data (bool, optional): Whether to remove training/validation instances that have no labels.
168169
This is effective only when is_test=False. Defaults to False.
169170
@@ -281,7 +282,7 @@ def load_or_build_text_dict(
281282
dataset (list): List of training instances with index, label, and tokenized text.
282283
vocab_file (str, optional): Path to a file holding vocabuaries. Defaults to None.
283284
min_vocab_freq (int, optional): The minimum frequency needed to include a token in the vocabulary. Defaults to 1.
284-
embed_file (str): Path to a file holding pre-trained embeddings.
285+
embed_file (str): Path to a file holding pre-trained embeddings or the name of the pretrained GloVe embedding. Defaults to None.
285286
embed_cache_dir (str, optional): Path to a directory for storing cached embeddings. Defaults to None.
286287
silent (bool, optional): Enable silent mode. Defaults to False.
287288
normalize_embed (bool, optional): Whether the embeddings of each word is normalized to a unit vector. Defaults to False.
@@ -319,13 +320,13 @@ def load_or_build_text_dict(
319320

320321

321322
def _build_word_dict(vocab_list, min_vocab_freq=1, specials=None):
322-
r"""Build word dictionary, modified from `torchtext.vocab.build-vocab-from-iterator`
323+
r"""Build word dictionary, modified from `torchtext.vocab.build-vocab-from-iterator`
323324
(https://docs.pytorch.org/text/stable/vocab.html#build-vocab-from-iterator)
324325
325326
Args:
326327
vocab_list: List of words.
327328
min_vocab_freq (int, optional): The minimum frequency needed to include a token in the vocabulary. Defaults to 1.
328-
specials: Special tokens (e.g., <unk>, <pad>) to add.
329+
specials: Special tokens (e.g., <unk>, <pad>) to add. Defaults to None.
329330
330331
Returns:
331332
dict: A dictionary which maps tokens to indices.
@@ -339,7 +340,7 @@ def _build_word_dict(vocab_list, min_vocab_freq=1, specials=None):
339340
sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
340341
ordered_dict = OrderedDict(sorted_by_freq_tuples)
341342

342-
# add special tokens at the beginning
343+
# add special tokens at the beginning
343344
tokens = specials or []
344345
for token, freq in ordered_dict.items():
345346
if freq >= min_vocab_freq:
@@ -388,27 +389,29 @@ def load_or_build_label(datasets, label_file=None, include_test_labels=False):
388389

389390

390391
def get_embedding_weights_from_file(word_dict, embed_file, silent=False, cache_dir=None):
391-
"""If the word exists in the embedding file, load the pretrained word embedding.
392-
Otherwise, assign a zero vector to that word.
392+
"""Obtain the word embeddings from file. If the word exists in the embedding file,
393+
load the pretrained word embedding. Otherwise, assign a zero vector to that word.
394+
If the given `embed_file` is the name of a pretrained GloVe embedding, the function
395+
will first download the corresponding file.
393396
394397
Args:
395398
word_dict (dict): A dictionary for mapping tokens to indices.
396-
embed_file (str): Path to a file holding pre-trained embeddings.
399+
embed_file (str): Path to a file holding pre-trained embeddings or the name of the pretrained GloVe embedding.
397400
silent (bool, optional): Enable silent mode. Defaults to False.
398401
cache_dir (str, optional): Path to a directory for storing cached embeddings. Defaults to None.
399402
400403
Returns:
401404
torch.Tensor: Embedding weights (vocab_size, embed_size).
402405
"""
403406

404-
if embed_file in PRETRAINED_ALIASES:
405-
embed_file = _download_pretrained_embedding(embed_file, cache_dir=cache_dir)
407+
if embed_file in GLOVE_WORD_EMBEDDING:
408+
embed_file = _download_glove_embedding(embed_file, cache_dir=cache_dir)
406409
elif not os.path.isfile(embed_file):
407410
raise ValueError(
408-
"Got embed_file {}, but allowed pretrained " "embeddings are {}".format(embed_file, PRETRAINED_ALIASES)
411+
"Got embed_file {}, but allowed pretrained " "embeddings are {}".format(embed_file, GLOVE_WORD_EMBEDDING)
409412
)
410413

411-
logging.info(f"Load pretrained embedding from file: {embed_file}.")
414+
logging.info(f"Load pretrained embedding from {embed_file}.")
412415
with open(embed_file) as f:
413416
word_vectors = f.readlines()
414417
embed_size = len(word_vectors[0].split()) - 1
@@ -433,31 +436,36 @@ def get_embedding_weights_from_file(word_dict, embed_file, silent=False, cache_d
433436
embedding_weights[word_dict[word]] = vector_dict[word]
434437
vec_counts += 1
435438

436-
logging.info(f"loaded {vec_counts}/{len(word_dict)} word embeddings")
439+
logging.info(f"Loaded {vec_counts}/{len(word_dict)} word embeddings")
437440

438441
return embedding_weights
439442

440443

441-
def _download_pretrained_embedding(embed_file, cache_dir=None):
444+
def _download_glove_embedding(embed_name, cache_dir=None):
442445
"""Download pretrained glove embedding from https://huggingface.co/stanfordnlp/glove/tree/main.
443446
447+
Args:
448+
embed_name (str): The name of the pretrained GloVe embedding. Defaults to None.
449+
cache_dir (str, optional): Path to a directory for storing cached embeddings. Defaults to None.
450+
444451
Returns:
445-
str: Path to the cached or downloaded embedding file.
452+
str: Path to the file that contains the cached embeddings.
446453
"""
447-
cached_embed_file = f"{cache_dir}/{embed_file}.txt"
454+
cache_dir = ".vector_cache" if cache_dir is None else cache_dir
455+
cached_embed_file = f"{cache_dir}/{embed_name}.txt"
448456
if os.path.isfile(cached_embed_file):
449457
return cached_embed_file
450458
os.makedirs(cache_dir, exist_ok=True)
451459

452-
remote_embed_file = re.sub(r"6B.*", "6B", embed_file) + ".zip"
460+
remote_embed_file = re.sub(r"6B.*", "6B", embed_name) + ".zip"
453461
url = f"https://huggingface.co/stanfordnlp/glove/resolve/main/{remote_embed_file}"
454-
logging.info(f"Downloading pretrained embedding from {url}.")
462+
logging.info(f"Downloading pretrained embeddings from {url}.")
455463
try:
456464
zip_file, _ = urlretrieve(url, f"{cache_dir}/{remote_embed_file}")
457465
with zipfile.ZipFile(zip_file, "r") as zf:
458466
zf.extractall(cache_dir)
459467
except Exception as e:
460468
os.remove(zip_file)
461469
raise e
462-
470+
logging.info(f"Downloaded pretrained embeddings {embed_name} to {cached_embed_file}.")
463471
return cached_embed_file

main.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def add_all_arguments(parser):
141141
# pretrained vocab / embeddings
142142
parser.add_argument("--vocab_file", type=str, help="Path to a file holding vocabuaries (default: %(default)s)")
143143
parser.add_argument(
144-
"--embed_file", type=str, help="Path to a file holding pre-trained embeddings (default: %(default)s)"
144+
"--embed_file", type=str, help="Path to a file holding pre-trained embeddings or the name of the pretrained GloVe embedding (default: %(default)s)"
145145
)
146146
parser.add_argument("--label_file", type=str, help="Path to a file holding all labels (default: %(default)s)")
147147

@@ -189,7 +189,6 @@ def add_all_arguments(parser):
189189
parser.add_argument(
190190
"--embed_cache_dir",
191191
type=str,
192-
default=".vector_cache",
193192
help="For parameter search only: path to a directory for storing embeddings for multiple runs. (default: %(default)s)",
194193
)
195194
parser.add_argument(

0 commit comments

Comments
 (0)