Add confidence
to Model; deprecate Classifier
This commit is contained in:
parent
b4766cb613
commit
448f200923
35
README.md
35
README.md
|
@ -43,14 +43,15 @@ example of the format. Any exceptions will be printed to stderr.
|
|||
|
||||
## Library
|
||||
|
||||
### `gptc.Classifier(model, max_ngram_length=1)`
|
||||
### `Model.serialize()`
|
||||
|
||||
Create a `Classifier` object using the given compiled model (as a `gptc.Model`
|
||||
object, not as a serialized byte string).
|
||||
Returns a `bytes` representing the model.
|
||||
|
||||
For information about `max_ngram_length`, see section "Ngrams."
|
||||
### `gptc.deserialize(encoded_model)`
|
||||
|
||||
#### `Classifier.confidence(text)`
|
||||
Deserialize a `Model` from a `bytes` returned by `Model.serialize()`.
|
||||
|
||||
### `Model.confidence(text, max_ngram_length)`
|
||||
|
||||
Classify `text`. Returns a dict of the format `{category: probability,
|
||||
category:probability, ...}`
|
||||
|
@ -60,14 +61,7 @@ common words between the input and the training data (likely, for example, with
|
|||
input in a different language from the training data), an empty dict will be
|
||||
returned.
|
||||
|
||||
#### `Classifier.classify(text)`
|
||||
|
||||
Classify `text`. Returns the category into which the text is placed (as a
|
||||
string), or `None` when it cannot classify the text.
|
||||
|
||||
#### `Classifier.model`
|
||||
|
||||
The classifier's model.
|
||||
For information about `max_ngram_length`, see section "Ngrams."
|
||||
|
||||
### `gptc.compile(raw_model, max_ngram_length=1, min_count=1)`
|
||||
|
||||
|
@ -79,14 +73,6 @@ For information about `max_ngram_length`, see section "Ngrams."
|
|||
Words or ngrams used less than `min_count` times throughout the input text are
|
||||
excluded from the model.
|
||||
|
||||
### `gptc.Model.serialize()`
|
||||
|
||||
Returns a `bytes` representing the model.
|
||||
|
||||
### `gptc.deserialize(encoded_model)`
|
||||
|
||||
Deserialize a `Model` from a `bytes` returned by `Model.serialize()`.
|
||||
|
||||
### `gptc.pack(directory, print_exceptions=False)
|
||||
|
||||
Pack the model in `directory` and return a tuple of the format:
|
||||
|
@ -99,6 +85,13 @@ GPTC.
|
|||
|
||||
See `models/unpacked/` for an example of the format.
|
||||
|
||||
### `gptc.Classifier(model, max_ngram_length=1)`
|
||||
|
||||
`Classifier` objects are deprecated starting with GPTC 3.1.0, and will be
|
||||
removed in 4.0.0. See [the README from
|
||||
3.0.2](https://git.kj7rrv.com/kj7rrv/gptc/src/tag/v3.0.1/README.md) if you need
|
||||
documentation.
|
||||
|
||||
## Ngrams
|
||||
|
||||
GPTC optionally supports using ngrams to improve classification accuracy. They
|
||||
|
|
|
@ -13,9 +13,7 @@ def main() -> None:
|
|||
)
|
||||
subparsers = parser.add_subparsers(dest="subparser_name", required=True)
|
||||
|
||||
compile_parser = subparsers.add_parser(
|
||||
"compile", help="compile a raw model"
|
||||
)
|
||||
compile_parser = subparsers.add_parser("compile", help="compile a raw model")
|
||||
compile_parser.add_argument("model", help="raw model to compile")
|
||||
compile_parser.add_argument(
|
||||
"--max-ngram-length",
|
||||
|
@ -55,9 +53,7 @@ def main() -> None:
|
|||
action="store_true",
|
||||
)
|
||||
|
||||
pack_parser = subparsers.add_parser(
|
||||
"pack", help="pack a model from a directory"
|
||||
)
|
||||
pack_parser = subparsers.add_parser("pack", help="pack a model from a directory")
|
||||
pack_parser.add_argument("model", help="directory containing model")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
@ -67,25 +63,26 @@ def main() -> None:
|
|||
model = json.load(f)
|
||||
|
||||
sys.stdout.buffer.write(
|
||||
gptc.compile(
|
||||
model, args.max_ngram_length, args.min_count
|
||||
).serialize()
|
||||
gptc.compile(model, args.max_ngram_length, args.min_count).serialize()
|
||||
)
|
||||
elif args.subparser_name == "classify":
|
||||
with open(args.model, "rb") as f:
|
||||
model = gptc.deserialize(f.read())
|
||||
|
||||
classifier = gptc.Classifier(model, args.max_ngram_length)
|
||||
|
||||
if sys.stdin.isatty():
|
||||
text = input("Text to analyse: ")
|
||||
else:
|
||||
text = sys.stdin.read()
|
||||
|
||||
probabilities = model.confidence(text, args.max_ngram_length)
|
||||
|
||||
if args.category:
|
||||
print(classifier.classify(text))
|
||||
try:
|
||||
print(sorted(probabilities.items(), key=lambda x: x[1])[-1][0])
|
||||
except IndexError:
|
||||
print(None)
|
||||
else:
|
||||
print(json.dumps(classifier.confidence(text)))
|
||||
print(json.dumps(probabilities))
|
||||
else:
|
||||
print(json.dumps(gptc.pack(args.model, True)[0]))
|
||||
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
|
||||
import gptc.tokenizer, gptc.compiler, gptc.exceptions, gptc.weighting
|
||||
import warnings
|
||||
from typing import Dict, Union, cast, List
|
||||
import gptc.model
|
||||
from typing import Dict, Union
|
||||
|
||||
|
||||
class Classifier:
|
||||
|
@ -45,29 +44,7 @@ class Classifier:
|
|||
matching any categories in the model were found
|
||||
|
||||
"""
|
||||
|
||||
model = self.model.weights
|
||||
|
||||
tokens = gptc.tokenizer.tokenize(text, self.max_ngram_length)
|
||||
numbered_probs: Dict[int, float] = {}
|
||||
for word in tokens:
|
||||
try:
|
||||
weighted_numbers = gptc.weighting.weight(
|
||||
[i / 65535 for i in cast(List[float], model[word])]
|
||||
)
|
||||
for category, value in enumerate(weighted_numbers):
|
||||
try:
|
||||
numbered_probs[category] += value
|
||||
except KeyError:
|
||||
numbered_probs[category] = value
|
||||
except KeyError:
|
||||
pass
|
||||
total = sum(numbered_probs.values())
|
||||
probs: Dict[str, float] = {
|
||||
self.model.names[category]: value / total
|
||||
for category, value in numbered_probs.items()
|
||||
}
|
||||
return probs
|
||||
return self.model.confidence(text, self.max_ngram_length)
|
||||
|
||||
def classify(self, text: str) -> Union[str, None]:
|
||||
"""Classify text.
|
||||
|
|
|
@ -2,7 +2,8 @@
|
|||
|
||||
import gptc.tokenizer
|
||||
from gptc.exceptions import InvalidModelError
|
||||
from typing import Iterable, Mapping, List, Dict, Union
|
||||
import gptc.weighting
|
||||
from typing import Iterable, Mapping, List, Dict, Union, cast
|
||||
import json
|
||||
|
||||
|
||||
|
@ -17,6 +18,50 @@ class Model:
|
|||
self.names = names
|
||||
self.max_ngram_length = max_ngram_length
|
||||
|
||||
def confidence(self, text: str, max_ngram_length: int) -> Dict[str, float]:
|
||||
"""Classify text with confidence.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
text : str
|
||||
The text to classify
|
||||
|
||||
max_ngram_length : int
|
||||
The maximum ngram length to use in classifying
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
{category:probability, category:probability...} or {} if no words
|
||||
matching any categories in the model were found
|
||||
|
||||
"""
|
||||
|
||||
model = self.weights
|
||||
|
||||
tokens = gptc.tokenizer.tokenize(
|
||||
text, min(max_ngram_length, self.max_ngram_length)
|
||||
)
|
||||
numbered_probs: Dict[int, float] = {}
|
||||
for word in tokens:
|
||||
try:
|
||||
weighted_numbers = gptc.weighting.weight(
|
||||
[i / 65535 for i in cast(List[float], model[word])]
|
||||
)
|
||||
for category, value in enumerate(weighted_numbers):
|
||||
try:
|
||||
numbered_probs[category] += value
|
||||
except KeyError:
|
||||
numbered_probs[category] = value
|
||||
except KeyError:
|
||||
pass
|
||||
total = sum(numbered_probs.values())
|
||||
probs: Dict[str, float] = {
|
||||
self.names[category]: value / total
|
||||
for category, value in numbered_probs.items()
|
||||
}
|
||||
return probs
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
out = b"GPTC model v4\n"
|
||||
out += (
|
||||
|
|
Loading…
Reference in New Issue
Block a user