parent
6f21e0d4e9
commit
f8dbc78b82
22
README.md
22
README.md
|
@ -69,7 +69,7 @@ Return a confidence dict for the given token or ngram. This function is very
|
||||||
similar to `Model.confidence()`, except it treats the input as a single token
|
similar to `Model.confidence()`, except it treats the input as a single token
|
||||||
or ngram.
|
or ngram.
|
||||||
|
|
||||||
### `gptc.compile(raw_model, max_ngram_length=1, min_count=1)`
|
### `gptc.compile(raw_model, max_ngram_length=1, min_count=1, hash_algorithm="sha256")`
|
||||||
|
|
||||||
Compile a raw model (as a list, not JSON) and return the compiled model (as a
|
Compile a raw model (as a list, not JSON) and return the compiled model (as a
|
||||||
`gptc.Model` object).
|
`gptc.Model` object).
|
||||||
|
@ -79,6 +79,26 @@ For information about `max_ngram_length`, see section "Ngrams."
|
||||||
Words or ngrams used less than `min_count` times throughout the input text are
|
Words or ngrams used less than `min_count` times throughout the input text are
|
||||||
excluded from the model.
|
excluded from the model.
|
||||||
|
|
||||||
|
The hash algorithm should be left as the default, which may change with a minor
|
||||||
|
version update, but it can be changed by the application if needed. It is
|
||||||
|
stored in the model, so changing the algorithm does not affect compatibility.
|
||||||
|
The following algorithms are supported:
|
||||||
|
|
||||||
|
* `md5`
|
||||||
|
* `sha1`
|
||||||
|
* `sha224`
|
||||||
|
* `sha256`
|
||||||
|
* `sha384`
|
||||||
|
* `sha512`
|
||||||
|
* `sha3_224`
|
||||||
|
* `sha3_384`
|
||||||
|
* `sha3_256`
|
||||||
|
* `sha3_512`
|
||||||
|
* `shake_128`
|
||||||
|
* `shake_256`
|
||||||
|
* `blake2b`
|
||||||
|
* `blake2s`
|
||||||
|
|
||||||
### `gptc.pack(directory, print_exceptions=False)`
|
### `gptc.pack(directory, print_exceptions=False)`
|
||||||
|
|
||||||
Pack the model in `directory` and return a tuple of the format:
|
Pack the model in `directory` and return a tuple of the format:
|
||||||
|
|
|
@ -9,6 +9,7 @@ def compile(
|
||||||
raw_model: Iterable[Mapping[str, str]],
|
raw_model: Iterable[Mapping[str, str]],
|
||||||
max_ngram_length: int = 1,
|
max_ngram_length: int = 1,
|
||||||
min_count: int = 1,
|
min_count: int = 1,
|
||||||
|
hash_algorithm: str = "sha256",
|
||||||
) -> gptc.model.Model:
|
) -> gptc.model.Model:
|
||||||
"""Compile a raw model.
|
"""Compile a raw model.
|
||||||
|
|
||||||
|
@ -33,7 +34,8 @@ def compile(
|
||||||
|
|
||||||
for portion in raw_model:
|
for portion in raw_model:
|
||||||
text = gptc.tokenizer.hash(
|
text = gptc.tokenizer.hash(
|
||||||
gptc.tokenizer.tokenize(portion["text"], max_ngram_length)
|
gptc.tokenizer.tokenize(portion["text"], max_ngram_length),
|
||||||
|
hash_algorithm,
|
||||||
)
|
)
|
||||||
category = portion["category"]
|
category = portion["category"]
|
||||||
|
|
||||||
|
@ -68,4 +70,4 @@ def compile(
|
||||||
)
|
)
|
||||||
model[word] = new_weights
|
model[word] = new_weights
|
||||||
|
|
||||||
return gptc.model.Model(model, names, max_ngram_length)
|
return gptc.model.Model(model, names, max_ngram_length, hash_algorithm)
|
||||||
|
|
|
@ -13,10 +13,12 @@ class Model:
|
||||||
weights: Dict[int, List[int]],
|
weights: Dict[int, List[int]],
|
||||||
names: List[str],
|
names: List[str],
|
||||||
max_ngram_length: int,
|
max_ngram_length: int,
|
||||||
|
hash_algorithm: str,
|
||||||
):
|
):
|
||||||
self.weights = weights
|
self.weights = weights
|
||||||
self.names = names
|
self.names = names
|
||||||
self.max_ngram_length = max_ngram_length
|
self.max_ngram_length = max_ngram_length
|
||||||
|
self.hash_algorithm = hash_algorithm
|
||||||
|
|
||||||
def confidence(self, text: str, max_ngram_length: int) -> Dict[str, float]:
|
def confidence(self, text: str, max_ngram_length: int) -> Dict[str, float]:
|
||||||
"""Classify text with confidence.
|
"""Classify text with confidence.
|
||||||
|
@ -42,7 +44,8 @@ class Model:
|
||||||
tokens = gptc.tokenizer.hash(
|
tokens = gptc.tokenizer.hash(
|
||||||
gptc.tokenizer.tokenize(
|
gptc.tokenizer.tokenize(
|
||||||
text, min(max_ngram_length, self.max_ngram_length)
|
text, min(max_ngram_length, self.max_ngram_length)
|
||||||
)
|
),
|
||||||
|
self.hash_algorithm,
|
||||||
)
|
)
|
||||||
numbered_probs: Dict[int, float] = {}
|
numbered_probs: Dict[int, float] = {}
|
||||||
for word in tokens:
|
for word in tokens:
|
||||||
|
@ -83,6 +86,7 @@ class Model:
|
||||||
{
|
{
|
||||||
"names": self.names,
|
"names": self.names,
|
||||||
"max_ngram_length": self.max_ngram_length,
|
"max_ngram_length": self.max_ngram_length,
|
||||||
|
"hash_algorithm": self.hash_algorithm,
|
||||||
}
|
}
|
||||||
).encode("utf-8")
|
).encode("utf-8")
|
||||||
+ b"\n"
|
+ b"\n"
|
||||||
|
@ -117,6 +121,7 @@ def deserialize(encoded_model: BinaryIO) -> Model:
|
||||||
try:
|
try:
|
||||||
names = config["names"]
|
names = config["names"]
|
||||||
max_ngram_length = config["max_ngram_length"]
|
max_ngram_length = config["max_ngram_length"]
|
||||||
|
hash_algorithm = config["hash_algorithm"]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise InvalidModelError()
|
raise InvalidModelError()
|
||||||
|
|
||||||
|
@ -141,4 +146,4 @@ def deserialize(encoded_model: BinaryIO) -> Model:
|
||||||
for value in [code[x : x + 2] for x in range(6, len(code), 2)]
|
for value in [code[x : x + 2] for x in range(6, len(code), 2)]
|
||||||
]
|
]
|
||||||
|
|
||||||
return Model(weights, names, max_ngram_length)
|
return Model(weights, names, max_ngram_length, hash_algorithm)
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||||
|
|
||||||
from typing import List, Union
|
from typing import List, Union, Callable
|
||||||
import hashlib
|
import hashlib
|
||||||
import emoji
|
import emoji
|
||||||
import unicodedata
|
import unicodedata
|
||||||
|
@ -45,14 +45,33 @@ def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
|
||||||
return ngrams
|
return ngrams
|
||||||
|
|
||||||
|
|
||||||
def hash_single(token: str) -> int:
|
def hash_single(token: str, hash_function: Callable) -> int:
|
||||||
return int.from_bytes(
|
return int.from_bytes(
|
||||||
hashlib.sha256(token.encode("utf-8")).digest()[:6], "big"
|
hash_function(token.encode("utf-8")).digest()[:6], "big"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def hash(tokens: List[str]) -> List[int]:
|
def hash(tokens: List[str], hash_algorithm: str) -> List[int]:
|
||||||
return [hash_single(token) for token in tokens]
|
if hash_algorithm in {
|
||||||
|
"sha224",
|
||||||
|
"md5",
|
||||||
|
"sha512",
|
||||||
|
"sha3_256",
|
||||||
|
"blake2s",
|
||||||
|
"sha3_224",
|
||||||
|
"sha1",
|
||||||
|
"sha256",
|
||||||
|
"sha384",
|
||||||
|
"shake_256",
|
||||||
|
"blake2b",
|
||||||
|
"sha3_512",
|
||||||
|
"shake_128",
|
||||||
|
"sha3_384",
|
||||||
|
}:
|
||||||
|
hash_function = getattr(hashlib, hash_algorithm)
|
||||||
|
return [hash_single(token, hash_function) for token in tokens]
|
||||||
|
else:
|
||||||
|
raise ValueError("not a valid hash function: " + hash_algorithm)
|
||||||
|
|
||||||
|
|
||||||
def normalize(text: str) -> str:
|
def normalize(text: str) -> str:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user