diff options
author | Jonne <jonne@jonnesaleva.com> | 2020-05-27 20:24:07 -0400 |
---|---|---|
committer | Jonne <jonne@jonnesaleva.com> | 2020-05-27 20:24:07 -0400 |
commit | dd8a0175a34cddd748cb4fdd9485c314b36e19cc (patch) | |
tree | c38246f066e7f7fb527b237ff7d8a24cbbe60f2e /src | |
parent | 5ad69bf8e1d1d5a359296613c8969a81ad743b7d (diff) | |
download | yi-word-clustering-dd8a0175a34cddd748cb4fdd9485c314b36e19cc.tar.gz |
coocurrence counts
Diffstat (limited to 'src')
-rw-r--r-- | src/__pycache__/compute_coocurrences.cpython-37.pyc | bin | 0 -> 3067 bytes | |||
-rw-r--r-- | src/compute_coocurrences.py | 112 |
2 files changed, 112 insertions, 0 deletions
diff --git a/src/__pycache__/compute_coocurrences.cpython-37.pyc b/src/__pycache__/compute_coocurrences.cpython-37.pyc new file mode 100644 index 0000000..d6fe8c6 --- /dev/null +++ b/src/__pycache__/compute_coocurrences.cpython-37.pyc Binary files differdiff --git a/src/compute_coocurrences.py b/src/compute_coocurrences.py new file mode 100644 index 0000000..6882563 --- /dev/null +++ b/src/compute_coocurrences.py @@ -0,0 +1,112 @@ +from collections import defaultdict +import scipy.sparse as sps +import math + +class CoocurrenceCalculator(object): + + def __init__(self, context_window_size, *args, **kwargs): + self.context_window_size = context_window_size + + def context_window(self, document, ix): + n = self.context_window_size + length = len(document) + for relative_ix in range(-n, n): + abs_ix = ix+relative_ix + if relative_ix == 0 or not (0 <= abs_ix < length): + continue + else: + yield document[abs_ix] + + def _normalize(self, d): + Z = sum(d.values()) + return {k: v/Z for k,v in d.items()} + + def compute_coocurrences(self, corpus, *args, **kwargs): + """ + We assume `corpus` is a list of sentences/documents, + and our desire is to simply compute the + """ + self.vocab = {} + last_word_ix = 0 + coocurrences = defaultdict(int) + for document in corpus: + for i, word in enumerate(document): + if word not in self.vocab: + self.vocab[word] = last_word_ix + last_word_ix += 1 + for context_word in self.context_window(document, i): + coocurrences[(word, context_word)] += 1 + self.coocurrences = coocurrences + return coocurrences + + def compute_distributions(self): + term_marginals = defaultdict(float) + ctx_marginals = defaultdict(float) + term_ctx_joint = defaultdict(float) + for (w,c), cnt in self.coocurrences.items(): + term_ctx_joint[(w,c)] += cnt + term_marginals[w] += cnt + ctx_marginals[c] += cnt + term_ctx_joint = self._normalize(term_ctx_joint) + term_marginals = self._normalize(term_marginals) + ctx_marginals = self._normalize(ctx_marginals) + + self.term_ctx_joint = term_ctx_joint + self.term_marginals = term_marginals + self.ctx_marginals = ctx_marginals + + return term_ctx_joint, term_marginals, ctx_marginals + + def compute_pmi(self, return_shifted=False): + k = self.context_window_size + + pmi_dict = defaultdict(float) + shifted_pmi_dict = defaultdict(float) + + for (w,c) in self.coocurrences: + obs_pmi = self._pmi(w,c) + pmi_dict[(w,c)] = obs_pmi + shifted_pmi_dict[(w,c)] = obs_pmi - math.log(k) + + self.pmi_dict = pmi_dict + self.shifted_pmi_dict = shifted_pmi_dict + + if return_shifted: + return shifted_pmi_dict + else: + return pmi_dict + + + def _pmi(self, w, c): + p_w = self.term_marginals[w] + p_c = self.ctx_marginals[c] + p_wc = self.term_ctx_joint[(w,c)] + + log_num = math.log(p_wc) + log_denom = math.log(p_w) + math.log(p_c) + pmi = log_num - log_denom + + return pmi + + def compute_pmi_matrix(self, corpus, return_shifted=True): + self.compute_coocurrences(corpus) + self.compute_distributions() + self.compute_pmi() + + vocab_size = len(self.vocab) + pmi_matrix = sps.dok_matrix((vocab_size, vocab_size)) + + if return_shifted: + iterator = self.shifted_pmi_dict.items() + else: + iterator = self.pmi_dict.items() + + for (w, c), _pmi in iterator: + term_ix = self.vocab[w] + ctx_ix = self.vocab[c] + pmi_matrix[(term_ix, ctx_ix)] = _pmi + + self.pmi_matrix = pmi_matrix + + return pmi_matrix + |