Add confidence to Model; deprecate Classifier

This commit is contained in:
Samuel Sloniker 2022-11-26 16:41:29 -08:00
parent b4766cb613
commit 448f200923
Signed by: kj7rrv
GPG Key ID: 1BB4029E66285A62
4 changed files with 73 additions and 61 deletions

View File

@ -43,14 +43,15 @@ example of the format. Any exceptions will be printed to stderr.
## Library ## Library
### `gptc.Classifier(model, max_ngram_length=1)` ### `Model.serialize()`
Create a `Classifier` object using the given compiled model (as a `gptc.Model` Returns a `bytes` representing the model.
object, not as a serialized byte string).
For information about `max_ngram_length`, see section "Ngrams." ### `gptc.deserialize(encoded_model)`
#### `Classifier.confidence(text)` Deserialize a `Model` from a `bytes` returned by `Model.serialize()`.
### `Model.confidence(text, max_ngram_length)`
Classify `text`. Returns a dict of the format `{category: probability, Classify `text`. Returns a dict of the format `{category: probability,
category:probability, ...}` category:probability, ...}`
@ -60,14 +61,7 @@ common words between the input and the training data (likely, for example, with
input in a different language from the training data), an empty dict will be input in a different language from the training data), an empty dict will be
returned. returned.
#### `Classifier.classify(text)` For information about `max_ngram_length`, see section "Ngrams."
Classify `text`. Returns the category into which the text is placed (as a
string), or `None` when it cannot classify the text.
#### `Classifier.model`
The classifier's model.
### `gptc.compile(raw_model, max_ngram_length=1, min_count=1)` ### `gptc.compile(raw_model, max_ngram_length=1, min_count=1)`
@ -79,14 +73,6 @@ For information about `max_ngram_length`, see section "Ngrams."
Words or ngrams used less than `min_count` times throughout the input text are Words or ngrams used less than `min_count` times throughout the input text are
excluded from the model. excluded from the model.
### `gptc.Model.serialize()`
Returns a `bytes` representing the model.
### `gptc.deserialize(encoded_model)`
Deserialize a `Model` from a `bytes` returned by `Model.serialize()`.
### `gptc.pack(directory, print_exceptions=False) ### `gptc.pack(directory, print_exceptions=False)
Pack the model in `directory` and return a tuple of the format: Pack the model in `directory` and return a tuple of the format:
@ -99,6 +85,13 @@ GPTC.
See `models/unpacked/` for an example of the format. See `models/unpacked/` for an example of the format.
### `gptc.Classifier(model, max_ngram_length=1)`
`Classifier` objects are deprecated starting with GPTC 3.1.0, and will be
removed in 4.0.0. See [the README from
3.0.2](https://git.kj7rrv.com/kj7rrv/gptc/src/tag/v3.0.1/README.md) if you need
documentation.
## Ngrams ## Ngrams
GPTC optionally supports using ngrams to improve classification accuracy. They GPTC optionally supports using ngrams to improve classification accuracy. They

View File

@ -13,9 +13,7 @@ def main() -> None:
) )
subparsers = parser.add_subparsers(dest="subparser_name", required=True) subparsers = parser.add_subparsers(dest="subparser_name", required=True)
compile_parser = subparsers.add_parser( compile_parser = subparsers.add_parser("compile", help="compile a raw model")
"compile", help="compile a raw model"
)
compile_parser.add_argument("model", help="raw model to compile") compile_parser.add_argument("model", help="raw model to compile")
compile_parser.add_argument( compile_parser.add_argument(
"--max-ngram-length", "--max-ngram-length",
@ -55,9 +53,7 @@ def main() -> None:
action="store_true", action="store_true",
) )
pack_parser = subparsers.add_parser( pack_parser = subparsers.add_parser("pack", help="pack a model from a directory")
"pack", help="pack a model from a directory"
)
pack_parser.add_argument("model", help="directory containing model") pack_parser.add_argument("model", help="directory containing model")
args = parser.parse_args() args = parser.parse_args()
@ -67,25 +63,26 @@ def main() -> None:
model = json.load(f) model = json.load(f)
sys.stdout.buffer.write( sys.stdout.buffer.write(
gptc.compile( gptc.compile(model, args.max_ngram_length, args.min_count).serialize()
model, args.max_ngram_length, args.min_count
).serialize()
) )
elif args.subparser_name == "classify": elif args.subparser_name == "classify":
with open(args.model, "rb") as f: with open(args.model, "rb") as f:
model = gptc.deserialize(f.read()) model = gptc.deserialize(f.read())
classifier = gptc.Classifier(model, args.max_ngram_length)
if sys.stdin.isatty(): if sys.stdin.isatty():
text = input("Text to analyse: ") text = input("Text to analyse: ")
else: else:
text = sys.stdin.read() text = sys.stdin.read()
probabilities = model.confidence(text, args.max_ngram_length)
if args.category: if args.category:
print(classifier.classify(text)) try:
print(sorted(probabilities.items(), key=lambda x: x[1])[-1][0])
except IndexError:
print(None)
else: else:
print(json.dumps(classifier.confidence(text))) print(json.dumps(probabilities))
else: else:
print(json.dumps(gptc.pack(args.model, True)[0])) print(json.dumps(gptc.pack(args.model, True)[0]))

View File

@ -1,8 +1,7 @@
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
import gptc.tokenizer, gptc.compiler, gptc.exceptions, gptc.weighting import gptc.model
import warnings from typing import Dict, Union
from typing import Dict, Union, cast, List
class Classifier: class Classifier:
@ -45,29 +44,7 @@ class Classifier:
matching any categories in the model were found matching any categories in the model were found
""" """
return self.model.confidence(text, self.max_ngram_length)
model = self.model.weights
tokens = gptc.tokenizer.tokenize(text, self.max_ngram_length)
numbered_probs: Dict[int, float] = {}
for word in tokens:
try:
weighted_numbers = gptc.weighting.weight(
[i / 65535 for i in cast(List[float], model[word])]
)
for category, value in enumerate(weighted_numbers):
try:
numbered_probs[category] += value
except KeyError:
numbered_probs[category] = value
except KeyError:
pass
total = sum(numbered_probs.values())
probs: Dict[str, float] = {
self.model.names[category]: value / total
for category, value in numbered_probs.items()
}
return probs
def classify(self, text: str) -> Union[str, None]: def classify(self, text: str) -> Union[str, None]:
"""Classify text. """Classify text.

View File

@ -2,7 +2,8 @@
import gptc.tokenizer import gptc.tokenizer
from gptc.exceptions import InvalidModelError from gptc.exceptions import InvalidModelError
from typing import Iterable, Mapping, List, Dict, Union import gptc.weighting
from typing import Iterable, Mapping, List, Dict, Union, cast
import json import json
@ -17,6 +18,50 @@ class Model:
self.names = names self.names = names
self.max_ngram_length = max_ngram_length self.max_ngram_length = max_ngram_length
def confidence(self, text: str, max_ngram_length: int) -> Dict[str, float]:
"""Classify text with confidence.
Parameters
----------
text : str
The text to classify
max_ngram_length : int
The maximum ngram length to use in classifying
Returns
-------
dict
{category:probability, category:probability...} or {} if no words
matching any categories in the model were found
"""
model = self.weights
tokens = gptc.tokenizer.tokenize(
text, min(max_ngram_length, self.max_ngram_length)
)
numbered_probs: Dict[int, float] = {}
for word in tokens:
try:
weighted_numbers = gptc.weighting.weight(
[i / 65535 for i in cast(List[float], model[word])]
)
for category, value in enumerate(weighted_numbers):
try:
numbered_probs[category] += value
except KeyError:
numbered_probs[category] = value
except KeyError:
pass
total = sum(numbered_probs.values())
probs: Dict[str, float] = {
self.names[category]: value / total
for category, value in numbered_probs.items()
}
return probs
def serialize(self) -> bytes: def serialize(self) -> bytes:
out = b"GPTC model v4\n" out = b"GPTC model v4\n"
out += ( out += (