Allow hash algorithm selection

Closes #9
This commit is contained in:
Samuel Sloniker 2022-12-24 11:18:05 -08:00
parent 6f21e0d4e9
commit f8dbc78b82
Signed by: kj7rrv
GPG Key ID: 1BB4029E66285A62
4 changed files with 56 additions and 10 deletions

View File

@ -69,7 +69,7 @@ Return a confidence dict for the given token or ngram. This function is very
similar to `Model.confidence()`, except it treats the input as a single token
or ngram.
### `gptc.compile(raw_model, max_ngram_length=1, min_count=1)`
### `gptc.compile(raw_model, max_ngram_length=1, min_count=1, hash_algorithm="sha256")`
Compile a raw model (as a list, not JSON) and return the compiled model (as a
`gptc.Model` object).
@ -79,6 +79,26 @@ For information about `max_ngram_length`, see section "Ngrams."
Words or ngrams used less than `min_count` times throughout the input text are
excluded from the model.
The hash algorithm should be left as the default, which may change with a minor
version update, but it can be changed by the application if needed. It is
stored in the model, so changing the algorithm does not affect compatibility.
The following algorithms are supported:
* `md5`
* `sha1`
* `sha224`
* `sha256`
* `sha384`
* `sha512`
* `sha3_224`
* `sha3_384`
* `sha3_256`
* `sha3_512`
* `shake_128`
* `shake_256`
* `blake2b`
* `blake2s`
### `gptc.pack(directory, print_exceptions=False)`
Pack the model in `directory` and return a tuple of the format:

View File

@ -9,6 +9,7 @@ def compile(
raw_model: Iterable[Mapping[str, str]],
max_ngram_length: int = 1,
min_count: int = 1,
hash_algorithm: str = "sha256",
) -> gptc.model.Model:
"""Compile a raw model.
@ -33,7 +34,8 @@ def compile(
for portion in raw_model:
text = gptc.tokenizer.hash(
gptc.tokenizer.tokenize(portion["text"], max_ngram_length)
gptc.tokenizer.tokenize(portion["text"], max_ngram_length),
hash_algorithm,
)
category = portion["category"]
@ -68,4 +70,4 @@ def compile(
)
model[word] = new_weights
return gptc.model.Model(model, names, max_ngram_length)
return gptc.model.Model(model, names, max_ngram_length, hash_algorithm)

View File

@ -13,10 +13,12 @@ class Model:
weights: Dict[int, List[int]],
names: List[str],
max_ngram_length: int,
hash_algorithm: str,
):
self.weights = weights
self.names = names
self.max_ngram_length = max_ngram_length
self.hash_algorithm = hash_algorithm
def confidence(self, text: str, max_ngram_length: int) -> Dict[str, float]:
"""Classify text with confidence.
@ -42,7 +44,8 @@ class Model:
tokens = gptc.tokenizer.hash(
gptc.tokenizer.tokenize(
text, min(max_ngram_length, self.max_ngram_length)
)
),
self.hash_algorithm,
)
numbered_probs: Dict[int, float] = {}
for word in tokens:
@ -83,6 +86,7 @@ class Model:
{
"names": self.names,
"max_ngram_length": self.max_ngram_length,
"hash_algorithm": self.hash_algorithm,
}
).encode("utf-8")
+ b"\n"
@ -117,6 +121,7 @@ def deserialize(encoded_model: BinaryIO) -> Model:
try:
names = config["names"]
max_ngram_length = config["max_ngram_length"]
hash_algorithm = config["hash_algorithm"]
except KeyError:
raise InvalidModelError()
@ -141,4 +146,4 @@ def deserialize(encoded_model: BinaryIO) -> Model:
for value in [code[x : x + 2] for x in range(6, len(code), 2)]
]
return Model(weights, names, max_ngram_length)
return Model(weights, names, max_ngram_length, hash_algorithm)

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: GPL-3.0-or-later
from typing import List, Union
from typing import List, Union, Callable
import hashlib
import emoji
import unicodedata
@ -45,14 +45,33 @@ def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
return ngrams
def hash_single(token: str) -> int:
def hash_single(token: str, hash_function: Callable) -> int:
return int.from_bytes(
hashlib.sha256(token.encode("utf-8")).digest()[:6], "big"
hash_function(token.encode("utf-8")).digest()[:6], "big"
)
def hash(tokens: List[str]) -> List[int]:
return [hash_single(token) for token in tokens]
def hash(tokens: List[str], hash_algorithm: str) -> List[int]:
if hash_algorithm in {
"sha224",
"md5",
"sha512",
"sha3_256",
"blake2s",
"sha3_224",
"sha1",
"sha256",
"sha384",
"shake_256",
"blake2b",
"sha3_512",
"shake_128",
"sha3_384",
}:
hash_function = getattr(hashlib, hash_algorithm)
return [hash_single(token, hash_function) for token in tokens]
else:
raise ValueError("not a valid hash function: " + hash_algorithm)
def normalize(text: str) -> str: