from collections import defaultdict import scipy.sparse as sps import math class CoocurrenceCalculator(object): def __init__(self, context_window_size, *args, **kwargs): self.context_window_size = context_window_size def context_window(self, document, ix): n = self.context_window_size length = len(document) for relative_ix in range(-n, n): abs_ix = ix+relative_ix if relative_ix == 0 or not (0 <= abs_ix < length): continue else: yield document[abs_ix] def _normalize(self, d): Z = sum(d.values()) return {k: v/Z for k,v in d.items()} def compute_coocurrences(self, corpus, *args, **kwargs): """ We assume `corpus` is a list of sentences/documents, and our desire is to simply compute the """ self.vocab = {} last_word_ix = 0 coocurrences = defaultdict(int) for document in corpus: for i, word in enumerate(document): if word not in self.vocab: self.vocab[word] = last_word_ix last_word_ix += 1 for context_word in self.context_window(document, i): coocurrences[(word, context_word)] += 1 self.coocurrences = coocurrences return coocurrences def compute_distributions(self): term_marginals = defaultdict(float) ctx_marginals = defaultdict(float) term_ctx_joint = defaultdict(float) for (w,c), cnt in self.coocurrences.items(): term_ctx_joint[(w,c)] += cnt term_marginals[w] += cnt ctx_marginals[c] += cnt term_ctx_joint = self._normalize(term_ctx_joint) term_marginals = self._normalize(term_marginals) ctx_marginals = self._normalize(ctx_marginals) self.term_ctx_joint = term_ctx_joint self.term_marginals = term_marginals self.ctx_marginals = ctx_marginals return term_ctx_joint, term_marginals, ctx_marginals def compute_pmi(self, return_shifted=False): k = self.context_window_size pmi_dict = defaultdict(float) shifted_pmi_dict = defaultdict(float) for (w,c) in self.coocurrences: obs_pmi = self._pmi(w,c) pmi_dict[(w,c)] = obs_pmi shifted_pmi_dict[(w,c)] = obs_pmi - math.log(k) self.pmi_dict = pmi_dict self.shifted_pmi_dict = shifted_pmi_dict if return_shifted: return shifted_pmi_dict else: return pmi_dict def _pmi(self, w, c): p_w = self.term_marginals[w] p_c = self.ctx_marginals[c] p_wc = self.term_ctx_joint[(w,c)] log_num = math.log(p_wc) log_denom = math.log(p_w) + math.log(p_c) pmi = log_num - log_denom return pmi def compute_pmi_matrix(self, corpus, return_shifted=True): self.compute_coocurrences(corpus) self.compute_distributions() self.compute_pmi() vocab_size = len(self.vocab) pmi_matrix = sps.dok_matrix((vocab_size, vocab_size)) if return_shifted: iterator = self.shifted_pmi_dict.items() else: iterator = self.pmi_dict.items() for (w, c), _pmi in iterator: term_ix = self.vocab[w] ctx_ix = self.vocab[c] pmi_matrix[(term_ix, ctx_ix)] = _pmi self.pmi_matrix = pmi_matrix return pmi_matrix