From f8dbc78b8231ae551f6ec61514644518190e5e21 Mon Sep 17 00:00:00 2001 From: Samuel Sloniker Date: Sat, 24 Dec 2022 11:18:05 -0800 Subject: [PATCH] Allow hash algorithm selection Closes #9 --- README.md | 22 +++++++++++++++++++++- gptc/compiler.py | 6 ++++-- gptc/model.py | 9 +++++++-- gptc/tokenizer.py | 29 ++++++++++++++++++++++++----- 4 files changed, 56 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index d25779f..6e1a5a9 100644 --- a/README.md +++ b/README.md @@ -69,7 +69,7 @@ Return a confidence dict for the given token or ngram. This function is very similar to `Model.confidence()`, except it treats the input as a single token or ngram. -### `gptc.compile(raw_model, max_ngram_length=1, min_count=1)` +### `gptc.compile(raw_model, max_ngram_length=1, min_count=1, hash_algorithm="sha256")` Compile a raw model (as a list, not JSON) and return the compiled model (as a `gptc.Model` object). @@ -79,6 +79,26 @@ For information about `max_ngram_length`, see section "Ngrams." Words or ngrams used less than `min_count` times throughout the input text are excluded from the model. +The hash algorithm should be left as the default, which may change with a minor +version update, but it can be changed by the application if needed. It is +stored in the model, so changing the algorithm does not affect compatibility. +The following algorithms are supported: + +* `md5` +* `sha1` +* `sha224` +* `sha256` +* `sha384` +* `sha512` +* `sha3_224` +* `sha3_384` +* `sha3_256` +* `sha3_512` +* `shake_128` +* `shake_256` +* `blake2b` +* `blake2s` + ### `gptc.pack(directory, print_exceptions=False)` Pack the model in `directory` and return a tuple of the format: diff --git a/gptc/compiler.py b/gptc/compiler.py index 7eedb15..c299a4b 100755 --- a/gptc/compiler.py +++ b/gptc/compiler.py @@ -9,6 +9,7 @@ def compile( raw_model: Iterable[Mapping[str, str]], max_ngram_length: int = 1, min_count: int = 1, + hash_algorithm: str = "sha256", ) -> gptc.model.Model: """Compile a raw model. @@ -33,7 +34,8 @@ def compile( for portion in raw_model: text = gptc.tokenizer.hash( - gptc.tokenizer.tokenize(portion["text"], max_ngram_length) + gptc.tokenizer.tokenize(portion["text"], max_ngram_length), + hash_algorithm, ) category = portion["category"] @@ -68,4 +70,4 @@ def compile( ) model[word] = new_weights - return gptc.model.Model(model, names, max_ngram_length) + return gptc.model.Model(model, names, max_ngram_length, hash_algorithm) diff --git a/gptc/model.py b/gptc/model.py index aa00f5a..f2d6760 100644 --- a/gptc/model.py +++ b/gptc/model.py @@ -13,10 +13,12 @@ class Model: weights: Dict[int, List[int]], names: List[str], max_ngram_length: int, + hash_algorithm: str, ): self.weights = weights self.names = names self.max_ngram_length = max_ngram_length + self.hash_algorithm = hash_algorithm def confidence(self, text: str, max_ngram_length: int) -> Dict[str, float]: """Classify text with confidence. @@ -42,7 +44,8 @@ class Model: tokens = gptc.tokenizer.hash( gptc.tokenizer.tokenize( text, min(max_ngram_length, self.max_ngram_length) - ) + ), + self.hash_algorithm, ) numbered_probs: Dict[int, float] = {} for word in tokens: @@ -83,6 +86,7 @@ class Model: { "names": self.names, "max_ngram_length": self.max_ngram_length, + "hash_algorithm": self.hash_algorithm, } ).encode("utf-8") + b"\n" @@ -117,6 +121,7 @@ def deserialize(encoded_model: BinaryIO) -> Model: try: names = config["names"] max_ngram_length = config["max_ngram_length"] + hash_algorithm = config["hash_algorithm"] except KeyError: raise InvalidModelError() @@ -141,4 +146,4 @@ def deserialize(encoded_model: BinaryIO) -> Model: for value in [code[x : x + 2] for x in range(6, len(code), 2)] ] - return Model(weights, names, max_ngram_length) + return Model(weights, names, max_ngram_length, hash_algorithm) diff --git a/gptc/tokenizer.py b/gptc/tokenizer.py index 33a2744..bd75685 100644 --- a/gptc/tokenizer.py +++ b/gptc/tokenizer.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: GPL-3.0-or-later -from typing import List, Union +from typing import List, Union, Callable import hashlib import emoji import unicodedata @@ -45,14 +45,33 @@ def tokenize(text: str, max_ngram_length: int = 1) -> List[str]: return ngrams -def hash_single(token: str) -> int: +def hash_single(token: str, hash_function: Callable) -> int: return int.from_bytes( - hashlib.sha256(token.encode("utf-8")).digest()[:6], "big" + hash_function(token.encode("utf-8")).digest()[:6], "big" ) -def hash(tokens: List[str]) -> List[int]: - return [hash_single(token) for token in tokens] +def hash(tokens: List[str], hash_algorithm: str) -> List[int]: + if hash_algorithm in { + "sha224", + "md5", + "sha512", + "sha3_256", + "blake2s", + "sha3_224", + "sha1", + "sha256", + "sha384", + "shake_256", + "blake2b", + "sha3_512", + "shake_128", + "sha3_384", + }: + hash_function = getattr(hashlib, hash_algorithm) + return [hash_single(token, hash_function) for token in tokens] + else: + raise ValueError("not a valid hash function: " + hash_algorithm) def normalize(text: str) -> str: