Deserialize from file

This commit is contained in:
Samuel Sloniker 2022-12-23 10:49:24 -08:00
parent 9916744801
commit 74b2ba81b9
Signed by: kj7rrv
GPG Key ID: 1BB4029E66285A62
3 changed files with 23 additions and 19 deletions

View File

@ -49,7 +49,7 @@ Write binary data representing the model to `file`.
### `gptc.deserialize(encoded_model)` ### `gptc.deserialize(encoded_model)`
Deserialize a `Model` from a `bytes` returned by `Model.serialize()`. Deserialize a `Model` from a file containing data from `Model.serialize()`.
### `Model.confidence(text, max_ngram_length)` ### `Model.confidence(text, max_ngram_length)`

View File

@ -69,7 +69,7 @@ def main() -> None:
gptc.compile(model, args.max_ngram_length, args.min_count).serialize(sys.stdout.buffer) gptc.compile(model, args.max_ngram_length, args.min_count).serialize(sys.stdout.buffer)
elif args.subparser_name == "classify": elif args.subparser_name == "classify":
with open(args.model, "rb") as f: with open(args.model, "rb") as f:
model = gptc.deserialize(f.read()) model = gptc.deserialize(f)
if sys.stdin.isatty(): if sys.stdin.isatty():
text = input("Text to analyse: ") text = input("Text to analyse: ")
@ -85,7 +85,7 @@ def main() -> None:
print(json.dumps(probabilities)) print(json.dumps(probabilities))
elif args.subparser_name == "check": elif args.subparser_name == "check":
with open(args.model, "rb") as f: with open(args.model, "rb") as f:
model = gptc.deserialize(f.read()) model = gptc.deserialize(f)
print(json.dumps(model.get(args.token))) print(json.dumps(model.get(args.token)))
else: else:
print(json.dumps(gptc.pack(args.model, True)[0])) print(json.dumps(gptc.pack(args.model, True)[0]))

View File

@ -104,14 +104,20 @@ class Model:
) )
def deserialize(encoded_model: bytes) -> Model: def deserialize(encoded_model: BinaryIO) -> Model:
try: prefix = encoded_model.read(14)
prefix, config_json, encoded_weights = encoded_model.split(b"\n", 2) if prefix != b"GPTC model v4\n":
except ValueError:
raise InvalidModelError() raise InvalidModelError()
if prefix != b"GPTC model v4": config_json = b""
while True:
byte = encoded_model.read(1)
if byte == b"\n":
break
elif byte == b"":
raise InvalidModelError() raise InvalidModelError()
else:
config_json += byte
try: try:
config = json.loads(config_json.decode("utf-8")) config = json.loads(config_json.decode("utf-8"))
@ -131,20 +137,18 @@ def deserialize(encoded_model: bytes) -> Model:
weight_code_length = 6 + 2 * len(names) weight_code_length = 6 + 2 * len(names)
if len(encoded_weights) % weight_code_length != 0: weights: Dict[int : List[int]] = {}
while True:
code = encoded_model.read(weight_code_length)
if not code:
break
elif len(code) != weight_code_length:
raise InvalidModelError() raise InvalidModelError()
weight_codes = [ weights[int.from_bytes(code[:6], "big")] = [
encoded_weights[x : x + weight_code_length]
for x in range(0, len(encoded_weights), weight_code_length)
]
weights = {
int.from_bytes(code[:6], "big"): [
int.from_bytes(value, "big") int.from_bytes(value, "big")
for value in [code[x : x + 2] for x in range(6, len(code), 2)] for value in [code[x : x + 2] for x in range(6, len(code), 2)]
] ]
for code in weight_codes
}
return Model(weights, names, max_ngram_length) return Model(weights, names, max_ngram_length)