diff --git a/gptc/compiler.py b/gptc/compiler.py index b54b85e..ba3ffda 100755 --- a/gptc/compiler.py +++ b/gptc/compiler.py @@ -6,7 +6,7 @@ from typing import Iterable, Mapping, List, Dict, Union def compile( raw_model: Iterable[Mapping[str, str]], max_ngram_length: int = 1 -) -> Dict[str, Union[int, List[Union[str, int]]]]: +) -> Dict[str, Union[str, int, List[int], List[str]]]: """Compile a raw model. Parameters @@ -60,14 +60,15 @@ def compile( except KeyError: word_weights[word] = {category: value} - model: Dict[str, Union[int, List[Union[str, int]]]] = {} + model: Dict[str, Union[str, int, List[int], List[str]]] = {} for word, weights in word_weights.items(): total = sum(weights.values()) - model[word] = [] + new_weights: List[int] = [] for category in names: - model[word].append( + new_weights.append( round((weights.get(category, 0) / total) * 65535) ) + model[word] = new_weights model["__names__"] = names model["__ngrams__"] = max_ngram_length