Move deserialize to Model object

This commit is contained in:
Samuel Sloniker 2023-04-17 21:35:38 -07:00
parent 457b569741
commit 97c4eef086
Signed by: kj7rrv
GPG Key ID: 1BB4029E66285A62
3 changed files with 44 additions and 44 deletions

View File

@ -3,7 +3,7 @@
"""General-Purpose Text Classifier"""
from gptc.pack import pack
from gptc.model import Model, deserialize
from gptc.model import Model
from gptc.tokenizer import normalize
from gptc.exceptions import (
GPTCError,

View File

@ -68,7 +68,7 @@ def main() -> None:
).serialize(output_file)
elif args.subparser_name == "classify":
with open(args.model, "rb") as model_file:
model = gptc.deserialize(model_file)
model = gptc.Model.deserialize(model_file)
if sys.stdin.isatty():
text = input("Text to analyse: ")
@ -78,7 +78,7 @@ def main() -> None:
print(json.dumps(model.confidence(text, args.max_ngram_length)))
elif args.subparser_name == "check":
with open(args.model, "rb") as model_file:
model = gptc.deserialize(model_file)
model = gptc.Model.deserialize(model_file)
print(json.dumps(model.get(args.token)))
else:
print(json.dumps(gptc.pack(args.model, True)[0]))

View File

@ -269,54 +269,54 @@ class Model:
model = _get_weights(min_count, word_counts, category_lengths, names)
return Model(model, names, max_ngram_length, hash_algorithm)
def deserialize(encoded_model: BinaryIO) -> Model:
prefix = encoded_model.read(14)
if prefix != b"GPTC model v6\n":
raise InvalidModelError()
config_json = b""
while True:
byte = encoded_model.read(1)
if byte == b"\n":
break
if byte == b"":
@staticmethod
def deserialize(encoded_model: BinaryIO) -> "Model":
prefix = encoded_model.read(14)
if prefix != b"GPTC model v6\n":
raise InvalidModelError()
config_json += byte
config_json = b""
while True:
byte = encoded_model.read(1)
if byte == b"\n":
break
try:
config = json.loads(config_json.decode("utf-8"))
except (UnicodeDecodeError, json.JSONDecodeError) as exc:
raise InvalidModelError() from exc
if byte == b"":
raise InvalidModelError()
try:
names = config["names"]
max_ngram_length = config["max_ngram_length"]
hash_algorithm = config["hash_algorithm"]
except KeyError as exc:
raise InvalidModelError() from exc
config_json += byte
if not (
isinstance(names, list) and isinstance(max_ngram_length, int)
) or not all(isinstance(name, str) for name in names):
raise InvalidModelError()
try:
config = json.loads(config_json.decode("utf-8"))
except (UnicodeDecodeError, json.JSONDecodeError) as exc:
raise InvalidModelError() from exc
weight_code_length = 6 + 2 * len(names)
try:
names = config["names"]
max_ngram_length = config["max_ngram_length"]
hash_algorithm = config["hash_algorithm"]
except KeyError as exc:
raise InvalidModelError() from exc
weights: Dict[int, List[int]] = {}
while True:
code = encoded_model.read(weight_code_length)
if not code:
break
if len(code) != weight_code_length:
if not (
isinstance(names, list) and isinstance(max_ngram_length, int)
) or not all(isinstance(name, str) for name in names):
raise InvalidModelError()
weights[int.from_bytes(code[:6], "big")] = [
int.from_bytes(value, "big")
for value in [code[x : x + 2] for x in range(6, len(code), 2)]
]
weight_code_length = 6 + 2 * len(names)
return Model(weights, names, max_ngram_length, hash_algorithm)
weights: Dict[int, List[int]] = {}
while True:
code = encoded_model.read(weight_code_length)
if not code:
break
if len(code) != weight_code_length:
raise InvalidModelError()
weights[int.from_bytes(code[:6], "big")] = [
int.from_bytes(value, "big")
for value in [code[x : x + 2] for x in range(6, len(code), 2)]
]
return Model(weights, names, max_ngram_length, hash_algorithm)