diff --git a/gptc/__init__.py b/gptc/__init__.py index eaee24c..7d84238 100644 --- a/gptc/__init__.py +++ b/gptc/__init__.py @@ -3,7 +3,7 @@ """General-Purpose Text Classifier""" from gptc.pack import pack -from gptc.model import Model, deserialize +from gptc.model import Model from gptc.tokenizer import normalize from gptc.exceptions import ( GPTCError, diff --git a/gptc/__main__.py b/gptc/__main__.py index f035be8..3c3f441 100644 --- a/gptc/__main__.py +++ b/gptc/__main__.py @@ -68,7 +68,7 @@ def main() -> None: ).serialize(output_file) elif args.subparser_name == "classify": with open(args.model, "rb") as model_file: - model = gptc.deserialize(model_file) + model = gptc.Model.deserialize(model_file) if sys.stdin.isatty(): text = input("Text to analyse: ") @@ -78,7 +78,7 @@ def main() -> None: print(json.dumps(model.confidence(text, args.max_ngram_length))) elif args.subparser_name == "check": with open(args.model, "rb") as model_file: - model = gptc.deserialize(model_file) + model = gptc.Model.deserialize(model_file) print(json.dumps(model.get(args.token))) else: print(json.dumps(gptc.pack(args.model, True)[0])) diff --git a/gptc/model.py b/gptc/model.py index b3813b1..f7c5dba 100644 --- a/gptc/model.py +++ b/gptc/model.py @@ -269,54 +269,54 @@ class Model: model = _get_weights(min_count, word_counts, category_lengths, names) return Model(model, names, max_ngram_length, hash_algorithm) - -def deserialize(encoded_model: BinaryIO) -> Model: - prefix = encoded_model.read(14) - if prefix != b"GPTC model v6\n": - raise InvalidModelError() - - config_json = b"" - while True: - byte = encoded_model.read(1) - if byte == b"\n": - break - - if byte == b"": + @staticmethod + def deserialize(encoded_model: BinaryIO) -> "Model": + prefix = encoded_model.read(14) + if prefix != b"GPTC model v6\n": raise InvalidModelError() - config_json += byte + config_json = b"" + while True: + byte = encoded_model.read(1) + if byte == b"\n": + break - try: - config = json.loads(config_json.decode("utf-8")) - except (UnicodeDecodeError, json.JSONDecodeError) as exc: - raise InvalidModelError() from exc + if byte == b"": + raise InvalidModelError() - try: - names = config["names"] - max_ngram_length = config["max_ngram_length"] - hash_algorithm = config["hash_algorithm"] - except KeyError as exc: - raise InvalidModelError() from exc + config_json += byte - if not ( - isinstance(names, list) and isinstance(max_ngram_length, int) - ) or not all(isinstance(name, str) for name in names): - raise InvalidModelError() + try: + config = json.loads(config_json.decode("utf-8")) + except (UnicodeDecodeError, json.JSONDecodeError) as exc: + raise InvalidModelError() from exc - weight_code_length = 6 + 2 * len(names) + try: + names = config["names"] + max_ngram_length = config["max_ngram_length"] + hash_algorithm = config["hash_algorithm"] + except KeyError as exc: + raise InvalidModelError() from exc - weights: Dict[int, List[int]] = {} - - while True: - code = encoded_model.read(weight_code_length) - if not code: - break - if len(code) != weight_code_length: + if not ( + isinstance(names, list) and isinstance(max_ngram_length, int) + ) or not all(isinstance(name, str) for name in names): raise InvalidModelError() - weights[int.from_bytes(code[:6], "big")] = [ - int.from_bytes(value, "big") - for value in [code[x : x + 2] for x in range(6, len(code), 2)] - ] + weight_code_length = 6 + 2 * len(names) - return Model(weights, names, max_ngram_length, hash_algorithm) + weights: Dict[int, List[int]] = {} + + while True: + code = encoded_model.read(weight_code_length) + if not code: + break + if len(code) != weight_code_length: + raise InvalidModelError() + + weights[int.from_bytes(code[:6], "big")] = [ + int.from_bytes(value, "big") + for value in [code[x : x + 2] for x in range(6, len(code), 2)] + ] + + return Model(weights, names, max_ngram_length, hash_algorithm)