diff --git a/gptc/classifier.py b/gptc/classifier.py index 73a0077..8fdd1eb 100755 --- a/gptc/classifier.py +++ b/gptc/classifier.py @@ -22,7 +22,7 @@ class Classifier: except: model_version = 1 - if model_version == 1: + if model_version == 2: self.model = model else: # The model is an unsupported version @@ -56,16 +56,15 @@ class Classifier: probs = {} for word in text: try: - total = sum(model[word].values()) - for category, value in model[word].items(): + for category, value in enumerate(model[word]): try: - probs[category] += value / total + probs[category] += value except KeyError: - probs[category] = value / total + probs[category] = value except KeyError: pass total = sum(probs.values()) - probs = {category: value/total for category, value in probs.items()} + probs = {model['__names__'][category]: value/total for category, value in probs.items()} return probs def classify(self, text): diff --git a/gptc/compiler.py b/gptc/compiler.py index 865f719..54a1032 100755 --- a/gptc/compiler.py +++ b/gptc/compiler.py @@ -27,7 +27,12 @@ def compile(raw_model): categories_by_count = {} + names = [] + for category, text in categories.items(): + if not category in names: + names.append(category) + categories_by_count[category] = {} for word in text: try: @@ -45,7 +50,11 @@ def compile(raw_model): model = {} for word, weights in word_weights.items(): total = sum(weights.values()) - model[word] = {category: weight/total for category, weight in weights.items()} + model[word] = [] + for category in names: + model[word].append(weights.get(category, 0)/total) + + model['__names__'] = names model['__version__'] = 2 model['__raw__'] = raw_model