Fix check
This commit is contained in:
parent
822aa7d1fd
commit
099e810a18
|
@ -70,7 +70,9 @@ class Model:
|
|||
def get(self, token: str) -> Dict[str, float]:
|
||||
try:
|
||||
weights = self.weights[
|
||||
gptc.tokenizer.hash_single(gptc.tokenizer.normalize(token))
|
||||
gptc.tokenizer.hash_single(
|
||||
gptc.tokenizer.normalize(token), self.hash_algorithm
|
||||
)
|
||||
]
|
||||
except KeyError:
|
||||
return {}
|
||||
|
|
|
@ -45,13 +45,13 @@ def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
|
|||
return ngrams
|
||||
|
||||
|
||||
def hash_single(token: str, hash_function: Callable) -> int:
|
||||
def _hash_single(token: str, hash_function: Callable) -> int:
|
||||
return int.from_bytes(
|
||||
hash_function(token.encode("utf-8")).digest()[:6], "big"
|
||||
)
|
||||
|
||||
|
||||
def hash(tokens: List[str], hash_algorithm: str) -> List[int]:
|
||||
def _get_hash_function(hash_algorithm: str) -> Callable:
|
||||
if hash_algorithm in {
|
||||
"sha224",
|
||||
"md5",
|
||||
|
@ -68,11 +68,19 @@ def hash(tokens: List[str], hash_algorithm: str) -> List[int]:
|
|||
"shake_128",
|
||||
"sha3_384",
|
||||
}:
|
||||
hash_function = getattr(hashlib, hash_algorithm)
|
||||
return [hash_single(token, hash_function) for token in tokens]
|
||||
return getattr(hashlib, hash_algorithm)
|
||||
else:
|
||||
raise ValueError("not a valid hash function: " + hash_algorithm)
|
||||
|
||||
|
||||
def hash_single(token: str, hash_algorithm: str) -> List[int]:
|
||||
return _hash_single(token, _get_hash_function(hash_algorithm))
|
||||
|
||||
|
||||
def hash(tokens: List[str], hash_algorithm: str) -> List[int]:
|
||||
hash_function = _get_hash_function(hash_algorithm)
|
||||
return [_hash_single(token, hash_function) for token in tokens]
|
||||
|
||||
|
||||
def normalize(text: str) -> str:
|
||||
return " ".join(tokenize(text, 1))
|
||||
|
|
Loading…
Reference in New Issue
Block a user