|
|
|
@ -1,22 +1,18 @@
|
|
|
|
|
# SPDX-License-Identifier: GPL-3.0-or-later |
|
|
|
|
|
|
|
|
|
import gptc.tokenizer |
|
|
|
|
from gptc.exceptions import InvalidModelError |
|
|
|
|
import gptc.weighting |
|
|
|
|
from typing import ( |
|
|
|
|
Iterable, |
|
|
|
|
Mapping, |
|
|
|
|
List, |
|
|
|
|
Dict, |
|
|
|
|
Union, |
|
|
|
|
cast, |
|
|
|
|
BinaryIO, |
|
|
|
|
Any, |
|
|
|
|
Tuple, |
|
|
|
|
TypedDict, |
|
|
|
|
) |
|
|
|
|
import json |
|
|
|
|
import collections |
|
|
|
|
import gptc.tokenizer |
|
|
|
|
from gptc.exceptions import InvalidModelError |
|
|
|
|
import gptc.weighting |
|
|
|
|
import gptc.compiler |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ExplanationEntry(TypedDict): |
|
|
|
@ -33,6 +29,21 @@ Explanation = Dict[
|
|
|
|
|
Log = List[Tuple[str, float, List[float]]] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Confidences(dict[str, float]): |
|
|
|
|
def __init__(self, probs: Dict[str, float]): |
|
|
|
|
dict.__init__(self, probs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TransparentConfidences(Confidences): |
|
|
|
|
def __init__( |
|
|
|
|
self, |
|
|
|
|
probs: Dict[str, float], |
|
|
|
|
explanation: Explanation, |
|
|
|
|
): |
|
|
|
|
self.explanation = explanation |
|
|
|
|
Confidences.__init__(self, probs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert_log(log: Log, names: List[str]) -> Explanation: |
|
|
|
|
explanation: Explanation = {} |
|
|
|
|
for word2, weight, word_probs in log: |
|
|
|
@ -49,33 +60,6 @@ def convert_log(log: Log, names: List[str]) -> Explanation:
|
|
|
|
|
return explanation |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Confidences(collections.UserDict[str, float]): |
|
|
|
|
def __init__( |
|
|
|
|
self, |
|
|
|
|
probs: Dict[str, float], |
|
|
|
|
model: Model, |
|
|
|
|
text: str, |
|
|
|
|
max_ngram_length: int, |
|
|
|
|
): |
|
|
|
|
collections.UserDict.__init__(self, probs) |
|
|
|
|
self.model = model |
|
|
|
|
self.text = text |
|
|
|
|
self.max_ngram_length = max_ngram_length |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TransparentConfidences(Confidences): |
|
|
|
|
def __init__( |
|
|
|
|
self, |
|
|
|
|
probs: Dict[str, float], |
|
|
|
|
explanation: Explanation, |
|
|
|
|
model: Model, |
|
|
|
|
text: str, |
|
|
|
|
max_ngram_length: int, |
|
|
|
|
): |
|
|
|
|
Confidences.__init__(self, probs, model, text, max_ngram_length) |
|
|
|
|
self.explanation = explanation |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Model: |
|
|
|
|
def __init__( |
|
|
|
|
self, |
|
|
|
@ -117,7 +101,7 @@ class Model:
|
|
|
|
|
text, min(max_ngram_length, self.max_ngram_length) |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
tokens = gptc.tokenizer.hash( |
|
|
|
|
tokens = gptc.tokenizer.hash_list( |
|
|
|
|
raw_tokens, |
|
|
|
|
self.hash_algorithm, |
|
|
|
|
) |
|
|
|
@ -163,12 +147,9 @@ class Model:
|
|
|
|
|
|
|
|
|
|
if transparent: |
|
|
|
|
explanation = convert_log(log, self.names) |
|
|
|
|
return TransparentConfidences(probs, explanation) |
|
|
|
|
|
|
|
|
|
return TransparentConfidences( |
|
|
|
|
probs, explanation, self, text, max_ngram_length |
|
|
|
|
) |
|
|
|
|
else: |
|
|
|
|
return Confidences(probs, self, text, max_ngram_length) |
|
|
|
|
return Confidences(probs) |
|
|
|
|
|
|
|
|
|
def get(self, token: str) -> Dict[str, float]: |
|
|
|
|
try: |
|
|
|
@ -202,6 +183,8 @@ class Model:
|
|
|
|
|
+ b"".join([weight.to_bytes(2, "big") for weight in weights]) |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
compile = staticmethod(gptc.compiler.compile_) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def deserialize(encoded_model: BinaryIO) -> Model: |
|
|
|
|
prefix = encoded_model.read(14) |
|
|
|
@ -213,26 +196,27 @@ def deserialize(encoded_model: BinaryIO) -> Model:
|
|
|
|
|
byte = encoded_model.read(1) |
|
|
|
|
if byte == b"\n": |
|
|
|
|
break |
|
|
|
|
elif byte == b"": |
|
|
|
|
|
|
|
|
|
if byte == b"": |
|
|
|
|
raise InvalidModelError() |
|
|
|
|
else: |
|
|
|
|
config_json += byte |
|
|
|
|
|
|
|
|
|
config_json += byte |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
config = json.loads(config_json.decode("utf-8")) |
|
|
|
|
except (UnicodeDecodeError, json.JSONDecodeError): |
|
|
|
|
raise InvalidModelError() |
|
|
|
|
except (UnicodeDecodeError, json.JSONDecodeError) as exc: |
|
|
|
|
raise InvalidModelError() from exc |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
names = config["names"] |
|
|
|
|
max_ngram_length = config["max_ngram_length"] |
|
|
|
|
hash_algorithm = config["hash_algorithm"] |
|
|
|
|
except KeyError: |
|
|
|
|
raise InvalidModelError() |
|
|
|
|
except KeyError as exc: |
|
|
|
|
raise InvalidModelError() from exc |
|
|
|
|
|
|
|
|
|
if not ( |
|
|
|
|
isinstance(names, list) and isinstance(max_ngram_length, int) |
|
|
|
|
) or not all([isinstance(name, str) for name in names]): |
|
|
|
|
) or not all(isinstance(name, str) for name in names): |
|
|
|
|
raise InvalidModelError() |
|
|
|
|
|
|
|
|
|
weight_code_length = 6 + 2 * len(names) |
|
|
|
@ -243,7 +227,7 @@ def deserialize(encoded_model: BinaryIO) -> Model:
|
|
|
|
|
code = encoded_model.read(weight_code_length) |
|
|
|
|
if not code: |
|
|
|
|
break |
|
|
|
|
elif len(code) != weight_code_length: |
|
|
|
|
if len(code) != weight_code_length: |
|
|
|
|
raise InvalidModelError() |
|
|
|
|
|
|
|
|
|
weights[int.from_bytes(code[:6], "big")] = [ |
|
|
|
|