Compare commits

..

No commits in common. "071656c2d2498f188ab1c866c1a36c6a0b1a2f8d" and "822aa7d1fdd06da7475938d8b8edaf38a3ea84f0" have entirely different histories.

3 changed files with 9 additions and 19 deletions

View File

@ -70,9 +70,7 @@ class Model:
def get(self, token: str) -> Dict[str, float]:
try:
weights = self.weights[
gptc.tokenizer.hash_single(
gptc.tokenizer.normalize(token), self.hash_algorithm
)
gptc.tokenizer.hash_single(gptc.tokenizer.normalize(token))
]
except KeyError:
return {}
@ -81,7 +79,7 @@ class Model:
for index, category in enumerate(self.names)
}
def serialize(self, file: BinaryIO) -> None:
def serialize(self, file: BinaryIO):
file.write(b"GPTC model v5\n")
file.write(
json.dumps(
@ -134,7 +132,7 @@ def deserialize(encoded_model: BinaryIO) -> Model:
weight_code_length = 6 + 2 * len(names)
weights: Dict[int, List[int]] = {}
weights: Dict[int : List[int]] = {}
while True:
code = encoded_model.read(weight_code_length)

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: GPL-3.0-or-later
from typing import List, Union, Callable, Any, cast
from typing import List, Union, Callable
import hashlib
import emoji
import unicodedata
@ -45,13 +45,13 @@ def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
return ngrams
def _hash_single(token: str, hash_function: type) -> int:
def hash_single(token: str, hash_function: Callable) -> int:
return int.from_bytes(
hash_function(token.encode("utf-8")).digest()[:6], "big"
)
def _get_hash_function(hash_algorithm: str) -> type:
def hash(tokens: List[str], hash_algorithm: str) -> List[int]:
if hash_algorithm in {
"sha224",
"md5",
@ -68,19 +68,11 @@ def _get_hash_function(hash_algorithm: str) -> type:
"shake_128",
"sha3_384",
}:
return cast(type, getattr(hashlib, hash_algorithm))
hash_function = getattr(hashlib, hash_algorithm)
return [hash_single(token, hash_function) for token in tokens]
else:
raise ValueError("not a valid hash function: " + hash_algorithm)
def hash_single(token: str, hash_algorithm: str) -> int:
return _hash_single(token, _get_hash_function(hash_algorithm))
def hash(tokens: List[str], hash_algorithm: str) -> List[int]:
hash_function = _get_hash_function(hash_algorithm)
return [_hash_single(token, hash_function) for token in tokens]
def normalize(text: str) -> str:
return " ".join(tokenize(text, 1))

View File

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "gptc"
version = "4.0.1"
version = "4.0.0"
description = "General-purpose text classifier"
readme = "README.md"
authors = [{ name = "Samuel Sloniker", email = "sam@kj7rrv.com"}]