Add type checks to all functions that need them
This commit is contained in:
parent
67ac3a4591
commit
e711767d24
|
@ -2,6 +2,10 @@
|
|||
|
||||
"""General-Purpose Text Classifier"""
|
||||
|
||||
from gptc.compiler import compile
|
||||
from gptc.classifier import Classifier
|
||||
from gptc.exceptions import *
|
||||
from gptc.compiler import compile as compile
|
||||
from gptc.classifier import Classifier as Classifier
|
||||
from gptc.exceptions import (
|
||||
GPTCError as GPTCError,
|
||||
ModelError as ModelError,
|
||||
UnsupportedModelError as UnsupportedModelError,
|
||||
)
|
||||
|
|
|
@ -13,9 +13,7 @@ def main() -> None:
|
|||
)
|
||||
subparsers = parser.add_subparsers(dest="subparser_name", required=True)
|
||||
|
||||
compile_parser = subparsers.add_parser(
|
||||
"compile", help="compile a raw model"
|
||||
)
|
||||
compile_parser = subparsers.add_parser("compile", help="compile a raw model")
|
||||
compile_parser.add_argument("model", help="raw model to compile")
|
||||
compile_parser.add_argument(
|
||||
"--max-ngram-length",
|
||||
|
|
|
@ -27,14 +27,10 @@ class Classifier:
|
|||
|
||||
def __init__(self, model: gptc.compiler.MODEL, max_ngram_length: int = 1):
|
||||
if model.get("__version__", 0) != 3:
|
||||
raise gptc.exceptions.UnsupportedModelError(
|
||||
f"unsupported model version"
|
||||
)
|
||||
raise gptc.exceptions.UnsupportedModelError(f"unsupported model version")
|
||||
self.model = model
|
||||
model_ngrams = cast(int, model.get("__ngrams__", 1))
|
||||
self.max_ngram_length = min(
|
||||
max_ngram_length, model_ngrams
|
||||
)
|
||||
self.max_ngram_length = min(max_ngram_length, model_ngrams)
|
||||
|
||||
def confidence(self, text: str) -> Dict[str, float]:
|
||||
"""Classify text with confidence.
|
||||
|
|
|
@ -8,9 +8,7 @@ CONFIG_T = Union[List[str], int, str]
|
|||
MODEL = Dict[str, Union[WEIGHTS_T, CONFIG_T]]
|
||||
|
||||
|
||||
def compile(
|
||||
raw_model: Iterable[Mapping[str, str]], max_ngram_length: int = 1
|
||||
) -> MODEL:
|
||||
def compile(raw_model: Iterable[Mapping[str, str]], max_ngram_length: int = 1) -> MODEL:
|
||||
"""Compile a raw model.
|
||||
|
||||
Parameters
|
||||
|
@ -49,13 +47,9 @@ def compile(
|
|||
categories_by_count[category] = {}
|
||||
for word in text:
|
||||
try:
|
||||
categories_by_count[category][word] += 1 / len(
|
||||
categories[category]
|
||||
)
|
||||
categories_by_count[category][word] += 1 / len(categories[category])
|
||||
except KeyError:
|
||||
categories_by_count[category][word] = 1 / len(
|
||||
categories[category]
|
||||
)
|
||||
categories_by_count[category][word] = 1 / len(categories[category])
|
||||
word_weights: Dict[str, Dict[str, float]] = {}
|
||||
for category, words in categories_by_count.items():
|
||||
for word, value in words.items():
|
||||
|
@ -69,9 +63,7 @@ def compile(
|
|||
total = sum(weights.values())
|
||||
new_weights: List[int] = []
|
||||
for category in names:
|
||||
new_weights.append(
|
||||
round((weights.get(category, 0) / total) * 65535)
|
||||
)
|
||||
new_weights.append(round((weights.get(category, 0) / total) * 65535))
|
||||
model[word] = new_weights
|
||||
|
||||
model["__names__"] = names
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# SPDX-License-Identifier: LGPL-3.0-or-later
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
try:
|
||||
import emoji
|
||||
|
||||
|
@ -9,7 +10,7 @@ except ImportError:
|
|||
has_emoji = False
|
||||
|
||||
|
||||
def tokenize(text: str, max_ngram_length: int=1) -> List[str]:
|
||||
def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
|
||||
"""Convert a string to a list of lemmas."""
|
||||
converted_text: Union[str, List[str]] = text.lower()
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user