Fix type annotations

This commit is contained in:
Samuel Sloniker 2022-12-24 12:48:43 -08:00
parent 099e810a18
commit aad590636a
Signed by: kj7rrv
GPG Key ID: 1BB4029E66285A62
2 changed files with 7 additions and 7 deletions

View File

@ -81,7 +81,7 @@ class Model:
for index, category in enumerate(self.names) for index, category in enumerate(self.names)
} }
def serialize(self, file: BinaryIO): def serialize(self, file: BinaryIO) -> None:
file.write(b"GPTC model v5\n") file.write(b"GPTC model v5\n")
file.write( file.write(
json.dumps( json.dumps(
@ -134,7 +134,7 @@ def deserialize(encoded_model: BinaryIO) -> Model:
weight_code_length = 6 + 2 * len(names) weight_code_length = 6 + 2 * len(names)
weights: Dict[int : List[int]] = {} weights: Dict[int, List[int]] = {}
while True: while True:
code = encoded_model.read(weight_code_length) code = encoded_model.read(weight_code_length)

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
from typing import List, Union, Callable from typing import List, Union, Callable, Any, cast
import hashlib import hashlib
import emoji import emoji
import unicodedata import unicodedata
@ -45,13 +45,13 @@ def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
return ngrams return ngrams
def _hash_single(token: str, hash_function: Callable) -> int: def _hash_single(token: str, hash_function: type) -> int:
return int.from_bytes( return int.from_bytes(
hash_function(token.encode("utf-8")).digest()[:6], "big" hash_function(token.encode("utf-8")).digest()[:6], "big"
) )
def _get_hash_function(hash_algorithm: str) -> Callable: def _get_hash_function(hash_algorithm: str) -> type:
if hash_algorithm in { if hash_algorithm in {
"sha224", "sha224",
"md5", "md5",
@ -68,12 +68,12 @@ def _get_hash_function(hash_algorithm: str) -> Callable:
"shake_128", "shake_128",
"sha3_384", "sha3_384",
}: }:
return getattr(hashlib, hash_algorithm) return cast(type, getattr(hashlib, hash_algorithm))
else: else:
raise ValueError("not a valid hash function: " + hash_algorithm) raise ValueError("not a valid hash function: " + hash_algorithm)
def hash_single(token: str, hash_algorithm: str) -> List[int]: def hash_single(token: str, hash_algorithm: str) -> int:
return _hash_single(token, _get_hash_function(hash_algorithm)) return _hash_single(token, _get_hash_function(hash_algorithm))