New model format
Use Model objects and binary serialization format
This commit is contained in:
parent
f4ae5f851d
commit
a10569b5ab
79
README.md
79
README.md
|
@ -6,9 +6,7 @@ GPTC provides both a CLI tool and a Python library.
|
|||
|
||||
## Installation
|
||||
|
||||
pip install gptc[emoji] # handles emojis! (see section "Emoji")
|
||||
# Or, if you don't need emoji support,
|
||||
pip install gptc # no dependencies!
|
||||
pip install gptc
|
||||
|
||||
## CLI Tool
|
||||
|
||||
|
@ -31,7 +29,7 @@ stdout (or "None" if it cannot determine anything).
|
|||
|
||||
gptc compile [-n <max_ngram_length>] [-c <min_count>] <raw model file>
|
||||
|
||||
This will print the compiled model in JSON to stdout.
|
||||
This will print the compiled model encoded in binary format to stdout.
|
||||
|
||||
If `-c` is specified, words and ngrams used less than `min_count` times will be
|
||||
excluded from the compiled model.
|
||||
|
@ -47,8 +45,8 @@ example of the format. Any exceptions will be printed to stderr.
|
|||
|
||||
### `gptc.Classifier(model, max_ngram_length=1)`
|
||||
|
||||
Create a `Classifier` object using the given *compiled* model (as a dict, not
|
||||
JSON).
|
||||
Create a `Classifier` object using the given compiled model (as a `gptc.Model`
|
||||
object, not as a serialized byte string).
|
||||
|
||||
For information about `max_ngram_length`, see section "Ngrams."
|
||||
|
||||
|
@ -57,6 +55,11 @@ For information about `max_ngram_length`, see section "Ngrams."
|
|||
Classify `text`. Returns a dict of the format `{category: probability,
|
||||
category:probability, ...}`
|
||||
|
||||
Note that this may not include values for all categories. If there are no
|
||||
common words between the input and the training data (likely, for example, with
|
||||
input in a different language from the training data), an empty dict will be
|
||||
returned.
|
||||
|
||||
#### `Classifier.classify(text)`
|
||||
|
||||
Classify `text`. Returns the category into which the text is placed (as a
|
||||
|
@ -66,21 +69,24 @@ string), or `None` when it cannot classify the text.
|
|||
|
||||
The classifier's model.
|
||||
|
||||
#### `Classifier.has_emoji`
|
||||
|
||||
Check whether emojis are supported by the `Classifier`. (See section "Emoji.")
|
||||
Equivalent to `gptc.has_emoji and gptc.model_has_emoji(model)`.
|
||||
|
||||
### `gptc.compile(raw_model, max_ngram_length=1, min_count=1)`
|
||||
|
||||
Compile a raw model (as a list, not JSON) and return the compiled model (as a
|
||||
dict).
|
||||
`gptc.Model` object).
|
||||
|
||||
For information about `max_ngram_length`, see section "Ngrams."
|
||||
|
||||
Words or ngrams used less than `min_count` times throughout the input text are
|
||||
excluded from the model.
|
||||
|
||||
### `gptc.Model.serialize()`
|
||||
|
||||
Returns a `bytes` representing the model.
|
||||
|
||||
### `gptc.deserialize(encoded_model)`
|
||||
|
||||
Deserialize a `Model` from a `bytes` returned by `Model.serialize()`.
|
||||
|
||||
### `gptc.pack(directory, print_exceptions=False)
|
||||
|
||||
Pack the model in `directory` and return a tuple of the format:
|
||||
|
@ -93,50 +99,26 @@ GPTC.
|
|||
|
||||
See `models/unpacked/` for an example of the format.
|
||||
|
||||
### `gptc.has_emoji`
|
||||
|
||||
`True` if the `emoji` package is installed (see section "Emoji"), `False`
|
||||
otherwise.
|
||||
|
||||
### `gptc.model_has_emoji(compiled_model)`
|
||||
|
||||
Returns `True` if `compiled_model` was compiled with emoji support, `False`
|
||||
otherwise.
|
||||
|
||||
## Ngrams
|
||||
|
||||
GPTC optionally supports using ngrams to improve classification accuracy. They
|
||||
are disabled by default (maximum length set to 1) for performance and
|
||||
compatibility reasons. Enabling them significantly increases the time required
|
||||
both for compilation and classification. The effect seems more significant for
|
||||
compilation than for classification. Compiled models are also much larger when
|
||||
ngrams are enabled. Larger maximum ngram lengths will result in slower
|
||||
performance and larger files. It is a good idea to experiment with different
|
||||
values and use the highest one at which GPTC is fast enough and models are
|
||||
small enough for your needs.
|
||||
are disabled by default (maximum length set to 1) for performance reasons.
|
||||
Enabling them significantly increases the time required both for compilation
|
||||
and classification. The effect seems more significant for compilation than for
|
||||
classification. Compiled models are also much larger when ngrams are enabled.
|
||||
Larger maximum ngram lengths will result in slower performance and larger
|
||||
files. It is a good idea to experiment with different values and use the
|
||||
highest one at which GPTC is fast enough and models are small enough for your
|
||||
needs.
|
||||
|
||||
Once a model is compiled at a certain maximum ngram length, it cannot be used
|
||||
for classification with a higher value. If you instantiate a `Classifier` with
|
||||
a model compiled with a lower `max_ngram_length`, the value will be silently
|
||||
reduced to the one used when compiling the model.
|
||||
|
||||
Models compiled with older versions of GPTC which did not support ngrams are
|
||||
handled the same way as models compiled with `max_ngram_length=1`.
|
||||
|
||||
## Emoji
|
||||
|
||||
If the [`emoji`](https://pypi.org/project/emoji/) package is installed, GPTC
|
||||
will automatically handle emojis the same way as words. If it is not installed,
|
||||
GPTC will still work but will ignore emojis.
|
||||
|
||||
`emoji` must be installed on both the system used to compile the model and the
|
||||
system used to classify text. Emojis are ignored if it is missing on either
|
||||
system.
|
||||
|
||||
## Model format
|
||||
|
||||
This section explains the raw model format, which is how you should create and
|
||||
edit models.
|
||||
This section explains the raw model format, which is how models are created and edited.
|
||||
|
||||
Raw models are formatted as a list of dicts. See below for the format:
|
||||
|
||||
|
@ -147,11 +129,14 @@ Raw models are formatted as a list of dicts. See below for the format:
|
|||
}
|
||||
]
|
||||
|
||||
GPTC handles models as Python `list`s of `dict`s of `str`s (for raw models) or
|
||||
`dict`s of `str`s and `float`s (for compiled models), and they can be stored
|
||||
GPTC handles raw models as `list`s of `dict`s of `str`s (`List[Dict[str, str]]`), and they can be stored
|
||||
in any way these Python objects can be. However, it is recommended to store
|
||||
them in JSON format for compatibility with the command-line tool.
|
||||
|
||||
## Emoji
|
||||
|
||||
GPTC treats individual emoji as words.
|
||||
|
||||
## Example model
|
||||
|
||||
An example model, which is designed to distinguish between texts written by
|
||||
|
|
|
@ -6,9 +6,9 @@ from gptc.compiler import compile as compile
|
|||
from gptc.classifier import Classifier as Classifier
|
||||
from gptc.pack import pack as pack
|
||||
from gptc.tokenizer import has_emoji as has_emoji
|
||||
from gptc.model_info import model_has_emoji as model_has_emoji
|
||||
from gptc.model import Model as Model, deserialize as deserialize
|
||||
from gptc.exceptions import (
|
||||
GPTCError as GPTCError,
|
||||
ModelError as ModelError,
|
||||
UnsupportedModelError as UnsupportedModelError,
|
||||
InvalidModelError as InvalidModelError,
|
||||
)
|
||||
|
|
|
@ -66,7 +66,11 @@ def main() -> None:
|
|||
with open(args.model, "r") as f:
|
||||
model = json.load(f)
|
||||
|
||||
print(json.dumps(gptc.compile(model, args.max_ngram_length, args.min_count)))
|
||||
sys.stdout.buffer.write(
|
||||
gptc.compile(
|
||||
model, args.max_ngram_length, args.min_count
|
||||
).serialize()
|
||||
)
|
||||
elif args.subparser_name == "classify":
|
||||
with open(args.model, "r") as f:
|
||||
model = json.load(f)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
|
||||
import gptc.tokenizer, gptc.compiler, gptc.exceptions, gptc.weighting, gptc.model_info
|
||||
import gptc.tokenizer, gptc.compiler, gptc.exceptions, gptc.weighting
|
||||
import warnings
|
||||
from typing import Dict, Union, cast, List
|
||||
|
||||
|
@ -25,17 +25,11 @@ class Classifier:
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self, model: gptc.compiler.MODEL, max_ngram_length: int = 1):
|
||||
if model.get("__version__", 0) != 3:
|
||||
raise gptc.exceptions.UnsupportedModelError(
|
||||
f"unsupported model version"
|
||||
)
|
||||
def __init__(self, model: gptc.model.Model, max_ngram_length: int = 1):
|
||||
self.model = model
|
||||
model_ngrams = cast(int, model.get("__ngrams__", 1))
|
||||
model_ngrams = model.max_ngram_length
|
||||
self.max_ngram_length = min(max_ngram_length, model_ngrams)
|
||||
self.has_emoji = (
|
||||
gptc.tokenizer.has_emoji and gptc.model_info.model_has_emoji(model)
|
||||
)
|
||||
self.has_emoji = gptc.tokenizer.has_emoji and model.has_emoji
|
||||
|
||||
def confidence(self, text: str) -> Dict[str, float]:
|
||||
"""Classify text with confidence.
|
||||
|
@ -53,7 +47,7 @@ class Classifier:
|
|||
|
||||
"""
|
||||
|
||||
model = self.model
|
||||
model = self.model.weights
|
||||
|
||||
tokens = gptc.tokenizer.tokenize(
|
||||
text, self.max_ngram_length, self.has_emoji
|
||||
|
@ -73,7 +67,7 @@ class Classifier:
|
|||
pass
|
||||
total = sum(numbered_probs.values())
|
||||
probs: Dict[str, float] = {
|
||||
cast(List[str], model["__names__"])[category]: value / total
|
||||
self.model.names[category]: value / total
|
||||
for category, value in numbered_probs.items()
|
||||
}
|
||||
return probs
|
||||
|
|
|
@ -1,18 +1,15 @@
|
|||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
|
||||
import gptc.tokenizer
|
||||
import gptc.model
|
||||
from typing import Iterable, Mapping, List, Dict, Union
|
||||
|
||||
WEIGHTS_T = List[int]
|
||||
CONFIG_T = Union[List[str], int, str]
|
||||
MODEL = Dict[str, Union[WEIGHTS_T, CONFIG_T]]
|
||||
|
||||
|
||||
def compile(
|
||||
raw_model: Iterable[Mapping[str, str]],
|
||||
max_ngram_length: int = 1,
|
||||
min_count: int = 1,
|
||||
) -> MODEL:
|
||||
) -> gptc.model.Model:
|
||||
"""Compile a raw model.
|
||||
|
||||
Parameters
|
||||
|
@ -30,7 +27,7 @@ def compile(
|
|||
|
||||
"""
|
||||
|
||||
categories: Dict[str, List[str]] = {}
|
||||
categories: Dict[str, List[int]] = {}
|
||||
|
||||
for portion in raw_model:
|
||||
text = gptc.tokenizer.tokenize(portion["text"], max_ngram_length)
|
||||
|
@ -40,7 +37,7 @@ def compile(
|
|||
except KeyError:
|
||||
categories[category] = text
|
||||
|
||||
word_counts: Dict[str, Dict[str, float]] = {}
|
||||
word_counts: Dict[int, Dict[str, int]] = {}
|
||||
|
||||
names = []
|
||||
|
||||
|
@ -66,7 +63,7 @@ def compile(
|
|||
if sum(counts.values()) >= min_count
|
||||
}
|
||||
|
||||
word_weights: Dict[str, Dict[str, float]] = {}
|
||||
word_weights: Dict[int, Dict[str, float]] = {}
|
||||
for word, values in word_counts.items():
|
||||
for category, value in values.items():
|
||||
try:
|
||||
|
@ -76,7 +73,7 @@ def compile(
|
|||
category: value / len(categories[category])
|
||||
}
|
||||
|
||||
model: MODEL = {}
|
||||
model: Dict[int, List[int]] = {}
|
||||
for word, weights in word_weights.items():
|
||||
total = sum(weights.values())
|
||||
new_weights: List[int] = []
|
||||
|
@ -86,9 +83,4 @@ def compile(
|
|||
)
|
||||
model[word] = new_weights
|
||||
|
||||
model["__names__"] = names
|
||||
model["__ngrams__"] = max_ngram_length
|
||||
model["__version__"] = 3
|
||||
model["__emoji__"] = int(gptc.tokenizer.has_emoji)
|
||||
|
||||
return model
|
||||
return gptc.model.Model(model, names, max_ngram_length)
|
||||
|
|
|
@ -9,5 +9,5 @@ class ModelError(GPTCError):
|
|||
pass
|
||||
|
||||
|
||||
class UnsupportedModelError(ModelError):
|
||||
class InvalidModelError(ModelError):
|
||||
pass
|
||||
|
|
89
gptc/model.py
Normal file
89
gptc/model.py
Normal file
|
@ -0,0 +1,89 @@
|
|||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
|
||||
import gptc.tokenizer
|
||||
from gptc.exceptions import InvalidModelError
|
||||
from typing import Iterable, Mapping, List, Dict, Union
|
||||
import json
|
||||
|
||||
|
||||
class Model:
|
||||
def __init__(
|
||||
self,
|
||||
weights: Dict[int, List[int]],
|
||||
names: List[str],
|
||||
max_ngram_length: int,
|
||||
has_emoji: Union[None, bool] = None,
|
||||
):
|
||||
self.weights = weights
|
||||
self.names = names
|
||||
self.max_ngram_length = max_ngram_length
|
||||
self.has_emoji = (
|
||||
gptc.tokenizer.has_emoji if has_emoji is None else has_emoji
|
||||
)
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
out = b"GPTC model v4\n"
|
||||
out += (
|
||||
json.dumps(
|
||||
{
|
||||
"names": self.names,
|
||||
"max_ngram_length": self.max_ngram_length,
|
||||
"has_emoji": self.has_emoji,
|
||||
}
|
||||
).encode("utf-8")
|
||||
+ b"\n"
|
||||
)
|
||||
for word, weights in self.weights.items():
|
||||
out += word.to_bytes(6, "big") + b"".join(
|
||||
[weight.to_bytes(2, "big") for weight in weights]
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def deserialize(encoded_model: bytes) -> Model:
|
||||
try:
|
||||
prefix, config_json, encoded_weights = encoded_model.split(b"\n", 2)
|
||||
except ValueError:
|
||||
raise InvalidModelError()
|
||||
|
||||
if prefix != b"GPTC model v4":
|
||||
raise InvalidModelError()
|
||||
|
||||
try:
|
||||
config = json.loads(config_json.decode("utf-8"))
|
||||
except (UnicodeDecodeError, json.JSONDecodeError):
|
||||
raise InvalidModelError()
|
||||
|
||||
try:
|
||||
names = config["names"]
|
||||
max_ngram_length = config["max_ngram_length"]
|
||||
has_emoji = config["has_emoji"]
|
||||
except KeyError:
|
||||
raise InvalidModelError()
|
||||
|
||||
if not (
|
||||
isinstance(names, list)
|
||||
and isinstance(max_ngram_length, int)
|
||||
and isinstance(has_emoji, bool)
|
||||
) or not all([isinstance(name, str) for name in names]):
|
||||
raise InvalidModelError()
|
||||
|
||||
weight_code_length = 6 + 2 * len(names)
|
||||
|
||||
if len(encoded_weights) % weight_code_length != 0:
|
||||
raise InvalidModelError()
|
||||
|
||||
weight_codes = [
|
||||
encoded_weights[x : x + weight_code_length]
|
||||
for x in range(0, len(encoded_weights), weight_code_length)
|
||||
]
|
||||
|
||||
weights = {
|
||||
int.from_bytes(code[:6], "big"): [
|
||||
int.from_bytes(value, "big")
|
||||
for value in [code[x : x + 2] for x in range(6, len(code), 2)]
|
||||
]
|
||||
for code in weight_codes
|
||||
}
|
||||
|
||||
return Model(weights, names, max_ngram_length, has_emoji)
|
|
@ -1,8 +0,0 @@
|
|||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
|
||||
import gptc.compiler
|
||||
from typing import Dict, Union, cast, List
|
||||
|
||||
|
||||
def model_has_emoji(model: gptc.compiler.MODEL) -> bool:
|
||||
return cast(int, model.get("__emoji__", 0)) == 1
|
|
@ -14,7 +14,7 @@ except ImportError:
|
|||
|
||||
def tokenize(
|
||||
text: str, max_ngram_length: int = 1, use_emoji: bool = True
|
||||
) -> List[str]:
|
||||
) -> List[int]:
|
||||
"""Convert a string to a list of lemmas."""
|
||||
converted_text: Union[str, List[str]] = text.lower()
|
||||
|
||||
|
@ -51,8 +51,8 @@ def tokenize(
|
|||
ngrams.append(" ".join(tokens[index : index + ngram_length]))
|
||||
|
||||
return [
|
||||
base64.b64encode(
|
||||
hashlib.sha256(token.encode("utf-8")).digest()[:6]
|
||||
).decode("ascii")
|
||||
int.from_bytes(
|
||||
hashlib.sha256(token.encode("utf-8")).digest()[:6], "big"
|
||||
)
|
||||
for token in ngrams
|
||||
]
|
||||
|
|
BIN
models/compiled.gptc
Normal file
BIN
models/compiled.gptc
Normal file
Binary file not shown.
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user