diff --git a/gptc/compiler.py b/gptc/compiler.py index 3372e70..b54b85e 100755 --- a/gptc/compiler.py +++ b/gptc/compiler.py @@ -1,9 +1,12 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import gptc.tokenizer +from typing import Iterable, Mapping, List, Dict, Union -def compile(raw_model, max_ngram_length=1): +def compile( + raw_model: Iterable[Mapping[str, str]], max_ngram_length: int = 1 +) -> Dict[str, Union[int, List[Union[str, int]]]]: """Compile a raw model. Parameters @@ -21,7 +24,7 @@ def compile(raw_model, max_ngram_length=1): """ - categories = {} + categories: Dict[str, str] = {} for portion in raw_model: text = gptc.tokenizer.tokenize(portion["text"], max_ngram_length) @@ -31,7 +34,7 @@ def compile(raw_model, max_ngram_length=1): except KeyError: categories[category] = text - categories_by_count = {} + categories_by_count: Dict[str, Dict[str, float]] = {} names = [] @@ -49,7 +52,7 @@ def compile(raw_model, max_ngram_length=1): categories_by_count[category][word] = 1 / len( categories[category] ) - word_weights = {} + word_weights: Dict[str, Dict[str, float]] = {} for category, words in categories_by_count.items(): for word, value in words.items(): try: @@ -57,7 +60,7 @@ def compile(raw_model, max_ngram_length=1): except KeyError: word_weights[word] = {category: value} - model = {} + model: Dict[str, Union[int, List[Union[str, int]]]] = {} for word, weights in word_weights.items(): total = sum(weights.values()) model[word] = []