diff --git a/gptc/__init__.py b/gptc/__init__.py index 3f08eb2..f4d36d5 100644 --- a/gptc/__init__.py +++ b/gptc/__init__.py @@ -2,6 +2,10 @@ """General-Purpose Text Classifier""" -from gptc.compiler import compile -from gptc.classifier import Classifier -from gptc.exceptions import * +from gptc.compiler import compile as compile +from gptc.classifier import Classifier as Classifier +from gptc.exceptions import ( + GPTCError as GPTCError, + ModelError as ModelError, + UnsupportedModelError as UnsupportedModelError, +) diff --git a/gptc/__main__.py b/gptc/__main__.py index ff9ed52..dc14d18 100644 --- a/gptc/__main__.py +++ b/gptc/__main__.py @@ -13,9 +13,7 @@ def main() -> None: ) subparsers = parser.add_subparsers(dest="subparser_name", required=True) - compile_parser = subparsers.add_parser( - "compile", help="compile a raw model" - ) + compile_parser = subparsers.add_parser("compile", help="compile a raw model") compile_parser.add_argument("model", help="raw model to compile") compile_parser.add_argument( "--max-ngram-length", diff --git a/gptc/classifier.py b/gptc/classifier.py index ac2fde7..2469ff3 100755 --- a/gptc/classifier.py +++ b/gptc/classifier.py @@ -27,14 +27,10 @@ class Classifier: def __init__(self, model: gptc.compiler.MODEL, max_ngram_length: int = 1): if model.get("__version__", 0) != 3: - raise gptc.exceptions.UnsupportedModelError( - f"unsupported model version" - ) + raise gptc.exceptions.UnsupportedModelError(f"unsupported model version") self.model = model model_ngrams = cast(int, model.get("__ngrams__", 1)) - self.max_ngram_length = min( - max_ngram_length, model_ngrams - ) + self.max_ngram_length = min(max_ngram_length, model_ngrams) def confidence(self, text: str) -> Dict[str, float]: """Classify text with confidence. diff --git a/gptc/compiler.py b/gptc/compiler.py index 05b793b..fd4c12e 100755 --- a/gptc/compiler.py +++ b/gptc/compiler.py @@ -8,9 +8,7 @@ CONFIG_T = Union[List[str], int, str] MODEL = Dict[str, Union[WEIGHTS_T, CONFIG_T]] -def compile( - raw_model: Iterable[Mapping[str, str]], max_ngram_length: int = 1 -) -> MODEL: +def compile(raw_model: Iterable[Mapping[str, str]], max_ngram_length: int = 1) -> MODEL: """Compile a raw model. Parameters @@ -49,13 +47,9 @@ def compile( categories_by_count[category] = {} for word in text: try: - categories_by_count[category][word] += 1 / len( - categories[category] - ) + categories_by_count[category][word] += 1 / len(categories[category]) except KeyError: - categories_by_count[category][word] = 1 / len( - categories[category] - ) + categories_by_count[category][word] = 1 / len(categories[category]) word_weights: Dict[str, Dict[str, float]] = {} for category, words in categories_by_count.items(): for word, value in words.items(): @@ -69,9 +63,7 @@ def compile( total = sum(weights.values()) new_weights: List[int] = [] for category in names: - new_weights.append( - round((weights.get(category, 0) / total) * 65535) - ) + new_weights.append(round((weights.get(category, 0) / total) * 65535)) model[word] = new_weights model["__names__"] = names diff --git a/gptc/tokenizer.py b/gptc/tokenizer.py index 25c97aa..7763e3c 100644 --- a/gptc/tokenizer.py +++ b/gptc/tokenizer.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import List, Union + try: import emoji @@ -9,7 +10,7 @@ except ImportError: has_emoji = False -def tokenize(text: str, max_ngram_length: int=1) -> List[str]: +def tokenize(text: str, max_ngram_length: int = 1) -> List[str]: """Convert a string to a list of lemmas.""" converted_text: Union[str, List[str]] = text.lower()