From 67ac3a45917759760629c39cfdf3e19c93c8a998 Mon Sep 17 00:00:00 2001 From: Samuel Sloniker Date: Sun, 17 Jul 2022 18:22:19 -0700 Subject: [PATCH] Working type checks --- gptc/__main__.py | 2 +- gptc/classifier.py | 35 ++++++++++++++++++----------------- gptc/compiler.py | 10 +++++++--- gptc/tokenizer.py | 9 +++++---- gptc/weighting.py | 9 +++++---- 5 files changed, 36 insertions(+), 29 deletions(-) diff --git a/gptc/__main__.py b/gptc/__main__.py index 560ba4c..ff9ed52 100644 --- a/gptc/__main__.py +++ b/gptc/__main__.py @@ -7,7 +7,7 @@ import sys import gptc -def main(): +def main() -> None: parser = argparse.ArgumentParser( description="General Purpose Text Classifier", prog="gptc" ) diff --git a/gptc/classifier.py b/gptc/classifier.py index 3648a86..ac2fde7 100755 --- a/gptc/classifier.py +++ b/gptc/classifier.py @@ -2,6 +2,7 @@ import gptc.tokenizer, gptc.compiler, gptc.exceptions, gptc.weighting import warnings +from typing import Dict, Union, cast, List class Classifier: @@ -24,17 +25,18 @@ class Classifier: """ - def __init__(self, model, max_ngram_length=1): + def __init__(self, model: gptc.compiler.MODEL, max_ngram_length: int = 1): if model.get("__version__", 0) != 3: raise gptc.exceptions.UnsupportedModelError( f"unsupported model version" ) self.model = model + model_ngrams = cast(int, model.get("__ngrams__", 1)) self.max_ngram_length = min( - max_ngram_length, model.get("__ngrams__", 1) + max_ngram_length, model_ngrams ) - def confidence(self, text): + def confidence(self, text: str) -> Dict[str, float]: """Classify text with confidence. Parameters @@ -52,29 +54,28 @@ class Classifier: model = self.model - text = gptc.tokenizer.tokenize(text, self.max_ngram_length) - probs = {} - for word in text: + tokens = gptc.tokenizer.tokenize(text, self.max_ngram_length) + numbered_probs: Dict[int, float] = {} + for word in tokens: try: - weight, weighted_numbers = gptc.weighting.weight( - [i / 65535 for i in model[word]] + weighted_numbers = gptc.weighting.weight( + [i / 65535 for i in cast(List[float], model[word])] ) for category, value in enumerate(weighted_numbers): try: - probs[category] += value + numbered_probs[category] += value except KeyError: - probs[category] = value + numbered_probs[category] = value except KeyError: pass - probs = { - model["__names__"][category]: value - for category, value in probs.items() + total = sum(numbered_probs.values()) + probs: Dict[str, float] = { + cast(List[str], model["__names__"])[category]: value / total + for category, value in numbered_probs.items() } - total = sum(probs.values()) - probs = {category: value / total for category, value in probs.items()} return probs - def classify(self, text): + def classify(self, text: str) -> Union[str, None]: """Classify text. Parameters @@ -89,7 +90,7 @@ class Classifier: category in the model were found. """ - probs = self.confidence(text) + probs: Dict[str, float] = self.confidence(text) try: return sorted(probs.items(), key=lambda x: x[1])[-1][0] except IndexError: diff --git a/gptc/compiler.py b/gptc/compiler.py index ba3ffda..05b793b 100755 --- a/gptc/compiler.py +++ b/gptc/compiler.py @@ -3,10 +3,14 @@ import gptc.tokenizer from typing import Iterable, Mapping, List, Dict, Union +WEIGHTS_T = List[int] +CONFIG_T = Union[List[str], int, str] +MODEL = Dict[str, Union[WEIGHTS_T, CONFIG_T]] + def compile( raw_model: Iterable[Mapping[str, str]], max_ngram_length: int = 1 -) -> Dict[str, Union[str, int, List[int], List[str]]]: +) -> MODEL: """Compile a raw model. Parameters @@ -24,7 +28,7 @@ def compile( """ - categories: Dict[str, str] = {} + categories: Dict[str, List[str]] = {} for portion in raw_model: text = gptc.tokenizer.tokenize(portion["text"], max_ngram_length) @@ -60,7 +64,7 @@ def compile( except KeyError: word_weights[word] = {category: value} - model: Dict[str, Union[str, int, List[int], List[str]]] = {} + model: MODEL = {} for word, weights in word_weights.items(): total = sum(weights.values()) new_weights: List[int] = [] diff --git a/gptc/tokenizer.py b/gptc/tokenizer.py index c9e9fd0..25c97aa 100644 --- a/gptc/tokenizer.py +++ b/gptc/tokenizer.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from typing import List, Union try: import emoji @@ -8,9 +9,9 @@ except ImportError: has_emoji = False -def tokenize(text, max_ngram_length=1): +def tokenize(text: str, max_ngram_length: int=1) -> List[str]: """Convert a string to a list of lemmas.""" - text = text.lower() + converted_text: Union[str, List[str]] = text.lower() if has_emoji: parts = [] @@ -20,11 +21,11 @@ def tokenize(text, max_ngram_length=1): parts.append(emoji_part["emoji"]) highest_end = emoji_part["match_end"] parts += list(text[highest_end:]) - text = [part for part in parts if part] + converted_text = [part for part in parts if part] tokens = [""] - for char in text: + for char in converted_text: if char.isalpha() or char == "'": tokens[-1] += char elif has_emoji and emoji.is_emoji(char): diff --git a/gptc/weighting.py b/gptc/weighting.py index 047de6a..5060e47 100755 --- a/gptc/weighting.py +++ b/gptc/weighting.py @@ -1,9 +1,10 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import math +from typing import Sequence, Union, Tuple, List -def _mean(numbers): +def _mean(numbers: Sequence[float]) -> float: """Calculate the mean of a group of numbers Parameters @@ -19,7 +20,7 @@ def _mean(numbers): return sum(numbers) / len(numbers) -def _standard_deviation(numbers): +def _standard_deviation(numbers: Sequence[float]) -> float: """Calculate the standard deviation of a group of numbers Parameters @@ -38,8 +39,8 @@ def _standard_deviation(numbers): return math.sqrt(_mean(squared_deviations)) -def weight(numbers): +def weight(numbers: Sequence[float]) -> List[float]: standard_deviation = _standard_deviation(numbers) weight = standard_deviation * 2 weighted_numbers = [i * weight for i in numbers] - return weight, weighted_numbers + return weighted_numbers