Compare commits
3 Commits
822aa7d1fd
...
071656c2d2
Author | SHA1 | Date | |
---|---|---|---|
071656c2d2 | |||
aad590636a | |||
099e810a18 |
|
@ -70,7 +70,9 @@ class Model:
|
||||||
def get(self, token: str) -> Dict[str, float]:
|
def get(self, token: str) -> Dict[str, float]:
|
||||||
try:
|
try:
|
||||||
weights = self.weights[
|
weights = self.weights[
|
||||||
gptc.tokenizer.hash_single(gptc.tokenizer.normalize(token))
|
gptc.tokenizer.hash_single(
|
||||||
|
gptc.tokenizer.normalize(token), self.hash_algorithm
|
||||||
|
)
|
||||||
]
|
]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return {}
|
return {}
|
||||||
|
@ -79,7 +81,7 @@ class Model:
|
||||||
for index, category in enumerate(self.names)
|
for index, category in enumerate(self.names)
|
||||||
}
|
}
|
||||||
|
|
||||||
def serialize(self, file: BinaryIO):
|
def serialize(self, file: BinaryIO) -> None:
|
||||||
file.write(b"GPTC model v5\n")
|
file.write(b"GPTC model v5\n")
|
||||||
file.write(
|
file.write(
|
||||||
json.dumps(
|
json.dumps(
|
||||||
|
@ -132,7 +134,7 @@ def deserialize(encoded_model: BinaryIO) -> Model:
|
||||||
|
|
||||||
weight_code_length = 6 + 2 * len(names)
|
weight_code_length = 6 + 2 * len(names)
|
||||||
|
|
||||||
weights: Dict[int : List[int]] = {}
|
weights: Dict[int, List[int]] = {}
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
code = encoded_model.read(weight_code_length)
|
code = encoded_model.read(weight_code_length)
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||||
|
|
||||||
from typing import List, Union, Callable
|
from typing import List, Union, Callable, Any, cast
|
||||||
import hashlib
|
import hashlib
|
||||||
import emoji
|
import emoji
|
||||||
import unicodedata
|
import unicodedata
|
||||||
|
@ -45,13 +45,13 @@ def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
|
||||||
return ngrams
|
return ngrams
|
||||||
|
|
||||||
|
|
||||||
def hash_single(token: str, hash_function: Callable) -> int:
|
def _hash_single(token: str, hash_function: type) -> int:
|
||||||
return int.from_bytes(
|
return int.from_bytes(
|
||||||
hash_function(token.encode("utf-8")).digest()[:6], "big"
|
hash_function(token.encode("utf-8")).digest()[:6], "big"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def hash(tokens: List[str], hash_algorithm: str) -> List[int]:
|
def _get_hash_function(hash_algorithm: str) -> type:
|
||||||
if hash_algorithm in {
|
if hash_algorithm in {
|
||||||
"sha224",
|
"sha224",
|
||||||
"md5",
|
"md5",
|
||||||
|
@ -68,11 +68,19 @@ def hash(tokens: List[str], hash_algorithm: str) -> List[int]:
|
||||||
"shake_128",
|
"shake_128",
|
||||||
"sha3_384",
|
"sha3_384",
|
||||||
}:
|
}:
|
||||||
hash_function = getattr(hashlib, hash_algorithm)
|
return cast(type, getattr(hashlib, hash_algorithm))
|
||||||
return [hash_single(token, hash_function) for token in tokens]
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("not a valid hash function: " + hash_algorithm)
|
raise ValueError("not a valid hash function: " + hash_algorithm)
|
||||||
|
|
||||||
|
|
||||||
|
def hash_single(token: str, hash_algorithm: str) -> int:
|
||||||
|
return _hash_single(token, _get_hash_function(hash_algorithm))
|
||||||
|
|
||||||
|
|
||||||
|
def hash(tokens: List[str], hash_algorithm: str) -> List[int]:
|
||||||
|
hash_function = _get_hash_function(hash_algorithm)
|
||||||
|
return [_hash_single(token, hash_function) for token in tokens]
|
||||||
|
|
||||||
|
|
||||||
def normalize(text: str) -> str:
|
def normalize(text: str) -> str:
|
||||||
return " ".join(tokenize(text, 1))
|
return " ".join(tokenize(text, 1))
|
||||||
|
|
|
@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "gptc"
|
name = "gptc"
|
||||||
version = "4.0.0"
|
version = "4.0.1"
|
||||||
description = "General-purpose text classifier"
|
description = "General-purpose text classifier"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
authors = [{ name = "Samuel Sloniker", email = "sam@kj7rrv.com"}]
|
authors = [{ name = "Samuel Sloniker", email = "sam@kj7rrv.com"}]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user