Skip to content

Commit

Permalink
feat(learning-stan): add LDA model
Browse files Browse the repository at this point in the history
  • Loading branch information
nojima committed Mar 7, 2018
1 parent d5a0eed commit 0d89160
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 0 deletions.
37 changes: 37 additions & 0 deletions learning-pystan/lda.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Latent Dirichlet Allocation (LDA)

data {
int<lower=2> n_topics;
int<lower=2> n_vocab;
int<lower=1> n_words;
int<lower=1> n_docs;
int<lower=1, upper=n_vocab> words[n_words]; // i番目の単語のID
int<lower=1, upper=n_docs> doc_of[n_words]; // i番目の単語が属するドキュメントのID
vector<lower=0>[n_topics] alpha; // i番目のトピックの事前分布
vector<lower=0>[n_vocab] beta; // IDがiである単語の事前分布
}

parameters {
simplex[n_topics] theta[n_docs]; // i番目のドキュメントのトピックの分布
simplex[n_vocab] phi[n_topics]; // i番目のトピックの単語の分布
}

model {
// 事前分布からパラメータをサンプリングする
for (i in 1:n_docs) {
theta[i] ~ dirichlet(alpha);
}
for (i in 1:n_topics) {
phi[i] ~ dirichlet(beta);
}

// 対数尤度を計算する
for (w in 1:n_words) {
real gamma[n_topics];
for (t in 1:n_topics) {
// log(そのドキュメントのトピックが t である確率) + log(トピック t の下で単語 w が出現する確率)
gamma[t] = log(theta[doc_of[w], t]) + log(phi[t, words[w]]);
}
target += log_sum_exp(gamma);
}
}
70 changes: 70 additions & 0 deletions learning-pystan/stan_lda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import pystan
import numpy as np
import matplotlib.pyplot as plt


class Vocabulary:
def __init__(self):
self._word2id = {}
self._id2word = {}

def intern(self, word):
if word in self._word2id:
return self._word2id[word]
new_id = len(self._word2id) + 1
self._word2id[word] = new_id
self._id2word[new_id] = word
return new_id

def word(self, wid):
return self._id2word[wid]

@property
def size(self):
return len(self._word2id)


def read_corpus(filename, max_lines):
vocab = Vocabulary()
vocab.intern("<unk>")

word_ids = []
doc_ids = []

with open(filename) as f:
for i, line in enumerate(f):
if i >= max_lines:
break
line = line.strip()
words = line.split(" ")
for word in words:
wid = vocab.intern(word)
word_ids.append(wid)
doc_ids.append(i + 1)

return (word_ids, doc_ids, vocab)


def run_stan(word_ids, doc_ids, vocab, n_topics=10):
# https://stats.stackexchange.com/questions/59684/what-are-typical-values-to-use-for-alpha-and-beta-in-latent-dirichlet-allocation
alpha = np.full(n_topics, 50 / n_topics)
beta = np.full(vocab.size, 0.1)

data = {
"n_topics": n_topics,
"n_vocab": vocab.size,
"n_words": len(word_ids),
"n_docs": max(doc_ids),
"words": word_ids,
"doc_of": doc_ids,
"alpha": alpha,
"beta": beta,
}

with open("lda.stan", encoding="utf-8") as f:
model_code = f.read()
# pystan は非ASCII文字があると例外が飛んでしまうので、非ASCII文字を消す
model_code = model_code.encode("ascii", errors="ignore").decode("ascii")

fit = pystan.stan(model_code=model_code, data=data, iter=300)
return fit

0 comments on commit 0d89160

Please sign in to comment.