Fix check

This commit is contained in:
Samuel Sloniker 2022-12-24 12:44:09 -08:00
parent 822aa7d1fd
commit 099e810a18
Signed by: kj7rrv
GPG Key ID: 1BB4029E66285A62
2 changed files with 15 additions and 5 deletions

View File

@ -70,7 +70,9 @@ class Model:
def get(self, token: str) -> Dict[str, float]: def get(self, token: str) -> Dict[str, float]:
try: try:
weights = self.weights[ weights = self.weights[
gptc.tokenizer.hash_single(gptc.tokenizer.normalize(token)) gptc.tokenizer.hash_single(
gptc.tokenizer.normalize(token), self.hash_algorithm
)
] ]
except KeyError: except KeyError:
return {} return {}

View File

@ -45,13 +45,13 @@ def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
return ngrams return ngrams
def hash_single(token: str, hash_function: Callable) -> int: def _hash_single(token: str, hash_function: Callable) -> int:
return int.from_bytes( return int.from_bytes(
hash_function(token.encode("utf-8")).digest()[:6], "big" hash_function(token.encode("utf-8")).digest()[:6], "big"
) )
def hash(tokens: List[str], hash_algorithm: str) -> List[int]: def _get_hash_function(hash_algorithm: str) -> Callable:
if hash_algorithm in { if hash_algorithm in {
"sha224", "sha224",
"md5", "md5",
@ -68,11 +68,19 @@ def hash(tokens: List[str], hash_algorithm: str) -> List[int]:
"shake_128", "shake_128",
"sha3_384", "sha3_384",
}: }:
hash_function = getattr(hashlib, hash_algorithm) return getattr(hashlib, hash_algorithm)
return [hash_single(token, hash_function) for token in tokens]
else: else:
raise ValueError("not a valid hash function: " + hash_algorithm) raise ValueError("not a valid hash function: " + hash_algorithm)
def hash_single(token: str, hash_algorithm: str) -> List[int]:
return _hash_single(token, _get_hash_function(hash_algorithm))
def hash(tokens: List[str], hash_algorithm: str) -> List[int]:
hash_function = _get_hash_function(hash_algorithm)
return [_hash_single(token, hash_function) for token in tokens]
def normalize(text: str) -> str: def normalize(text: str) -> str:
return " ".join(tokenize(text, 1)) return " ".join(tokenize(text, 1))