Move deserialize to Model object

This commit is contained in:
Samuel Sloniker 2023-04-17 21:35:38 -07:00
parent 457b569741
commit 97c4eef086
Signed by: kj7rrv
GPG Key ID: 1BB4029E66285A62
3 changed files with 44 additions and 44 deletions

View File

@ -3,7 +3,7 @@
"""General-Purpose Text Classifier""" """General-Purpose Text Classifier"""
from gptc.pack import pack from gptc.pack import pack
from gptc.model import Model, deserialize from gptc.model import Model
from gptc.tokenizer import normalize from gptc.tokenizer import normalize
from gptc.exceptions import ( from gptc.exceptions import (
GPTCError, GPTCError,

View File

@ -68,7 +68,7 @@ def main() -> None:
).serialize(output_file) ).serialize(output_file)
elif args.subparser_name == "classify": elif args.subparser_name == "classify":
with open(args.model, "rb") as model_file: with open(args.model, "rb") as model_file:
model = gptc.deserialize(model_file) model = gptc.Model.deserialize(model_file)
if sys.stdin.isatty(): if sys.stdin.isatty():
text = input("Text to analyse: ") text = input("Text to analyse: ")
@ -78,7 +78,7 @@ def main() -> None:
print(json.dumps(model.confidence(text, args.max_ngram_length))) print(json.dumps(model.confidence(text, args.max_ngram_length)))
elif args.subparser_name == "check": elif args.subparser_name == "check":
with open(args.model, "rb") as model_file: with open(args.model, "rb") as model_file:
model = gptc.deserialize(model_file) model = gptc.Model.deserialize(model_file)
print(json.dumps(model.get(args.token))) print(json.dumps(model.get(args.token)))
else: else:
print(json.dumps(gptc.pack(args.model, True)[0])) print(json.dumps(gptc.pack(args.model, True)[0]))

View File

@ -269,8 +269,8 @@ class Model:
model = _get_weights(min_count, word_counts, category_lengths, names) model = _get_weights(min_count, word_counts, category_lengths, names)
return Model(model, names, max_ngram_length, hash_algorithm) return Model(model, names, max_ngram_length, hash_algorithm)
@staticmethod
def deserialize(encoded_model: BinaryIO) -> Model: def deserialize(encoded_model: BinaryIO) -> "Model":
prefix = encoded_model.read(14) prefix = encoded_model.read(14)
if prefix != b"GPTC model v6\n": if prefix != b"GPTC model v6\n":
raise InvalidModelError() raise InvalidModelError()