-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Vocabulary class for word to index mapping (and reverse).
- Loading branch information
Karan Desai
committed
Dec 22, 2018
1 parent
dfe0590
commit 184c92e
Showing
3 changed files
with
87 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
""" | ||
A Vocabulary maintains a mapping between words and corresponding unique integers, holds special | ||
integers (tokens) for indicating start and end of sequence, and offers functionality to map | ||
out-of-vocabulary words to the corresponding token. | ||
""" | ||
import json | ||
import os | ||
|
||
|
||
class Vocabulary(object): | ||
""" | ||
A simple Vocabulary class which maintains a mapping between words and integer tokens. Can be | ||
initialized either by word counts from the VisDial v1.0 train dataset, or a pre-saved | ||
vocabulary mapping. | ||
Parameters | ||
---------- | ||
word_counts_path: str | ||
Path to a json file containing counts of each word across captions, questions and answers | ||
of the VisDial v1.0 train dataset. | ||
min_count : int, optional (default=0) | ||
When initializing the vocabulary from word counts, you can specify a minimum count, and | ||
every token with a count less than this will be excluded from vocabulary. | ||
""" | ||
|
||
PAD_TOKEN = "<PAD>" | ||
SOS_TOKEN = "<S>" | ||
EOS_TOKEN = "</S>" | ||
UNK_TOKEN = "<UNK>" | ||
|
||
PAD_INDEX = 0 | ||
SOS_INDEX = 1 | ||
EOS_INDEX = 2 | ||
UNK_INDEX = 3 | ||
|
||
def __init__(self, word_counts_path: str, min_count: int = 5): | ||
if not os.path.exists(word_counts_path): | ||
raise FileNotFoundError(f"Word counts do not exist at {word_counts_path}") | ||
|
||
with open(word_counts_path, "r") as word_counts_file: | ||
word_counts = json.load(word_counts_file) | ||
|
||
# form a list of (word, count) tuples and apply min_count threshold | ||
word_counts = [(word, count) for word, count in word_counts.items() | ||
if count >= min_count] | ||
# sort in descending order of word counts | ||
word_counts = sorted(word_counts, key=lambda wc: -wc[1]) | ||
words = [w[0] for w in word_counts] | ||
|
||
self.word2index = {} | ||
self.word2index[self.PAD_TOKEN] = self.PAD_INDEX | ||
self.word2index[self.SOS_TOKEN] = self.SOS_INDEX | ||
self.word2index[self.EOS_TOKEN] = self.EOS_INDEX | ||
self.word2index[self.UNK_TOKEN] = self.UNK_INDEX | ||
for index, word in enumerate(words): | ||
self.word2index[word] = index + 4 | ||
|
||
self.index2word = {index: word for word, index in self.word2index.items()} | ||
|
||
@classmethod | ||
def from_saved(cls, saved_vocabulary_path: str) -> 'Vocabulary': | ||
"""Build the vocabulary from a json file saved by ``save`` method. | ||
Parameters | ||
---------- | ||
saved_vocabulary_path : str | ||
Path to a json file containing word to integer mappings (saved vocabulary). | ||
""" | ||
with open(saved_vocabulary_path, "r") as saved_vocabulary_file: | ||
cls.word2index = json.load(saved_vocabulary_file) | ||
cls.index2word = {index: word for word, index in cls.word2index.items()} | ||
|
||
def get_index_by_word(self, word: str) -> int: | ||
return self.word2index.get(word, self.UNK_INDEX) | ||
|
||
def get_word_by_index(self, index: int) -> str: | ||
return self.index2word.get(index, self.UNK_TOKEN) | ||
|
||
def save(self, save_vocabulary_path: str) -> None: | ||
with open(save_vocabulary_path, "w") as save_vocabulary_file: | ||
json.dump(self.word2index, saved_vocabulary_file) | ||
|
||
def __len__(self): | ||
return len(self.index2word) |