diff --git a/gptc/model.py b/gptc/model.py index 1f2840d..3b04827 100644 --- a/gptc/model.py +++ b/gptc/model.py @@ -81,7 +81,7 @@ class Model: for index, category in enumerate(self.names) } - def serialize(self, file: BinaryIO): + def serialize(self, file: BinaryIO) -> None: file.write(b"GPTC model v5\n") file.write( json.dumps( @@ -134,7 +134,7 @@ def deserialize(encoded_model: BinaryIO) -> Model: weight_code_length = 6 + 2 * len(names) - weights: Dict[int : List[int]] = {} + weights: Dict[int, List[int]] = {} while True: code = encoded_model.read(weight_code_length) diff --git a/gptc/tokenizer.py b/gptc/tokenizer.py index 26fa205..337779d 100644 --- a/gptc/tokenizer.py +++ b/gptc/tokenizer.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: GPL-3.0-or-later -from typing import List, Union, Callable +from typing import List, Union, Callable, Any, cast import hashlib import emoji import unicodedata @@ -45,13 +45,13 @@ def tokenize(text: str, max_ngram_length: int = 1) -> List[str]: return ngrams -def _hash_single(token: str, hash_function: Callable) -> int: +def _hash_single(token: str, hash_function: type) -> int: return int.from_bytes( hash_function(token.encode("utf-8")).digest()[:6], "big" ) -def _get_hash_function(hash_algorithm: str) -> Callable: +def _get_hash_function(hash_algorithm: str) -> type: if hash_algorithm in { "sha224", "md5", @@ -68,12 +68,12 @@ def _get_hash_function(hash_algorithm: str) -> Callable: "shake_128", "sha3_384", }: - return getattr(hashlib, hash_algorithm) + return cast(type, getattr(hashlib, hash_algorithm)) else: raise ValueError("not a valid hash function: " + hash_algorithm) -def hash_single(token: str, hash_algorithm: str) -> List[int]: +def hash_single(token: str, hash_algorithm: str) -> int: return _hash_single(token, _get_hash_function(hash_algorithm))