Move deserialize to Model object
This commit is contained in:
parent
457b569741
commit
97c4eef086
|
@ -3,7 +3,7 @@
|
||||||
"""General-Purpose Text Classifier"""
|
"""General-Purpose Text Classifier"""
|
||||||
|
|
||||||
from gptc.pack import pack
|
from gptc.pack import pack
|
||||||
from gptc.model import Model, deserialize
|
from gptc.model import Model
|
||||||
from gptc.tokenizer import normalize
|
from gptc.tokenizer import normalize
|
||||||
from gptc.exceptions import (
|
from gptc.exceptions import (
|
||||||
GPTCError,
|
GPTCError,
|
||||||
|
|
|
@ -68,7 +68,7 @@ def main() -> None:
|
||||||
).serialize(output_file)
|
).serialize(output_file)
|
||||||
elif args.subparser_name == "classify":
|
elif args.subparser_name == "classify":
|
||||||
with open(args.model, "rb") as model_file:
|
with open(args.model, "rb") as model_file:
|
||||||
model = gptc.deserialize(model_file)
|
model = gptc.Model.deserialize(model_file)
|
||||||
|
|
||||||
if sys.stdin.isatty():
|
if sys.stdin.isatty():
|
||||||
text = input("Text to analyse: ")
|
text = input("Text to analyse: ")
|
||||||
|
@ -78,7 +78,7 @@ def main() -> None:
|
||||||
print(json.dumps(model.confidence(text, args.max_ngram_length)))
|
print(json.dumps(model.confidence(text, args.max_ngram_length)))
|
||||||
elif args.subparser_name == "check":
|
elif args.subparser_name == "check":
|
||||||
with open(args.model, "rb") as model_file:
|
with open(args.model, "rb") as model_file:
|
||||||
model = gptc.deserialize(model_file)
|
model = gptc.Model.deserialize(model_file)
|
||||||
print(json.dumps(model.get(args.token)))
|
print(json.dumps(model.get(args.token)))
|
||||||
else:
|
else:
|
||||||
print(json.dumps(gptc.pack(args.model, True)[0]))
|
print(json.dumps(gptc.pack(args.model, True)[0]))
|
||||||
|
|
|
@ -269,54 +269,54 @@ class Model:
|
||||||
model = _get_weights(min_count, word_counts, category_lengths, names)
|
model = _get_weights(min_count, word_counts, category_lengths, names)
|
||||||
return Model(model, names, max_ngram_length, hash_algorithm)
|
return Model(model, names, max_ngram_length, hash_algorithm)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def deserialize(encoded_model: BinaryIO) -> Model:
|
def deserialize(encoded_model: BinaryIO) -> "Model":
|
||||||
prefix = encoded_model.read(14)
|
prefix = encoded_model.read(14)
|
||||||
if prefix != b"GPTC model v6\n":
|
if prefix != b"GPTC model v6\n":
|
||||||
raise InvalidModelError()
|
|
||||||
|
|
||||||
config_json = b""
|
|
||||||
while True:
|
|
||||||
byte = encoded_model.read(1)
|
|
||||||
if byte == b"\n":
|
|
||||||
break
|
|
||||||
|
|
||||||
if byte == b"":
|
|
||||||
raise InvalidModelError()
|
raise InvalidModelError()
|
||||||
|
|
||||||
config_json += byte
|
config_json = b""
|
||||||
|
while True:
|
||||||
|
byte = encoded_model.read(1)
|
||||||
|
if byte == b"\n":
|
||||||
|
break
|
||||||
|
|
||||||
try:
|
if byte == b"":
|
||||||
config = json.loads(config_json.decode("utf-8"))
|
raise InvalidModelError()
|
||||||
except (UnicodeDecodeError, json.JSONDecodeError) as exc:
|
|
||||||
raise InvalidModelError() from exc
|
|
||||||
|
|
||||||
try:
|
config_json += byte
|
||||||
names = config["names"]
|
|
||||||
max_ngram_length = config["max_ngram_length"]
|
|
||||||
hash_algorithm = config["hash_algorithm"]
|
|
||||||
except KeyError as exc:
|
|
||||||
raise InvalidModelError() from exc
|
|
||||||
|
|
||||||
if not (
|
try:
|
||||||
isinstance(names, list) and isinstance(max_ngram_length, int)
|
config = json.loads(config_json.decode("utf-8"))
|
||||||
) or not all(isinstance(name, str) for name in names):
|
except (UnicodeDecodeError, json.JSONDecodeError) as exc:
|
||||||
raise InvalidModelError()
|
raise InvalidModelError() from exc
|
||||||
|
|
||||||
weight_code_length = 6 + 2 * len(names)
|
try:
|
||||||
|
names = config["names"]
|
||||||
|
max_ngram_length = config["max_ngram_length"]
|
||||||
|
hash_algorithm = config["hash_algorithm"]
|
||||||
|
except KeyError as exc:
|
||||||
|
raise InvalidModelError() from exc
|
||||||
|
|
||||||
weights: Dict[int, List[int]] = {}
|
if not (
|
||||||
|
isinstance(names, list) and isinstance(max_ngram_length, int)
|
||||||
while True:
|
) or not all(isinstance(name, str) for name in names):
|
||||||
code = encoded_model.read(weight_code_length)
|
|
||||||
if not code:
|
|
||||||
break
|
|
||||||
if len(code) != weight_code_length:
|
|
||||||
raise InvalidModelError()
|
raise InvalidModelError()
|
||||||
|
|
||||||
weights[int.from_bytes(code[:6], "big")] = [
|
weight_code_length = 6 + 2 * len(names)
|
||||||
int.from_bytes(value, "big")
|
|
||||||
for value in [code[x : x + 2] for x in range(6, len(code), 2)]
|
|
||||||
]
|
|
||||||
|
|
||||||
return Model(weights, names, max_ngram_length, hash_algorithm)
|
weights: Dict[int, List[int]] = {}
|
||||||
|
|
||||||
|
while True:
|
||||||
|
code = encoded_model.read(weight_code_length)
|
||||||
|
if not code:
|
||||||
|
break
|
||||||
|
if len(code) != weight_code_length:
|
||||||
|
raise InvalidModelError()
|
||||||
|
|
||||||
|
weights[int.from_bytes(code[:6], "big")] = [
|
||||||
|
int.from_bytes(value, "big")
|
||||||
|
for value in [code[x : x + 2] for x in range(6, len(code), 2)]
|
||||||
|
]
|
||||||
|
|
||||||
|
return Model(weights, names, max_ngram_length, hash_algorithm)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user