-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstan_lda.py
70 lines (55 loc) · 1.86 KB
/
stan_lda.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import pystan
import numpy as np
import matplotlib.pyplot as plt
class Vocabulary:
def __init__(self):
self._word2id = {}
self._id2word = {}
def intern(self, word):
if word in self._word2id:
return self._word2id[word]
new_id = len(self._word2id) + 1
self._word2id[word] = new_id
self._id2word[new_id] = word
return new_id
def word(self, wid):
return self._id2word[wid]
@property
def size(self):
return len(self._word2id)
def read_corpus(filename, max_lines):
vocab = Vocabulary()
vocab.intern("<unk>")
word_ids = []
doc_ids = []
with open(filename) as f:
for i, line in enumerate(f):
if i >= max_lines:
break
line = line.strip()
words = line.split(" ")
for word in words:
wid = vocab.intern(word)
word_ids.append(wid)
doc_ids.append(i + 1)
return (word_ids, doc_ids, vocab)
def run_stan(word_ids, doc_ids, vocab, n_topics=10):
# https://stats.stackexchange.com/questions/59684/what-are-typical-values-to-use-for-alpha-and-beta-in-latent-dirichlet-allocation
alpha = np.full(n_topics, 50 / n_topics)
beta = np.full(vocab.size, 0.1)
data = {
"n_topics": n_topics,
"n_vocab": vocab.size,
"n_words": len(word_ids),
"n_docs": max(doc_ids),
"words": word_ids,
"doc_of": doc_ids,
"alpha": alpha,
"beta": beta,
}
with open("lda.stan", encoding="utf-8") as f:
model_code = f.read()
# pystan は非ASCII文字があると例外が飛んでしまうので、非ASCII文字を消す
model_code = model_code.encode("ascii", errors="ignore").decode("ascii")
fit = pystan.stan(model_code=model_code, data=data, iter=300)
return fit