about summary refs log tree commit diff stats
path: root/src/compute_coocurrences.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/compute_coocurrences.py')
-rw-r--r--src/compute_coocurrences.py112
1 files changed, 112 insertions, 0 deletions
diff --git a/src/compute_coocurrences.py b/src/compute_coocurrences.py
new file mode 100644
index 0000000..6882563
--- /dev/null
+++ b/src/compute_coocurrences.py
@@ -0,0 +1,112 @@
+from collections import defaultdict
+import scipy.sparse as sps
+import math
+
+class CoocurrenceCalculator(object):
+
+    def __init__(self, context_window_size, *args, **kwargs):
+        self.context_window_size = context_window_size
+
+    def context_window(self, document, ix):
+        n = self.context_window_size
+        length = len(document)
+        for relative_ix in range(-n, n):
+            abs_ix = ix+relative_ix
+            if relative_ix == 0 or not (0 <= abs_ix < length):
+                continue
+            else:
+                yield document[abs_ix]
+
+    def _normalize(self, d):
+        Z = sum(d.values())
+        return {k: v/Z for k,v in d.items()}
+
+    def compute_coocurrences(self, corpus, *args, **kwargs):
+        """
+        We assume `corpus` is a list of sentences/documents,
+        and our desire is to simply compute the 
+        """
+        self.vocab = {}
+        last_word_ix = 0
+        coocurrences = defaultdict(int)
+        for document in corpus:
+            for i, word in enumerate(document):
+                if word not in self.vocab:
+                    self.vocab[word] = last_word_ix
+                    last_word_ix += 1
+                for context_word in self.context_window(document, i):
+                    coocurrences[(word, context_word)] += 1
+        self.coocurrences = coocurrences
+        return coocurrences
+
+    def compute_distributions(self):
+        term_marginals = defaultdict(float)
+        ctx_marginals = defaultdict(float)
+        term_ctx_joint = defaultdict(float)
+        for (w,c), cnt in self.coocurrences.items():
+            term_ctx_joint[(w,c)] += cnt
+            term_marginals[w] += cnt
+            ctx_marginals[c] += cnt
+        term_ctx_joint = self._normalize(term_ctx_joint)
+        term_marginals = self._normalize(term_marginals)
+        ctx_marginals = self._normalize(ctx_marginals)
+
+        self.term_ctx_joint = term_ctx_joint
+        self.term_marginals = term_marginals
+        self.ctx_marginals = ctx_marginals
+
+        return term_ctx_joint, term_marginals, ctx_marginals
+
+    def compute_pmi(self, return_shifted=False):
+        k = self.context_window_size
+
+        pmi_dict = defaultdict(float)
+        shifted_pmi_dict = defaultdict(float)
+
+        for (w,c) in self.coocurrences:
+            obs_pmi = self._pmi(w,c)
+            pmi_dict[(w,c)] = obs_pmi
+            shifted_pmi_dict[(w,c)] = obs_pmi - math.log(k)
+
+        self.pmi_dict = pmi_dict
+        self.shifted_pmi_dict = shifted_pmi_dict
+
+        if return_shifted:
+            return shifted_pmi_dict
+        else:
+            return pmi_dict
+
+
+    def _pmi(self, w, c):
+        p_w = self.term_marginals[w]
+        p_c = self.ctx_marginals[c]
+        p_wc = self.term_ctx_joint[(w,c)]
+        
+        log_num = math.log(p_wc)
+        log_denom = math.log(p_w) + math.log(p_c)
+        pmi = log_num - log_denom
+        
+        return pmi
+
+    def compute_pmi_matrix(self, corpus, return_shifted=True):
+        self.compute_coocurrences(corpus)
+        self.compute_distributions()
+        self.compute_pmi()
+        
+        vocab_size = len(self.vocab)
+        pmi_matrix = sps.dok_matrix((vocab_size, vocab_size))
+
+        if return_shifted:
+            iterator = self.shifted_pmi_dict.items()
+        else:
+            iterator = self.pmi_dict.items()
+
+        for (w, c), _pmi in iterator:
+            term_ix = self.vocab[w]
+            ctx_ix = self.vocab[c]
+            pmi_matrix[(term_ix, ctx_ix)] = _pmi
+
+        self.pmi_matrix = pmi_matrix
+
+        return pmi_matrix
+