Working type checks

This commit is contained in:
Samuel Sloniker 2022-07-17 18:22:19 -07:00
parent b36d8e6081
commit 67ac3a4591
5 changed files with 36 additions and 29 deletions

View File

@ -7,7 +7,7 @@ import sys
import gptc import gptc
def main(): def main() -> None:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="General Purpose Text Classifier", prog="gptc" description="General Purpose Text Classifier", prog="gptc"
) )

View File

@ -2,6 +2,7 @@
import gptc.tokenizer, gptc.compiler, gptc.exceptions, gptc.weighting import gptc.tokenizer, gptc.compiler, gptc.exceptions, gptc.weighting
import warnings import warnings
from typing import Dict, Union, cast, List
class Classifier: class Classifier:
@ -24,17 +25,18 @@ class Classifier:
""" """
def __init__(self, model, max_ngram_length=1): def __init__(self, model: gptc.compiler.MODEL, max_ngram_length: int = 1):
if model.get("__version__", 0) != 3: if model.get("__version__", 0) != 3:
raise gptc.exceptions.UnsupportedModelError( raise gptc.exceptions.UnsupportedModelError(
f"unsupported model version" f"unsupported model version"
) )
self.model = model self.model = model
model_ngrams = cast(int, model.get("__ngrams__", 1))
self.max_ngram_length = min( self.max_ngram_length = min(
max_ngram_length, model.get("__ngrams__", 1) max_ngram_length, model_ngrams
) )
def confidence(self, text): def confidence(self, text: str) -> Dict[str, float]:
"""Classify text with confidence. """Classify text with confidence.
Parameters Parameters
@ -52,29 +54,28 @@ class Classifier:
model = self.model model = self.model
text = gptc.tokenizer.tokenize(text, self.max_ngram_length) tokens = gptc.tokenizer.tokenize(text, self.max_ngram_length)
probs = {} numbered_probs: Dict[int, float] = {}
for word in text: for word in tokens:
try: try:
weight, weighted_numbers = gptc.weighting.weight( weighted_numbers = gptc.weighting.weight(
[i / 65535 for i in model[word]] [i / 65535 for i in cast(List[float], model[word])]
) )
for category, value in enumerate(weighted_numbers): for category, value in enumerate(weighted_numbers):
try: try:
probs[category] += value numbered_probs[category] += value
except KeyError: except KeyError:
probs[category] = value numbered_probs[category] = value
except KeyError: except KeyError:
pass pass
probs = { total = sum(numbered_probs.values())
model["__names__"][category]: value probs: Dict[str, float] = {
for category, value in probs.items() cast(List[str], model["__names__"])[category]: value / total
for category, value in numbered_probs.items()
} }
total = sum(probs.values())
probs = {category: value / total for category, value in probs.items()}
return probs return probs
def classify(self, text): def classify(self, text: str) -> Union[str, None]:
"""Classify text. """Classify text.
Parameters Parameters
@ -89,7 +90,7 @@ class Classifier:
category in the model were found. category in the model were found.
""" """
probs = self.confidence(text) probs: Dict[str, float] = self.confidence(text)
try: try:
return sorted(probs.items(), key=lambda x: x[1])[-1][0] return sorted(probs.items(), key=lambda x: x[1])[-1][0]
except IndexError: except IndexError:

View File

@ -3,10 +3,14 @@
import gptc.tokenizer import gptc.tokenizer
from typing import Iterable, Mapping, List, Dict, Union from typing import Iterable, Mapping, List, Dict, Union
WEIGHTS_T = List[int]
CONFIG_T = Union[List[str], int, str]
MODEL = Dict[str, Union[WEIGHTS_T, CONFIG_T]]
def compile( def compile(
raw_model: Iterable[Mapping[str, str]], max_ngram_length: int = 1 raw_model: Iterable[Mapping[str, str]], max_ngram_length: int = 1
) -> Dict[str, Union[str, int, List[int], List[str]]]: ) -> MODEL:
"""Compile a raw model. """Compile a raw model.
Parameters Parameters
@ -24,7 +28,7 @@ def compile(
""" """
categories: Dict[str, str] = {} categories: Dict[str, List[str]] = {}
for portion in raw_model: for portion in raw_model:
text = gptc.tokenizer.tokenize(portion["text"], max_ngram_length) text = gptc.tokenizer.tokenize(portion["text"], max_ngram_length)
@ -60,7 +64,7 @@ def compile(
except KeyError: except KeyError:
word_weights[word] = {category: value} word_weights[word] = {category: value}
model: Dict[str, Union[str, int, List[int], List[str]]] = {} model: MODEL = {}
for word, weights in word_weights.items(): for word, weights in word_weights.items():
total = sum(weights.values()) total = sum(weights.values())
new_weights: List[int] = [] new_weights: List[int] = []

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later # SPDX-License-Identifier: LGPL-3.0-or-later
from typing import List, Union
try: try:
import emoji import emoji
@ -8,9 +9,9 @@ except ImportError:
has_emoji = False has_emoji = False
def tokenize(text, max_ngram_length=1): def tokenize(text: str, max_ngram_length: int=1) -> List[str]:
"""Convert a string to a list of lemmas.""" """Convert a string to a list of lemmas."""
text = text.lower() converted_text: Union[str, List[str]] = text.lower()
if has_emoji: if has_emoji:
parts = [] parts = []
@ -20,11 +21,11 @@ def tokenize(text, max_ngram_length=1):
parts.append(emoji_part["emoji"]) parts.append(emoji_part["emoji"])
highest_end = emoji_part["match_end"] highest_end = emoji_part["match_end"]
parts += list(text[highest_end:]) parts += list(text[highest_end:])
text = [part for part in parts if part] converted_text = [part for part in parts if part]
tokens = [""] tokens = [""]
for char in text: for char in converted_text:
if char.isalpha() or char == "'": if char.isalpha() or char == "'":
tokens[-1] += char tokens[-1] += char
elif has_emoji and emoji.is_emoji(char): elif has_emoji and emoji.is_emoji(char):

View File

@ -1,9 +1,10 @@
# SPDX-License-Identifier: LGPL-3.0-or-later # SPDX-License-Identifier: LGPL-3.0-or-later
import math import math
from typing import Sequence, Union, Tuple, List
def _mean(numbers): def _mean(numbers: Sequence[float]) -> float:
"""Calculate the mean of a group of numbers """Calculate the mean of a group of numbers
Parameters Parameters
@ -19,7 +20,7 @@ def _mean(numbers):
return sum(numbers) / len(numbers) return sum(numbers) / len(numbers)
def _standard_deviation(numbers): def _standard_deviation(numbers: Sequence[float]) -> float:
"""Calculate the standard deviation of a group of numbers """Calculate the standard deviation of a group of numbers
Parameters Parameters
@ -38,8 +39,8 @@ def _standard_deviation(numbers):
return math.sqrt(_mean(squared_deviations)) return math.sqrt(_mean(squared_deviations))
def weight(numbers): def weight(numbers: Sequence[float]) -> List[float]:
standard_deviation = _standard_deviation(numbers) standard_deviation = _standard_deviation(numbers)
weight = standard_deviation * 2 weight = standard_deviation * 2
weighted_numbers = [i * weight for i in numbers] weighted_numbers = [i * weight for i in numbers]
return weight, weighted_numbers return weighted_numbers