Add confidence to Model; deprecate Classifier

This commit is contained in:
Samuel Sloniker 2022-11-26 16:41:29 -08:00
parent b4766cb613
commit 448f200923
Signed by: kj7rrv
GPG Key ID: 1BB4029E66285A62
4 changed files with 73 additions and 61 deletions

View File

@ -43,14 +43,15 @@ example of the format. Any exceptions will be printed to stderr.
## Library
### `gptc.Classifier(model, max_ngram_length=1)`
### `Model.serialize()`
Create a `Classifier` object using the given compiled model (as a `gptc.Model`
object, not as a serialized byte string).
Returns a `bytes` representing the model.
For information about `max_ngram_length`, see section "Ngrams."
### `gptc.deserialize(encoded_model)`
#### `Classifier.confidence(text)`
Deserialize a `Model` from a `bytes` returned by `Model.serialize()`.
### `Model.confidence(text, max_ngram_length)`
Classify `text`. Returns a dict of the format `{category: probability,
category:probability, ...}`
@ -60,14 +61,7 @@ common words between the input and the training data (likely, for example, with
input in a different language from the training data), an empty dict will be
returned.
#### `Classifier.classify(text)`
Classify `text`. Returns the category into which the text is placed (as a
string), or `None` when it cannot classify the text.
#### `Classifier.model`
The classifier's model.
For information about `max_ngram_length`, see section "Ngrams."
### `gptc.compile(raw_model, max_ngram_length=1, min_count=1)`
@ -79,14 +73,6 @@ For information about `max_ngram_length`, see section "Ngrams."
Words or ngrams used less than `min_count` times throughout the input text are
excluded from the model.
### `gptc.Model.serialize()`
Returns a `bytes` representing the model.
### `gptc.deserialize(encoded_model)`
Deserialize a `Model` from a `bytes` returned by `Model.serialize()`.
### `gptc.pack(directory, print_exceptions=False)
Pack the model in `directory` and return a tuple of the format:
@ -99,6 +85,13 @@ GPTC.
See `models/unpacked/` for an example of the format.
### `gptc.Classifier(model, max_ngram_length=1)`
`Classifier` objects are deprecated starting with GPTC 3.1.0, and will be
removed in 4.0.0. See [the README from
3.0.2](https://git.kj7rrv.com/kj7rrv/gptc/src/tag/v3.0.1/README.md) if you need
documentation.
## Ngrams
GPTC optionally supports using ngrams to improve classification accuracy. They

View File

@ -13,9 +13,7 @@ def main() -> None:
)
subparsers = parser.add_subparsers(dest="subparser_name", required=True)
compile_parser = subparsers.add_parser(
"compile", help="compile a raw model"
)
compile_parser = subparsers.add_parser("compile", help="compile a raw model")
compile_parser.add_argument("model", help="raw model to compile")
compile_parser.add_argument(
"--max-ngram-length",
@ -55,9 +53,7 @@ def main() -> None:
action="store_true",
)
pack_parser = subparsers.add_parser(
"pack", help="pack a model from a directory"
)
pack_parser = subparsers.add_parser("pack", help="pack a model from a directory")
pack_parser.add_argument("model", help="directory containing model")
args = parser.parse_args()
@ -67,25 +63,26 @@ def main() -> None:
model = json.load(f)
sys.stdout.buffer.write(
gptc.compile(
model, args.max_ngram_length, args.min_count
).serialize()
gptc.compile(model, args.max_ngram_length, args.min_count).serialize()
)
elif args.subparser_name == "classify":
with open(args.model, "rb") as f:
model = gptc.deserialize(f.read())
classifier = gptc.Classifier(model, args.max_ngram_length)
if sys.stdin.isatty():
text = input("Text to analyse: ")
else:
text = sys.stdin.read()
probabilities = model.confidence(text, args.max_ngram_length)
if args.category:
print(classifier.classify(text))
try:
print(sorted(probabilities.items(), key=lambda x: x[1])[-1][0])
except IndexError:
print(None)
else:
print(json.dumps(classifier.confidence(text)))
print(json.dumps(probabilities))
else:
print(json.dumps(gptc.pack(args.model, True)[0]))

View File

@ -1,8 +1,7 @@
# SPDX-License-Identifier: GPL-3.0-or-later
import gptc.tokenizer, gptc.compiler, gptc.exceptions, gptc.weighting
import warnings
from typing import Dict, Union, cast, List
import gptc.model
from typing import Dict, Union
class Classifier:
@ -45,29 +44,7 @@ class Classifier:
matching any categories in the model were found
"""
model = self.model.weights
tokens = gptc.tokenizer.tokenize(text, self.max_ngram_length)
numbered_probs: Dict[int, float] = {}
for word in tokens:
try:
weighted_numbers = gptc.weighting.weight(
[i / 65535 for i in cast(List[float], model[word])]
)
for category, value in enumerate(weighted_numbers):
try:
numbered_probs[category] += value
except KeyError:
numbered_probs[category] = value
except KeyError:
pass
total = sum(numbered_probs.values())
probs: Dict[str, float] = {
self.model.names[category]: value / total
for category, value in numbered_probs.items()
}
return probs
return self.model.confidence(text, self.max_ngram_length)
def classify(self, text: str) -> Union[str, None]:
"""Classify text.

View File

@ -2,7 +2,8 @@
import gptc.tokenizer
from gptc.exceptions import InvalidModelError
from typing import Iterable, Mapping, List, Dict, Union
import gptc.weighting
from typing import Iterable, Mapping, List, Dict, Union, cast
import json
@ -17,6 +18,50 @@ class Model:
self.names = names
self.max_ngram_length = max_ngram_length
def confidence(self, text: str, max_ngram_length: int) -> Dict[str, float]:
"""Classify text with confidence.
Parameters
----------
text : str
The text to classify
max_ngram_length : int
The maximum ngram length to use in classifying
Returns
-------
dict
{category:probability, category:probability...} or {} if no words
matching any categories in the model were found
"""
model = self.weights
tokens = gptc.tokenizer.tokenize(
text, min(max_ngram_length, self.max_ngram_length)
)
numbered_probs: Dict[int, float] = {}
for word in tokens:
try:
weighted_numbers = gptc.weighting.weight(
[i / 65535 for i in cast(List[float], model[word])]
)
for category, value in enumerate(weighted_numbers):
try:
numbered_probs[category] += value
except KeyError:
numbered_probs[category] = value
except KeyError:
pass
total = sum(numbered_probs.values())
probs: Dict[str, float] = {
self.names[category]: value / total
for category, value in numbered_probs.items()
}
return probs
def serialize(self) -> bytes:
out = b"GPTC model v4\n"
out += (