diff --git a/gptc/compiler.py b/gptc/compiler.py index c299a4b..8051c91 100755 --- a/gptc/compiler.py +++ b/gptc/compiler.py @@ -28,8 +28,8 @@ def compile( """ - word_counts: Dict[int, Dict[str, int]] = {} - category_lengths: Dict[str, int] = {} + word_counts: Dict[int, Dict[int, int]] = {} + category_lengths: Dict[int, int] = {} names: List[str] = [] for portion in raw_model: @@ -37,10 +37,12 @@ def compile( gptc.tokenizer.tokenize(portion["text"], max_ngram_length), hash_algorithm, ) - category = portion["category"] + category_name = portion["category"] - if not category in names: - names.append(category) + if not category_name in names: + names.append(category_name) + + category = names.index(category_name) category_lengths[category] = category_lengths.get(category, 0) + len( text @@ -64,7 +66,7 @@ def compile( } total = sum(weights.values()) new_weights: List[int] = [] - for category in names: + for category in range(len(names)): new_weights.append( round((weights.get(category, 0) / total) * 65535) )