Move deserialize to Model object

This commit is contained in:
Samuel Sloniker 2023-04-17 21:35:38 -07:00
parent 457b569741
commit 97c4eef086
Signed by: kj7rrv
GPG Key ID: 1BB4029E66285A62
3 changed files with 44 additions and 44 deletions

View File

@ -3,7 +3,7 @@
"""General-Purpose Text Classifier""" """General-Purpose Text Classifier"""
from gptc.pack import pack from gptc.pack import pack
from gptc.model import Model, deserialize from gptc.model import Model
from gptc.tokenizer import normalize from gptc.tokenizer import normalize
from gptc.exceptions import ( from gptc.exceptions import (
GPTCError, GPTCError,

View File

@ -68,7 +68,7 @@ def main() -> None:
).serialize(output_file) ).serialize(output_file)
elif args.subparser_name == "classify": elif args.subparser_name == "classify":
with open(args.model, "rb") as model_file: with open(args.model, "rb") as model_file:
model = gptc.deserialize(model_file) model = gptc.Model.deserialize(model_file)
if sys.stdin.isatty(): if sys.stdin.isatty():
text = input("Text to analyse: ") text = input("Text to analyse: ")
@ -78,7 +78,7 @@ def main() -> None:
print(json.dumps(model.confidence(text, args.max_ngram_length))) print(json.dumps(model.confidence(text, args.max_ngram_length)))
elif args.subparser_name == "check": elif args.subparser_name == "check":
with open(args.model, "rb") as model_file: with open(args.model, "rb") as model_file:
model = gptc.deserialize(model_file) model = gptc.Model.deserialize(model_file)
print(json.dumps(model.get(args.token))) print(json.dumps(model.get(args.token)))
else: else:
print(json.dumps(gptc.pack(args.model, True)[0])) print(json.dumps(gptc.pack(args.model, True)[0]))

View File

@ -269,54 +269,54 @@ class Model:
model = _get_weights(min_count, word_counts, category_lengths, names) model = _get_weights(min_count, word_counts, category_lengths, names)
return Model(model, names, max_ngram_length, hash_algorithm) return Model(model, names, max_ngram_length, hash_algorithm)
@staticmethod
def deserialize(encoded_model: BinaryIO) -> Model: def deserialize(encoded_model: BinaryIO) -> "Model":
prefix = encoded_model.read(14) prefix = encoded_model.read(14)
if prefix != b"GPTC model v6\n": if prefix != b"GPTC model v6\n":
raise InvalidModelError()
config_json = b""
while True:
byte = encoded_model.read(1)
if byte == b"\n":
break
if byte == b"":
raise InvalidModelError() raise InvalidModelError()
config_json += byte config_json = b""
while True:
byte = encoded_model.read(1)
if byte == b"\n":
break
try: if byte == b"":
config = json.loads(config_json.decode("utf-8")) raise InvalidModelError()
except (UnicodeDecodeError, json.JSONDecodeError) as exc:
raise InvalidModelError() from exc
try: config_json += byte
names = config["names"]
max_ngram_length = config["max_ngram_length"]
hash_algorithm = config["hash_algorithm"]
except KeyError as exc:
raise InvalidModelError() from exc
if not ( try:
isinstance(names, list) and isinstance(max_ngram_length, int) config = json.loads(config_json.decode("utf-8"))
) or not all(isinstance(name, str) for name in names): except (UnicodeDecodeError, json.JSONDecodeError) as exc:
raise InvalidModelError() raise InvalidModelError() from exc
weight_code_length = 6 + 2 * len(names) try:
names = config["names"]
max_ngram_length = config["max_ngram_length"]
hash_algorithm = config["hash_algorithm"]
except KeyError as exc:
raise InvalidModelError() from exc
weights: Dict[int, List[int]] = {} if not (
isinstance(names, list) and isinstance(max_ngram_length, int)
while True: ) or not all(isinstance(name, str) for name in names):
code = encoded_model.read(weight_code_length)
if not code:
break
if len(code) != weight_code_length:
raise InvalidModelError() raise InvalidModelError()
weights[int.from_bytes(code[:6], "big")] = [ weight_code_length = 6 + 2 * len(names)
int.from_bytes(value, "big")
for value in [code[x : x + 2] for x in range(6, len(code), 2)]
]
return Model(weights, names, max_ngram_length, hash_algorithm) weights: Dict[int, List[int]] = {}
while True:
code = encoded_model.read(weight_code_length)
if not code:
break
if len(code) != weight_code_length:
raise InvalidModelError()
weights[int.from_bytes(code[:6], "big")] = [
int.from_bytes(value, "big")
for value in [code[x : x + 2] for x in range(6, len(code), 2)]
]
return Model(weights, names, max_ngram_length, hash_algorithm)