diff --git a/gptc/model.py b/gptc/model.py index f2d6760..1f2840d 100644 --- a/gptc/model.py +++ b/gptc/model.py @@ -70,7 +70,9 @@ class Model: def get(self, token: str) -> Dict[str, float]: try: weights = self.weights[ - gptc.tokenizer.hash_single(gptc.tokenizer.normalize(token)) + gptc.tokenizer.hash_single( + gptc.tokenizer.normalize(token), self.hash_algorithm + ) ] except KeyError: return {} diff --git a/gptc/tokenizer.py b/gptc/tokenizer.py index bd75685..26fa205 100644 --- a/gptc/tokenizer.py +++ b/gptc/tokenizer.py @@ -45,13 +45,13 @@ def tokenize(text: str, max_ngram_length: int = 1) -> List[str]: return ngrams -def hash_single(token: str, hash_function: Callable) -> int: +def _hash_single(token: str, hash_function: Callable) -> int: return int.from_bytes( hash_function(token.encode("utf-8")).digest()[:6], "big" ) -def hash(tokens: List[str], hash_algorithm: str) -> List[int]: +def _get_hash_function(hash_algorithm: str) -> Callable: if hash_algorithm in { "sha224", "md5", @@ -68,11 +68,19 @@ def hash(tokens: List[str], hash_algorithm: str) -> List[int]: "shake_128", "sha3_384", }: - hash_function = getattr(hashlib, hash_algorithm) - return [hash_single(token, hash_function) for token in tokens] + return getattr(hashlib, hash_algorithm) else: raise ValueError("not a valid hash function: " + hash_algorithm) +def hash_single(token: str, hash_algorithm: str) -> List[int]: + return _hash_single(token, _get_hash_function(hash_algorithm)) + + +def hash(tokens: List[str], hash_algorithm: str) -> List[int]: + hash_function = _get_hash_function(hash_algorithm) + return [_hash_single(token, hash_function) for token in tokens] + + def normalize(text: str) -> str: return " ".join(tokenize(text, 1))