Move deserialize to Model object

This commit is contained in:
Samuel Sloniker 2023-04-17 21:35:38 -07:00
parent 457b569741
commit 97c4eef086
Signed by: kj7rrv
GPG Key ID: 1BB4029E66285A62
3 changed files with 44 additions and 44 deletions

View File

@ -3,7 +3,7 @@
"""General-Purpose Text Classifier"""
from gptc.pack import pack
from gptc.model import Model, deserialize
from gptc.model import Model
from gptc.tokenizer import normalize
from gptc.exceptions import (
GPTCError,

View File

@ -68,7 +68,7 @@ def main() -> None:
).serialize(output_file)
elif args.subparser_name == "classify":
with open(args.model, "rb") as model_file:
model = gptc.deserialize(model_file)
model = gptc.Model.deserialize(model_file)
if sys.stdin.isatty():
text = input("Text to analyse: ")
@ -78,7 +78,7 @@ def main() -> None:
print(json.dumps(model.confidence(text, args.max_ngram_length)))
elif args.subparser_name == "check":
with open(args.model, "rb") as model_file:
model = gptc.deserialize(model_file)
model = gptc.Model.deserialize(model_file)
print(json.dumps(model.get(args.token)))
else:
print(json.dumps(gptc.pack(args.model, True)[0]))

View File

@ -269,8 +269,8 @@ class Model:
model = _get_weights(min_count, word_counts, category_lengths, names)
return Model(model, names, max_ngram_length, hash_algorithm)
def deserialize(encoded_model: BinaryIO) -> Model:
@staticmethod
def deserialize(encoded_model: BinaryIO) -> "Model":
prefix = encoded_model.read(14)
if prefix != b"GPTC model v6\n":
raise InvalidModelError()