From 74b2ba81b9fe71b4e3bad11a7a4514528f8b897b Mon Sep 17 00:00:00 2001 From: Samuel Sloniker Date: Fri, 23 Dec 2022 10:49:24 -0800 Subject: [PATCH] Deserialize from file --- README.md | 2 +- gptc/__main__.py | 4 ++-- gptc/model.py | 36 ++++++++++++++++++++---------------- 3 files changed, 23 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index 51ab811..d25779f 100644 --- a/README.md +++ b/README.md @@ -49,7 +49,7 @@ Write binary data representing the model to `file`. ### `gptc.deserialize(encoded_model)` -Deserialize a `Model` from a `bytes` returned by `Model.serialize()`. +Deserialize a `Model` from a file containing data from `Model.serialize()`. ### `Model.confidence(text, max_ngram_length)` diff --git a/gptc/__main__.py b/gptc/__main__.py index eaa3143..c55bacb 100644 --- a/gptc/__main__.py +++ b/gptc/__main__.py @@ -69,7 +69,7 @@ def main() -> None: gptc.compile(model, args.max_ngram_length, args.min_count).serialize(sys.stdout.buffer) elif args.subparser_name == "classify": with open(args.model, "rb") as f: - model = gptc.deserialize(f.read()) + model = gptc.deserialize(f) if sys.stdin.isatty(): text = input("Text to analyse: ") @@ -85,7 +85,7 @@ def main() -> None: print(json.dumps(probabilities)) elif args.subparser_name == "check": with open(args.model, "rb") as f: - model = gptc.deserialize(f.read()) + model = gptc.deserialize(f) print(json.dumps(model.get(args.token))) else: print(json.dumps(gptc.pack(args.model, True)[0])) diff --git a/gptc/model.py b/gptc/model.py index 53bcf01..faedaef 100644 --- a/gptc/model.py +++ b/gptc/model.py @@ -104,14 +104,20 @@ class Model: ) -def deserialize(encoded_model: bytes) -> Model: - try: - prefix, config_json, encoded_weights = encoded_model.split(b"\n", 2) - except ValueError: +def deserialize(encoded_model: BinaryIO) -> Model: + prefix = encoded_model.read(14) + if prefix != b"GPTC model v4\n": raise InvalidModelError() - if prefix != b"GPTC model v4": - raise InvalidModelError() + config_json = b"" + while True: + byte = encoded_model.read(1) + if byte == b"\n": + break + elif byte == b"": + raise InvalidModelError() + else: + config_json += byte try: config = json.loads(config_json.decode("utf-8")) @@ -131,20 +137,18 @@ def deserialize(encoded_model: bytes) -> Model: weight_code_length = 6 + 2 * len(names) - if len(encoded_weights) % weight_code_length != 0: - raise InvalidModelError() + weights: Dict[int : List[int]] = {} - weight_codes = [ - encoded_weights[x : x + weight_code_length] - for x in range(0, len(encoded_weights), weight_code_length) - ] + while True: + code = encoded_model.read(weight_code_length) + if not code: + break + elif len(code) != weight_code_length: + raise InvalidModelError() - weights = { - int.from_bytes(code[:6], "big"): [ + weights[int.from_bytes(code[:6], "big")] = [ int.from_bytes(value, "big") for value in [code[x : x + 2] for x in range(6, len(code), 2)] ] - for code in weight_codes - } return Model(weights, names, max_ngram_length)