diff --git a/gptc/compiler.py b/gptc/compiler.py index f994897..9ce5ac4 100755 --- a/gptc/compiler.py +++ b/gptc/compiler.py @@ -41,39 +41,30 @@ def compile( word_counts: Dict[int, Dict[str, int]] = {} - names = [] + names = tuple(categories.keys()) for category, text in categories.items(): - if not category in names: - names.append(category) - for word in text: - try: - counts_for_word = word_counts[word] - except KeyError: - counts_for_word = {} - word_counts[word] = counts_for_word + if word in word_counts: + try: + word_counts[word][category] += 1 + except KeyError: + word_counts[word][category] = 1 + else: + word_counts[word] = {category: 1} - try: - word_counts[word][category] += 1 - except KeyError: - word_counts[word][category] = 1 - - word_counts = { - word: counts - for word, counts in word_counts.items() - if sum(counts.values()) >= min_count + category_lengths = { + category: len(text) for category, text in categories.items() } - word_weights: Dict[int, Dict[str, float]] = {} - for word, values in word_counts.items(): - for category, value in values.items(): - try: - word_weights[word][category] = value / len(categories[category]) - except KeyError: - word_weights[word] = { - category: value / len(categories[category]) - } + word_weights: Dict[int, Dict[str, float]] = { + word: { + category: value / category_lengths[category] + for category, value in values.items() + } + for word, values in word_counts.items() + if sum(values.values()) >= min_count + } model: Dict[int, List[int]] = {} for word, weights in word_weights.items():