Deserialize from file
This commit is contained in:
parent
9916744801
commit
74b2ba81b9
|
@ -49,7 +49,7 @@ Write binary data representing the model to `file`.
|
||||||
|
|
||||||
### `gptc.deserialize(encoded_model)`
|
### `gptc.deserialize(encoded_model)`
|
||||||
|
|
||||||
Deserialize a `Model` from a `bytes` returned by `Model.serialize()`.
|
Deserialize a `Model` from a file containing data from `Model.serialize()`.
|
||||||
|
|
||||||
### `Model.confidence(text, max_ngram_length)`
|
### `Model.confidence(text, max_ngram_length)`
|
||||||
|
|
||||||
|
|
|
@ -69,7 +69,7 @@ def main() -> None:
|
||||||
gptc.compile(model, args.max_ngram_length, args.min_count).serialize(sys.stdout.buffer)
|
gptc.compile(model, args.max_ngram_length, args.min_count).serialize(sys.stdout.buffer)
|
||||||
elif args.subparser_name == "classify":
|
elif args.subparser_name == "classify":
|
||||||
with open(args.model, "rb") as f:
|
with open(args.model, "rb") as f:
|
||||||
model = gptc.deserialize(f.read())
|
model = gptc.deserialize(f)
|
||||||
|
|
||||||
if sys.stdin.isatty():
|
if sys.stdin.isatty():
|
||||||
text = input("Text to analyse: ")
|
text = input("Text to analyse: ")
|
||||||
|
@ -85,7 +85,7 @@ def main() -> None:
|
||||||
print(json.dumps(probabilities))
|
print(json.dumps(probabilities))
|
||||||
elif args.subparser_name == "check":
|
elif args.subparser_name == "check":
|
||||||
with open(args.model, "rb") as f:
|
with open(args.model, "rb") as f:
|
||||||
model = gptc.deserialize(f.read())
|
model = gptc.deserialize(f)
|
||||||
print(json.dumps(model.get(args.token)))
|
print(json.dumps(model.get(args.token)))
|
||||||
else:
|
else:
|
||||||
print(json.dumps(gptc.pack(args.model, True)[0]))
|
print(json.dumps(gptc.pack(args.model, True)[0]))
|
||||||
|
|
|
@ -104,14 +104,20 @@ class Model:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def deserialize(encoded_model: bytes) -> Model:
|
def deserialize(encoded_model: BinaryIO) -> Model:
|
||||||
try:
|
prefix = encoded_model.read(14)
|
||||||
prefix, config_json, encoded_weights = encoded_model.split(b"\n", 2)
|
if prefix != b"GPTC model v4\n":
|
||||||
except ValueError:
|
|
||||||
raise InvalidModelError()
|
raise InvalidModelError()
|
||||||
|
|
||||||
if prefix != b"GPTC model v4":
|
config_json = b""
|
||||||
|
while True:
|
||||||
|
byte = encoded_model.read(1)
|
||||||
|
if byte == b"\n":
|
||||||
|
break
|
||||||
|
elif byte == b"":
|
||||||
raise InvalidModelError()
|
raise InvalidModelError()
|
||||||
|
else:
|
||||||
|
config_json += byte
|
||||||
|
|
||||||
try:
|
try:
|
||||||
config = json.loads(config_json.decode("utf-8"))
|
config = json.loads(config_json.decode("utf-8"))
|
||||||
|
@ -131,20 +137,18 @@ def deserialize(encoded_model: bytes) -> Model:
|
||||||
|
|
||||||
weight_code_length = 6 + 2 * len(names)
|
weight_code_length = 6 + 2 * len(names)
|
||||||
|
|
||||||
if len(encoded_weights) % weight_code_length != 0:
|
weights: Dict[int : List[int]] = {}
|
||||||
|
|
||||||
|
while True:
|
||||||
|
code = encoded_model.read(weight_code_length)
|
||||||
|
if not code:
|
||||||
|
break
|
||||||
|
elif len(code) != weight_code_length:
|
||||||
raise InvalidModelError()
|
raise InvalidModelError()
|
||||||
|
|
||||||
weight_codes = [
|
weights[int.from_bytes(code[:6], "big")] = [
|
||||||
encoded_weights[x : x + weight_code_length]
|
|
||||||
for x in range(0, len(encoded_weights), weight_code_length)
|
|
||||||
]
|
|
||||||
|
|
||||||
weights = {
|
|
||||||
int.from_bytes(code[:6], "big"): [
|
|
||||||
int.from_bytes(value, "big")
|
int.from_bytes(value, "big")
|
||||||
for value in [code[x : x + 2] for x in range(6, len(code), 2)]
|
for value in [code[x : x + 2] for x in range(6, len(code), 2)]
|
||||||
]
|
]
|
||||||
for code in weight_codes
|
|
||||||
}
|
|
||||||
|
|
||||||
return Model(weights, names, max_ngram_length)
|
return Model(weights, names, max_ngram_length)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user