Add type checks to all functions that need them

This commit is contained in:
Samuel Sloniker 2022-07-17 18:42:38 -07:00
parent 67ac3a4591
commit e711767d24
5 changed files with 16 additions and 25 deletions

View File

@ -2,6 +2,10 @@
"""General-Purpose Text Classifier""" """General-Purpose Text Classifier"""
from gptc.compiler import compile from gptc.compiler import compile as compile
from gptc.classifier import Classifier from gptc.classifier import Classifier as Classifier
from gptc.exceptions import * from gptc.exceptions import (
GPTCError as GPTCError,
ModelError as ModelError,
UnsupportedModelError as UnsupportedModelError,
)

View File

@ -13,9 +13,7 @@ def main() -> None:
) )
subparsers = parser.add_subparsers(dest="subparser_name", required=True) subparsers = parser.add_subparsers(dest="subparser_name", required=True)
compile_parser = subparsers.add_parser( compile_parser = subparsers.add_parser("compile", help="compile a raw model")
"compile", help="compile a raw model"
)
compile_parser.add_argument("model", help="raw model to compile") compile_parser.add_argument("model", help="raw model to compile")
compile_parser.add_argument( compile_parser.add_argument(
"--max-ngram-length", "--max-ngram-length",

View File

@ -27,14 +27,10 @@ class Classifier:
def __init__(self, model: gptc.compiler.MODEL, max_ngram_length: int = 1): def __init__(self, model: gptc.compiler.MODEL, max_ngram_length: int = 1):
if model.get("__version__", 0) != 3: if model.get("__version__", 0) != 3:
raise gptc.exceptions.UnsupportedModelError( raise gptc.exceptions.UnsupportedModelError(f"unsupported model version")
f"unsupported model version"
)
self.model = model self.model = model
model_ngrams = cast(int, model.get("__ngrams__", 1)) model_ngrams = cast(int, model.get("__ngrams__", 1))
self.max_ngram_length = min( self.max_ngram_length = min(max_ngram_length, model_ngrams)
max_ngram_length, model_ngrams
)
def confidence(self, text: str) -> Dict[str, float]: def confidence(self, text: str) -> Dict[str, float]:
"""Classify text with confidence. """Classify text with confidence.

View File

@ -8,9 +8,7 @@ CONFIG_T = Union[List[str], int, str]
MODEL = Dict[str, Union[WEIGHTS_T, CONFIG_T]] MODEL = Dict[str, Union[WEIGHTS_T, CONFIG_T]]
def compile( def compile(raw_model: Iterable[Mapping[str, str]], max_ngram_length: int = 1) -> MODEL:
raw_model: Iterable[Mapping[str, str]], max_ngram_length: int = 1
) -> MODEL:
"""Compile a raw model. """Compile a raw model.
Parameters Parameters
@ -49,13 +47,9 @@ def compile(
categories_by_count[category] = {} categories_by_count[category] = {}
for word in text: for word in text:
try: try:
categories_by_count[category][word] += 1 / len( categories_by_count[category][word] += 1 / len(categories[category])
categories[category]
)
except KeyError: except KeyError:
categories_by_count[category][word] = 1 / len( categories_by_count[category][word] = 1 / len(categories[category])
categories[category]
)
word_weights: Dict[str, Dict[str, float]] = {} word_weights: Dict[str, Dict[str, float]] = {}
for category, words in categories_by_count.items(): for category, words in categories_by_count.items():
for word, value in words.items(): for word, value in words.items():
@ -69,9 +63,7 @@ def compile(
total = sum(weights.values()) total = sum(weights.values())
new_weights: List[int] = [] new_weights: List[int] = []
for category in names: for category in names:
new_weights.append( new_weights.append(round((weights.get(category, 0) / total) * 65535))
round((weights.get(category, 0) / total) * 65535)
)
model[word] = new_weights model[word] = new_weights
model["__names__"] = names model["__names__"] = names

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later # SPDX-License-Identifier: LGPL-3.0-or-later
from typing import List, Union from typing import List, Union
try: try:
import emoji import emoji