Move deserialize to Model object
This commit is contained in:
parent
457b569741
commit
97c4eef086
|
@ -3,7 +3,7 @@
|
|||
"""General-Purpose Text Classifier"""
|
||||
|
||||
from gptc.pack import pack
|
||||
from gptc.model import Model, deserialize
|
||||
from gptc.model import Model
|
||||
from gptc.tokenizer import normalize
|
||||
from gptc.exceptions import (
|
||||
GPTCError,
|
||||
|
|
|
@ -68,7 +68,7 @@ def main() -> None:
|
|||
).serialize(output_file)
|
||||
elif args.subparser_name == "classify":
|
||||
with open(args.model, "rb") as model_file:
|
||||
model = gptc.deserialize(model_file)
|
||||
model = gptc.Model.deserialize(model_file)
|
||||
|
||||
if sys.stdin.isatty():
|
||||
text = input("Text to analyse: ")
|
||||
|
@ -78,7 +78,7 @@ def main() -> None:
|
|||
print(json.dumps(model.confidence(text, args.max_ngram_length)))
|
||||
elif args.subparser_name == "check":
|
||||
with open(args.model, "rb") as model_file:
|
||||
model = gptc.deserialize(model_file)
|
||||
model = gptc.Model.deserialize(model_file)
|
||||
print(json.dumps(model.get(args.token)))
|
||||
else:
|
||||
print(json.dumps(gptc.pack(args.model, True)[0]))
|
||||
|
|
|
@ -269,8 +269,8 @@ class Model:
|
|||
model = _get_weights(min_count, word_counts, category_lengths, names)
|
||||
return Model(model, names, max_ngram_length, hash_algorithm)
|
||||
|
||||
|
||||
def deserialize(encoded_model: BinaryIO) -> Model:
|
||||
@staticmethod
|
||||
def deserialize(encoded_model: BinaryIO) -> "Model":
|
||||
prefix = encoded_model.read(14)
|
||||
if prefix != b"GPTC model v6\n":
|
||||
raise InvalidModelError()
|
||||
|
|
Loading…
Reference in New Issue
Block a user