Skip to content

Commit

Permalink
Add Vocabulary class for word to index mapping (and reverse).
Browse files Browse the repository at this point in the history
  • Loading branch information
Karan Desai committed Dec 22, 2018
1 parent dfe0590 commit 184c92e
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 2 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ MANIFEST

# Datasets, pretrained models and preprocessed files
data/
!visdialch/data/

# Checkpoints directory
checkpoints/
Expand Down
4 changes: 2 additions & 2 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

FROM pytorch/pytorch:0.4.1-cuda9-cudnn7-devel

RUN apt-get install libglib2.0-0 libsm6 libxrender1 libxext6
RUN apt-get update && apt-get install -y libglib2.0-0 libsm6 libxrender1 libxext6

RUN pip install --upgrade pip
RUN pip install --upgrade pip && pip install cython
RUN git clone --depth 1 https://www.github.com/batra-mlp-lab/visdial-challenge-starter-pytorch /workspace && \
pip install -r /workspace/requirements.txt

Expand Down
84 changes: 84 additions & 0 deletions visdialch/data/vocabulary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""
A Vocabulary maintains a mapping between words and corresponding unique integers, holds special
integers (tokens) for indicating start and end of sequence, and offers functionality to map
out-of-vocabulary words to the corresponding token.
"""
import json
import os


class Vocabulary(object):
"""
A simple Vocabulary class which maintains a mapping between words and integer tokens. Can be
initialized either by word counts from the VisDial v1.0 train dataset, or a pre-saved
vocabulary mapping.
Parameters
----------
word_counts_path: str
Path to a json file containing counts of each word across captions, questions and answers
of the VisDial v1.0 train dataset.
min_count : int, optional (default=0)
When initializing the vocabulary from word counts, you can specify a minimum count, and
every token with a count less than this will be excluded from vocabulary.
"""

PAD_TOKEN = "<PAD>"
SOS_TOKEN = "<S>"
EOS_TOKEN = "</S>"
UNK_TOKEN = "<UNK>"

PAD_INDEX = 0
SOS_INDEX = 1
EOS_INDEX = 2
UNK_INDEX = 3

def __init__(self, word_counts_path: str, min_count: int = 5):
if not os.path.exists(word_counts_path):
raise FileNotFoundError(f"Word counts do not exist at {word_counts_path}")

with open(word_counts_path, "r") as word_counts_file:
word_counts = json.load(word_counts_file)

# form a list of (word, count) tuples and apply min_count threshold
word_counts = [(word, count) for word, count in word_counts.items()
if count >= min_count]
# sort in descending order of word counts
word_counts = sorted(word_counts, key=lambda wc: -wc[1])
words = [w[0] for w in word_counts]

self.word2index = {}
self.word2index[self.PAD_TOKEN] = self.PAD_INDEX
self.word2index[self.SOS_TOKEN] = self.SOS_INDEX
self.word2index[self.EOS_TOKEN] = self.EOS_INDEX
self.word2index[self.UNK_TOKEN] = self.UNK_INDEX
for index, word in enumerate(words):
self.word2index[word] = index + 4

self.index2word = {index: word for word, index in self.word2index.items()}

@classmethod
def from_saved(cls, saved_vocabulary_path: str) -> 'Vocabulary':
"""Build the vocabulary from a json file saved by ``save`` method.
Parameters
----------
saved_vocabulary_path : str
Path to a json file containing word to integer mappings (saved vocabulary).
"""
with open(saved_vocabulary_path, "r") as saved_vocabulary_file:
cls.word2index = json.load(saved_vocabulary_file)
cls.index2word = {index: word for word, index in cls.word2index.items()}

def get_index_by_word(self, word: str) -> int:
return self.word2index.get(word, self.UNK_INDEX)

def get_word_by_index(self, index: int) -> str:
return self.index2word.get(index, self.UNK_TOKEN)

def save(self, save_vocabulary_path: str) -> None:
with open(save_vocabulary_path, "w") as save_vocabulary_file:
json.dump(self.word2index, saved_vocabulary_file)

def __len__(self):
return len(self.index2word)

0 comments on commit 184c92e

Please sign in to comment.