Deserialize from file

This commit is contained in:
Samuel Sloniker 2022-12-23 10:49:24 -08:00
parent 9916744801
commit 74b2ba81b9
Signed by: kj7rrv
GPG Key ID: 1BB4029E66285A62
3 changed files with 23 additions and 19 deletions

View File

@ -49,7 +49,7 @@ Write binary data representing the model to `file`.
### `gptc.deserialize(encoded_model)`
Deserialize a `Model` from a `bytes` returned by `Model.serialize()`.
Deserialize a `Model` from a file containing data from `Model.serialize()`.
### `Model.confidence(text, max_ngram_length)`

View File

@ -69,7 +69,7 @@ def main() -> None:
gptc.compile(model, args.max_ngram_length, args.min_count).serialize(sys.stdout.buffer)
elif args.subparser_name == "classify":
with open(args.model, "rb") as f:
model = gptc.deserialize(f.read())
model = gptc.deserialize(f)
if sys.stdin.isatty():
text = input("Text to analyse: ")
@ -85,7 +85,7 @@ def main() -> None:
print(json.dumps(probabilities))
elif args.subparser_name == "check":
with open(args.model, "rb") as f:
model = gptc.deserialize(f.read())
model = gptc.deserialize(f)
print(json.dumps(model.get(args.token)))
else:
print(json.dumps(gptc.pack(args.model, True)[0]))

View File

@ -104,14 +104,20 @@ class Model:
)
def deserialize(encoded_model: bytes) -> Model:
try:
prefix, config_json, encoded_weights = encoded_model.split(b"\n", 2)
except ValueError:
def deserialize(encoded_model: BinaryIO) -> Model:
prefix = encoded_model.read(14)
if prefix != b"GPTC model v4\n":
raise InvalidModelError()
if prefix != b"GPTC model v4":
raise InvalidModelError()
config_json = b""
while True:
byte = encoded_model.read(1)
if byte == b"\n":
break
elif byte == b"":
raise InvalidModelError()
else:
config_json += byte
try:
config = json.loads(config_json.decode("utf-8"))
@ -131,20 +137,18 @@ def deserialize(encoded_model: bytes) -> Model:
weight_code_length = 6 + 2 * len(names)
if len(encoded_weights) % weight_code_length != 0:
raise InvalidModelError()
weights: Dict[int : List[int]] = {}
weight_codes = [
encoded_weights[x : x + weight_code_length]
for x in range(0, len(encoded_weights), weight_code_length)
]
while True:
code = encoded_model.read(weight_code_length)
if not code:
break
elif len(code) != weight_code_length:
raise InvalidModelError()
weights = {
int.from_bytes(code[:6], "big"): [
weights[int.from_bytes(code[:6], "big")] = [
int.from_bytes(value, "big")
for value in [code[x : x + 2] for x in range(6, len(code), 2)]
]
for code in weight_codes
}
return Model(weights, names, max_ngram_length)