-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
107 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
// Latent Dirichlet Allocation (LDA) | ||
|
||
data { | ||
int<lower=2> n_topics; | ||
int<lower=2> n_vocab; | ||
int<lower=1> n_words; | ||
int<lower=1> n_docs; | ||
int<lower=1, upper=n_vocab> words[n_words]; // i番目の単語のID | ||
int<lower=1, upper=n_docs> doc_of[n_words]; // i番目の単語が属するドキュメントのID | ||
vector<lower=0>[n_topics] alpha; // i番目のトピックの事前分布 | ||
vector<lower=0>[n_vocab] beta; // IDがiである単語の事前分布 | ||
} | ||
|
||
parameters { | ||
simplex[n_topics] theta[n_docs]; // i番目のドキュメントのトピックの分布 | ||
simplex[n_vocab] phi[n_topics]; // i番目のトピックの単語の分布 | ||
} | ||
|
||
model { | ||
// 事前分布からパラメータをサンプリングする | ||
for (i in 1:n_docs) { | ||
theta[i] ~ dirichlet(alpha); | ||
} | ||
for (i in 1:n_topics) { | ||
phi[i] ~ dirichlet(beta); | ||
} | ||
|
||
// 対数尤度を計算する | ||
for (w in 1:n_words) { | ||
real gamma[n_topics]; | ||
for (t in 1:n_topics) { | ||
// log(そのドキュメントのトピックが t である確率) + log(トピック t の下で単語 w が出現する確率) | ||
gamma[t] = log(theta[doc_of[w], t]) + log(phi[t, words[w]]); | ||
} | ||
target += log_sum_exp(gamma); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import pystan | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
|
||
|
||
class Vocabulary: | ||
def __init__(self): | ||
self._word2id = {} | ||
self._id2word = {} | ||
|
||
def intern(self, word): | ||
if word in self._word2id: | ||
return self._word2id[word] | ||
new_id = len(self._word2id) + 1 | ||
self._word2id[word] = new_id | ||
self._id2word[new_id] = word | ||
return new_id | ||
|
||
def word(self, wid): | ||
return self._id2word[wid] | ||
|
||
@property | ||
def size(self): | ||
return len(self._word2id) | ||
|
||
|
||
def read_corpus(filename, max_lines): | ||
vocab = Vocabulary() | ||
vocab.intern("<unk>") | ||
|
||
word_ids = [] | ||
doc_ids = [] | ||
|
||
with open(filename) as f: | ||
for i, line in enumerate(f): | ||
if i >= max_lines: | ||
break | ||
line = line.strip() | ||
words = line.split(" ") | ||
for word in words: | ||
wid = vocab.intern(word) | ||
word_ids.append(wid) | ||
doc_ids.append(i + 1) | ||
|
||
return (word_ids, doc_ids, vocab) | ||
|
||
|
||
def run_stan(word_ids, doc_ids, vocab, n_topics=10): | ||
# https://stats.stackexchange.com/questions/59684/what-are-typical-values-to-use-for-alpha-and-beta-in-latent-dirichlet-allocation | ||
alpha = np.full(n_topics, 50 / n_topics) | ||
beta = np.full(vocab.size, 0.1) | ||
|
||
data = { | ||
"n_topics": n_topics, | ||
"n_vocab": vocab.size, | ||
"n_words": len(word_ids), | ||
"n_docs": max(doc_ids), | ||
"words": word_ids, | ||
"doc_of": doc_ids, | ||
"alpha": alpha, | ||
"beta": beta, | ||
} | ||
|
||
with open("lda.stan", encoding="utf-8") as f: | ||
model_code = f.read() | ||
# pystan は非ASCII文字があると例外が飛んでしまうので、非ASCII文字を消す | ||
model_code = model_code.encode("ascii", errors="ignore").decode("ascii") | ||
|
||
fit = pystan.stan(model_code=model_code, data=data, iter=300) | ||
return fit |