Skip to content

Commit 00486b3

Browse files
authored
Merge pull request #405 from cmusphinx/arpabo-lm
Add a simple language model maker script
2 parents 69167fb + 07e206a commit 00486b3

File tree

3 files changed

+388
-0
lines changed

3 files changed

+388
-0
lines changed

cython/pocketsphinx/lm.py

+383
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,383 @@
1+
#!/usr/bin/env python
2+
3+
import argparse
4+
import re
5+
import sys
6+
import unicodedata as ud
7+
from collections import defaultdict
8+
from datetime import date
9+
from io import StringIO
10+
from math import log
11+
from typing import Any, Dict, Optional, TextIO
12+
13+
# Author: Kevin Lenzo
14+
# Based on a Perl script by Alex Rudnicky
15+
16+
17+
class ArpaBoLM:
18+
"""
19+
A simple ARPA model builder
20+
"""
21+
22+
log10 = log(10.0)
23+
norm_exclude_categories = set(["P", "S", "C", "M", "Z"])
24+
25+
def __init__(
26+
self,
27+
sentfile: Optional[TextIO] = None,
28+
text: Optional[str] = None,
29+
add_start: bool = False,
30+
word_file: Optional[str] = None,
31+
word_file_count: int = 1,
32+
discount_mass: float = 0.5,
33+
case: Optional[str] = None, # lower, upper
34+
norm: bool = False,
35+
verbose: bool = False,
36+
):
37+
self.add_start = add_start
38+
self.word_file = word_file
39+
self.word_file_count = word_file_count
40+
self.discount_mass = discount_mass
41+
self.case = case
42+
self.norm = norm
43+
self.verbose = verbose
44+
45+
self.logfile = sys.stdout
46+
47+
if self.verbose:
48+
print("Started", date.today(), file=self.logfile)
49+
50+
if discount_mass is None: # TODO: add other smoothing methods
51+
self.discount_mass = 0.5
52+
elif not 0.0 < discount_mass < 1.0:
53+
raise AttributeError(
54+
f"Discount value ({discount_mass}) out of range [0.0, 1.0]"
55+
)
56+
57+
self.deflator: float = 1.0 - self.discount_mass
58+
59+
self.sent_count = 0
60+
61+
self.grams_1: Any = defaultdict(int)
62+
self.grams_2: Any = defaultdict(lambda: defaultdict(int))
63+
self.grams_3: Any = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
64+
65+
self.sum_1: int = 0
66+
self.count_1: int = 0
67+
self.count_2: int = 0
68+
self.count_3: int = 0
69+
70+
self.prob_1: Dict[str, float] = {}
71+
self.alpha_1: Dict[str, float] = {}
72+
self.prob_2: Any = defaultdict(lambda: defaultdict(float))
73+
self.alpha_2: Any = defaultdict(lambda: defaultdict(float))
74+
75+
if sentfile is not None:
76+
self.read_corpus(sentfile)
77+
if text is not None:
78+
self.read_corpus(StringIO(text))
79+
80+
if self.word_file is not None:
81+
self.read_word_file(self.word_file)
82+
83+
def read_word_file(self, path: str, count: Optional[int] = None) -> bool:
84+
"""
85+
Read in a file of words to add to the model,
86+
if not present, with the given count (default 1)
87+
"""
88+
if self.verbose:
89+
print("Reading word file:", path, file=self.logfile)
90+
91+
if count is None:
92+
count = self.word_file_count
93+
94+
new_word_count = token_count = 0
95+
with open(path) as words_file:
96+
for token in words_file:
97+
token = token.strip()
98+
if not token:
99+
continue
100+
if self.case == "lower":
101+
token = token.lower()
102+
elif self.case == "upper":
103+
token = token.upper()
104+
if self.norm:
105+
token = self.norm_token(token)
106+
token_count += 1
107+
# Here, we could just add one, bumping all the word counts;
108+
# or just add N for the missing ones. We do the latter.
109+
if token not in self.grams_1:
110+
self.grams_1[token] = count
111+
new_word_count += 1
112+
113+
if self.verbose:
114+
print(
115+
f"{new_word_count} new unique words",
116+
f"from {token_count} tokens,",
117+
f"each with count {count}",
118+
file=self.logfile,
119+
)
120+
return True
121+
122+
def norm_token(self, token: str) -> str:
123+
"""
124+
Remove excluded leading and trailing character categories from a token
125+
"""
126+
while (
127+
len(token) and ud.category(token[0])[0] in ArpaBoLM.norm_exclude_categories
128+
):
129+
token = token[1:]
130+
while (
131+
len(token) and ud.category(token[-1])[0] in ArpaBoLM.norm_exclude_categories
132+
):
133+
token = token[:-1]
134+
return token
135+
136+
def read_corpus(self, infile):
137+
"""
138+
Read in a text training corpus from a file handle
139+
"""
140+
if self.verbose:
141+
print("Reading corpus file, breaking per newline.", file=self.logfile)
142+
143+
sent_count = 0
144+
for line in infile:
145+
if self.case == "lower":
146+
line = line.lower()
147+
elif self.case == "upper":
148+
line = line.upper()
149+
line = line.strip()
150+
line = re.sub(
151+
r"(.+)\(.+\)$", r"\1", line
152+
) # trailing file name in transcripts
153+
154+
words = line.split()
155+
if self.add_start:
156+
words = ["<s>"] + words + ["</s>"]
157+
if self.norm:
158+
words = [self.norm_token(w) for w in words]
159+
words = [w for w in words if len(w)]
160+
if not words:
161+
continue
162+
sent_count += 1
163+
wc = len(words)
164+
for j in range(wc):
165+
w1 = words[j]
166+
self.grams_1[w1] += 1
167+
if j + 1 < wc:
168+
w2 = words[j + 1]
169+
self.grams_2[w1][w2] += 1
170+
if j + 2 < wc:
171+
w3 = words[j + 2]
172+
self.grams_3[w1][w2][w3] += 1
173+
174+
if self.verbose:
175+
print(f"{sent_count} sentences", file=self.logfile)
176+
177+
def compute(self) -> bool:
178+
"""
179+
Compute all the things (derived values).
180+
181+
If an n-gram is not present, the back-off is
182+
183+
P( word_N | word_{N-1}, word_{N-2}, ...., word_1 ) =
184+
P( word_N | word_{N-1}, word_{N-2}, ...., word_2 )
185+
* backoff-weight( word_{N-1} | word_{N-2}, ...., word_1 )
186+
187+
If the sequence
188+
189+
( word_{N-1}, word_{N-2}, ...., word_1 )
190+
191+
is also not listed, then the term
192+
193+
backoff-weight( word_{N-1} | word_{N-2}, ...., word_1 )
194+
195+
gets replaced with 1.0 and the recursion continues.
196+
197+
"""
198+
if not self.grams_1:
199+
sys.exit("No input?")
200+
return False
201+
202+
# token counts
203+
self.sum_1 = sum(self.grams_1.values())
204+
205+
# type counts
206+
self.count_1 = len(self.grams_1)
207+
for w1, gram2 in self.grams_2.items():
208+
self.count_2 += len(gram2)
209+
for w2 in gram2:
210+
self.count_3 += len(self.grams_3[w1][w2])
211+
212+
# unigram probabilities
213+
for gram1, count in self.grams_1.items():
214+
self.prob_1[gram1] = count * self.deflator / self.sum_1
215+
216+
# unigram alphas
217+
for w1 in self.grams_1:
218+
sum_denom = 0.0
219+
for w2, count in self.grams_2[w1].items():
220+
sum_denom += self.prob_1[w2]
221+
self.alpha_1[w1] = self.discount_mass / (1.0 - sum_denom)
222+
223+
# bigram probabilities
224+
for w1, grams2 in self.grams_2.items():
225+
for w2, count in grams2.items():
226+
self.prob_2[w1][w2] = count * self.deflator / self.grams_1[w1]
227+
228+
# bigram alphas
229+
for w1, grams2 in self.grams_2.items():
230+
for w2, count in grams2.items():
231+
sum_denom = 0.0
232+
for w3 in self.grams_3[w1][w2]:
233+
sum_denom += self.prob_2[w2][w3]
234+
self.alpha_2[w1][w2] = self.discount_mass / (1.0 - sum_denom)
235+
return True
236+
237+
def write_file(self, out_path: str) -> bool:
238+
"""
239+
Write out the ARPAbo model to a file path
240+
"""
241+
try:
242+
with open(out_path, "w") as outfile:
243+
self.write(outfile)
244+
except Exception:
245+
return False
246+
return True
247+
248+
def write(self, outfile: TextIO) -> bool:
249+
"""
250+
Write the ARPAbo model to a file handle
251+
"""
252+
if self.verbose:
253+
print("Writing output file", file=self.logfile)
254+
255+
print(
256+
"Corpus:",
257+
f"{self.sent_count} sentences;",
258+
f"{self.sum_1} words,",
259+
f"{self.count_1} 1-grams,",
260+
f"{self.count_2} 2-grams,",
261+
f"{self.count_3} 3-grams,",
262+
f"with fixed discount mass {self.discount_mass}",
263+
"with simple normalization" if self.norm else "",
264+
file=outfile,
265+
)
266+
267+
print(file=outfile)
268+
print("\\data\\", file=outfile)
269+
270+
print(f"ngram 1={self.count_1}", file=outfile)
271+
if self.count_2:
272+
print(f"ngram 2={self.count_2}", file=outfile)
273+
if self.count_3:
274+
print(f"ngram 3={self.count_3}", file=outfile)
275+
print(file=outfile)
276+
277+
print("\\1-grams:", file=outfile)
278+
for w1, prob in sorted(self.prob_1.items()):
279+
log_prob = log(prob) / ArpaBoLM.log10
280+
log_alpha = log(self.alpha_1[w1]) / ArpaBoLM.log10
281+
print(f"{log_prob:6.4f} {w1} {log_alpha:6.4f}", file=outfile)
282+
283+
if self.count_2:
284+
print(file=outfile)
285+
print("\\2-grams:", file=outfile)
286+
for w1, grams2 in sorted(self.prob_2.items()):
287+
for w2, prob in sorted(grams2.items()):
288+
log_prob = log(prob) / ArpaBoLM.log10
289+
log_alpha = log(self.alpha_2[w1][w2]) / ArpaBoLM.log10
290+
print(f"{log_prob:6.4f} {w1} {w2} {log_alpha:6.4f}", file=outfile)
291+
if self.count_3:
292+
print(file=outfile)
293+
print("\\3-grams:", file=outfile)
294+
for w1, grams2 in sorted(self.grams_3.items()):
295+
for w2, grams3 in sorted(grams2.items()):
296+
for w3, count in sorted(grams3.items()): # type: ignore
297+
prob = count * self.deflator / self.grams_2[w1][w2]
298+
log_prob = log(prob) / ArpaBoLM.log10
299+
print(f"{log_prob:6.4f} {w1} {w2} {w3}", file=outfile)
300+
301+
print(file=outfile)
302+
print("\\end\\", file=outfile)
303+
if self.verbose:
304+
print("Finished", date.today(), file=self.logfile)
305+
306+
return True
307+
308+
309+
def main() -> None:
310+
parser = argparse.ArgumentParser(description="Create a fixed-backoff ARPA LM")
311+
parser.add_argument(
312+
"-s",
313+
"--sentfile",
314+
type=argparse.FileType("rt"),
315+
help="sentence transcripts in sphintrain style or one-per-line texts",
316+
)
317+
parser.add_argument("-t", "--text", type=str)
318+
parser.add_argument(
319+
"-w", "--word-file", type=str, help="add words from this file with count -C"
320+
)
321+
parser.add_argument(
322+
"-C",
323+
"--word-file-count",
324+
type=int,
325+
default=1,
326+
help="word count set for each word in --word-file (default 1)",
327+
)
328+
parser.add_argument(
329+
"-d", "--discount-mass", type=float, help="fixed discount mass [0.0, 1.0]"
330+
)
331+
parser.add_argument(
332+
"-c", "--case", type=str, help="fold case (values: lower, upper)"
333+
)
334+
parser.add_argument(
335+
"-a",
336+
"--add-start",
337+
action="store_true",
338+
help="add <s> at start, and at end of lines </s> for -s or -t",
339+
)
340+
parser.add_argument(
341+
"-n",
342+
"--norm",
343+
action="store_true",
344+
help="do rudimentary token normalization / remove punctuation",
345+
)
346+
parser.add_argument(
347+
"-o", "--output", type=str, help="output to this file (default stdout)"
348+
)
349+
parser.add_argument(
350+
"-v", "--verbose", action="store_true", help="extra log info (to stderr)"
351+
)
352+
353+
args = parser.parse_args()
354+
355+
if args.case and args.case not in ["lower", "upper"]:
356+
parser.error("--case must be lower or upper (if given)")
357+
358+
if args.sentfile is None and args.text is None:
359+
parser.error("Input must be specified with --sentfile and/or --text")
360+
361+
lm = ArpaBoLM(
362+
sentfile=args.sentfile,
363+
text=args.text,
364+
word_file=args.word_file,
365+
word_file_count=args.word_file_count,
366+
discount_mass=args.discount_mass,
367+
case=args.case,
368+
add_start=args.add_start,
369+
norm=args.norm,
370+
verbose=args.verbose,
371+
)
372+
lm.compute()
373+
374+
if args.output:
375+
outfile: TextIO = open(args.output, "w")
376+
else:
377+
outfile = sys.stdout
378+
379+
lm.write(outfile)
380+
381+
382+
if __name__ == "__main__":
383+
main()

0 commit comments

Comments
 (0)