From a252a15e9dcdcf2ce66f755a22521a5076504cd3 Mon Sep 17 00:00:00 2001 From: Samuel Sloniker Date: Mon, 17 Apr 2023 20:59:39 -0700 Subject: [PATCH] Clean up code --- gptc/__init__.py | 13 ++++---- gptc/__main__.py | 18 +++++----- gptc/compiler.py | 6 ++-- gptc/model.py | 84 +++++++++++++++++++---------------------------- gptc/pack.py | 20 ++++++----- gptc/tokenizer.py | 22 ++++++------- gptc/weighting.py | 8 ++--- 7 files changed, 78 insertions(+), 93 deletions(-) diff --git a/gptc/__init__.py b/gptc/__init__.py index 05f5dc5..eaee24c 100644 --- a/gptc/__init__.py +++ b/gptc/__init__.py @@ -2,12 +2,11 @@ """General-Purpose Text Classifier""" -from gptc.compiler import compile as compile -from gptc.pack import pack as pack -from gptc.model import Model as Model, deserialize as deserialize -from gptc.tokenizer import normalize as normalize +from gptc.pack import pack +from gptc.model import Model, deserialize +from gptc.tokenizer import normalize from gptc.exceptions import ( - GPTCError as GPTCError, - ModelError as ModelError, - InvalidModelError as InvalidModelError, + GPTCError, + ModelError, + InvalidModelError, ) diff --git a/gptc/__main__.py b/gptc/__main__.py index 9c87536..f035be8 100644 --- a/gptc/__main__.py +++ b/gptc/__main__.py @@ -59,16 +59,16 @@ def main() -> None: args = parser.parse_args() if args.subparser_name == "compile": - with open(args.model, "r") as f: - model = json.load(f) + with open(args.model, "r", encoding="utf-8") as input_file: + model = json.load(input_file) - with open(args.out, "wb+") as f: - gptc.compile( + with open(args.out, "wb+") as output_file: + gptc.Model.compile( model, args.max_ngram_length, args.min_count - ).serialize(f) + ).serialize(output_file) elif args.subparser_name == "classify": - with open(args.model, "rb") as f: - model = gptc.deserialize(f) + with open(args.model, "rb") as model_file: + model = gptc.deserialize(model_file) if sys.stdin.isatty(): text = input("Text to analyse: ") @@ -77,8 +77,8 @@ def main() -> None: print(json.dumps(model.confidence(text, args.max_ngram_length))) elif args.subparser_name == "check": - with open(args.model, "rb") as f: - model = gptc.deserialize(f) + with open(args.model, "rb") as model_file: + model = gptc.deserialize(model_file) print(json.dumps(model.get(args.token))) else: print(json.dumps(gptc.pack(args.model, True)[0])) diff --git a/gptc/compiler.py b/gptc/compiler.py index 130f79b..5a1166b 100755 --- a/gptc/compiler.py +++ b/gptc/compiler.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: GPL-3.0-or-later +from typing import Iterable, Mapping, List, Dict, Tuple import gptc.tokenizer import gptc.model -from typing import Iterable, Mapping, List, Dict, Union, Tuple def _count_words( @@ -15,7 +15,7 @@ def _count_words( names: List[str] = [] for portion in raw_model: - text = gptc.tokenizer.hash( + text = gptc.tokenizer.hash_list( gptc.tokenizer.tokenize(portion["text"], max_ngram_length), hash_algorithm, ) @@ -63,7 +63,7 @@ def _get_weights( return model -def compile( +def compile_( raw_model: Iterable[Mapping[str, str]], max_ngram_length: int = 1, min_count: int = 1, diff --git a/gptc/model.py b/gptc/model.py index 8270b74..b17197c 100644 --- a/gptc/model.py +++ b/gptc/model.py @@ -1,22 +1,18 @@ # SPDX-License-Identifier: GPL-3.0-or-later -import gptc.tokenizer -from gptc.exceptions import InvalidModelError -import gptc.weighting from typing import ( - Iterable, - Mapping, List, Dict, - Union, cast, BinaryIO, - Any, Tuple, TypedDict, ) import json -import collections +import gptc.tokenizer +from gptc.exceptions import InvalidModelError +import gptc.weighting +import gptc.compiler class ExplanationEntry(TypedDict): @@ -33,6 +29,21 @@ Explanation = Dict[ Log = List[Tuple[str, float, List[float]]] +class Confidences(dict[str, float]): + def __init__(self, probs: Dict[str, float]): + dict.__init__(self, probs) + + +class TransparentConfidences(Confidences): + def __init__( + self, + probs: Dict[str, float], + explanation: Explanation, + ): + self.explanation = explanation + Confidences.__init__(self, probs) + + def convert_log(log: Log, names: List[str]) -> Explanation: explanation: Explanation = {} for word2, weight, word_probs in log: @@ -49,33 +60,6 @@ def convert_log(log: Log, names: List[str]) -> Explanation: return explanation -class Confidences(collections.UserDict[str, float]): - def __init__( - self, - probs: Dict[str, float], - model: Model, - text: str, - max_ngram_length: int, - ): - collections.UserDict.__init__(self, probs) - self.model = model - self.text = text - self.max_ngram_length = max_ngram_length - - -class TransparentConfidences(Confidences): - def __init__( - self, - probs: Dict[str, float], - explanation: Explanation, - model: Model, - text: str, - max_ngram_length: int, - ): - Confidences.__init__(self, probs, model, text, max_ngram_length) - self.explanation = explanation - - class Model: def __init__( self, @@ -117,7 +101,7 @@ class Model: text, min(max_ngram_length, self.max_ngram_length) ) - tokens = gptc.tokenizer.hash( + tokens = gptc.tokenizer.hash_list( raw_tokens, self.hash_algorithm, ) @@ -163,12 +147,9 @@ class Model: if transparent: explanation = convert_log(log, self.names) + return TransparentConfidences(probs, explanation) - return TransparentConfidences( - probs, explanation, self, text, max_ngram_length - ) - else: - return Confidences(probs, self, text, max_ngram_length) + return Confidences(probs) def get(self, token: str) -> Dict[str, float]: try: @@ -202,6 +183,8 @@ class Model: + b"".join([weight.to_bytes(2, "big") for weight in weights]) ) + compile = staticmethod(gptc.compiler.compile_) + def deserialize(encoded_model: BinaryIO) -> Model: prefix = encoded_model.read(14) @@ -213,26 +196,27 @@ def deserialize(encoded_model: BinaryIO) -> Model: byte = encoded_model.read(1) if byte == b"\n": break - elif byte == b"": + + if byte == b"": raise InvalidModelError() - else: - config_json += byte + + config_json += byte try: config = json.loads(config_json.decode("utf-8")) - except (UnicodeDecodeError, json.JSONDecodeError): - raise InvalidModelError() + except (UnicodeDecodeError, json.JSONDecodeError) as exc: + raise InvalidModelError() from exc try: names = config["names"] max_ngram_length = config["max_ngram_length"] hash_algorithm = config["hash_algorithm"] - except KeyError: - raise InvalidModelError() + except KeyError as exc: + raise InvalidModelError() from exc if not ( isinstance(names, list) and isinstance(max_ngram_length, int) - ) or not all([isinstance(name, str) for name in names]): + ) or not all(isinstance(name, str) for name in names): raise InvalidModelError() weight_code_length = 6 + 2 * len(names) @@ -243,7 +227,7 @@ def deserialize(encoded_model: BinaryIO) -> Model: code = encoded_model.read(weight_code_length) if not code: break - elif len(code) != weight_code_length: + if len(code) != weight_code_length: raise InvalidModelError() weights[int.from_bytes(code[:6], "big")] = [ diff --git a/gptc/pack.py b/gptc/pack.py index 22f3a9f..7ff12ce 100644 --- a/gptc/pack.py +++ b/gptc/pack.py @@ -7,7 +7,7 @@ from typing import List, Dict, Tuple def pack( directory: str, print_exceptions: bool = False -) -> Tuple[List[Dict[str, str]], List[Tuple[Exception]]]: +) -> Tuple[List[Dict[str, str]], List[Tuple[OSError]]]: paths = os.listdir(directory) texts: Dict[str, List[str]] = {} exceptions = [] @@ -17,16 +17,18 @@ def pack( try: for file in os.listdir(os.path.join(directory, path)): try: - with open(os.path.join(directory, path, file)) as f: - texts[path].append(f.read()) - except Exception as e: - exceptions.append((e,)) + with open( + os.path.join(directory, path, file), encoding="utf-8" + ) as input_file: + texts[path].append(input_file.read()) + except OSError as error: + exceptions.append((error,)) if print_exceptions: - print(e, file=sys.stderr) - except Exception as e: - exceptions.append((e,)) + print(error, file=sys.stderr) + except OSError as error: + exceptions.append((error,)) if print_exceptions: - print(e, file=sys.stderr) + print(error, file=sys.stderr) raw_model = [] diff --git a/gptc/tokenizer.py b/gptc/tokenizer.py index a24adf7..abc3287 100644 --- a/gptc/tokenizer.py +++ b/gptc/tokenizer.py @@ -1,9 +1,9 @@ # SPDX-License-Identifier: GPL-3.0-or-later -from typing import List, Union, Callable, Any, cast +import unicodedata +from typing import List, cast import hashlib import emoji -import unicodedata def tokenize(text: str, max_ngram_length: int = 1) -> List[str]: @@ -37,12 +37,12 @@ def tokenize(text: str, max_ngram_length: int = 1) -> List[str]: if max_ngram_length == 1: return tokens - else: - ngrams = [] - for ngram_length in range(1, max_ngram_length + 1): - for index in range(len(tokens) + 1 - ngram_length): - ngrams.append(" ".join(tokens[index : index + ngram_length])) - return ngrams + + ngrams = [] + for ngram_length in range(1, max_ngram_length + 1): + for index in range(len(tokens) + 1 - ngram_length): + ngrams.append(" ".join(tokens[index : index + ngram_length])) + return ngrams def _hash_single(token: str, hash_function: type) -> int: @@ -69,15 +69,15 @@ def _get_hash_function(hash_algorithm: str) -> type: "sha3_384", }: return cast(type, getattr(hashlib, hash_algorithm)) - else: - raise ValueError("not a valid hash function: " + hash_algorithm) + + raise ValueError("not a valid hash function: " + hash_algorithm) def hash_single(token: str, hash_algorithm: str) -> int: return _hash_single(token, _get_hash_function(hash_algorithm)) -def hash(tokens: List[str], hash_algorithm: str) -> List[int]: +def hash_list(tokens: List[str], hash_algorithm: str) -> List[int]: hash_function = _get_hash_function(hash_algorithm) return [_hash_single(token, hash_function) for token in tokens] diff --git a/gptc/weighting.py b/gptc/weighting.py index d033af1..4d93743 100755 --- a/gptc/weighting.py +++ b/gptc/weighting.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import math -from typing import Sequence, Union, Tuple, List +from typing import Sequence, Tuple, List def _mean(numbers: Sequence[float]) -> float: @@ -41,6 +41,6 @@ def _standard_deviation(numbers: Sequence[float]) -> float: def weight(numbers: Sequence[float]) -> Tuple[float, List[float]]: standard_deviation = _standard_deviation(numbers) - weight = standard_deviation * 2 - weighted_numbers = [i * weight for i in numbers] - return weight, weighted_numbers + weight_assigned = standard_deviation * 2 + weighted_numbers = [i * weight_assigned for i in numbers] + return weight_assigned, weighted_numbers