diff --git a/README.md b/README.md index 00a10d7..c64ae33 100644 --- a/README.md +++ b/README.md @@ -43,14 +43,15 @@ example of the format. Any exceptions will be printed to stderr. ## Library -### `gptc.Classifier(model, max_ngram_length=1)` +### `Model.serialize()` -Create a `Classifier` object using the given compiled model (as a `gptc.Model` -object, not as a serialized byte string). +Returns a `bytes` representing the model. -For information about `max_ngram_length`, see section "Ngrams." +### `gptc.deserialize(encoded_model)` -#### `Classifier.confidence(text)` +Deserialize a `Model` from a `bytes` returned by `Model.serialize()`. + +### `Model.confidence(text, max_ngram_length)` Classify `text`. Returns a dict of the format `{category: probability, category:probability, ...}` @@ -60,14 +61,7 @@ common words between the input and the training data (likely, for example, with input in a different language from the training data), an empty dict will be returned. -#### `Classifier.classify(text)` - -Classify `text`. Returns the category into which the text is placed (as a -string), or `None` when it cannot classify the text. - -#### `Classifier.model` - -The classifier's model. +For information about `max_ngram_length`, see section "Ngrams." ### `gptc.compile(raw_model, max_ngram_length=1, min_count=1)` @@ -79,14 +73,6 @@ For information about `max_ngram_length`, see section "Ngrams." Words or ngrams used less than `min_count` times throughout the input text are excluded from the model. -### `gptc.Model.serialize()` - -Returns a `bytes` representing the model. - -### `gptc.deserialize(encoded_model)` - -Deserialize a `Model` from a `bytes` returned by `Model.serialize()`. - ### `gptc.pack(directory, print_exceptions=False) Pack the model in `directory` and return a tuple of the format: @@ -99,6 +85,13 @@ GPTC. See `models/unpacked/` for an example of the format. +### `gptc.Classifier(model, max_ngram_length=1)` + +`Classifier` objects are deprecated starting with GPTC 3.1.0, and will be +removed in 4.0.0. See [the README from +3.0.2](https://git.kj7rrv.com/kj7rrv/gptc/src/tag/v3.0.1/README.md) if you need +documentation. + ## Ngrams GPTC optionally supports using ngrams to improve classification accuracy. They diff --git a/gptc/__main__.py b/gptc/__main__.py index d98d4bd..6c3e99b 100644 --- a/gptc/__main__.py +++ b/gptc/__main__.py @@ -13,9 +13,7 @@ def main() -> None: ) subparsers = parser.add_subparsers(dest="subparser_name", required=True) - compile_parser = subparsers.add_parser( - "compile", help="compile a raw model" - ) + compile_parser = subparsers.add_parser("compile", help="compile a raw model") compile_parser.add_argument("model", help="raw model to compile") compile_parser.add_argument( "--max-ngram-length", @@ -55,9 +53,7 @@ def main() -> None: action="store_true", ) - pack_parser = subparsers.add_parser( - "pack", help="pack a model from a directory" - ) + pack_parser = subparsers.add_parser("pack", help="pack a model from a directory") pack_parser.add_argument("model", help="directory containing model") args = parser.parse_args() @@ -67,25 +63,26 @@ def main() -> None: model = json.load(f) sys.stdout.buffer.write( - gptc.compile( - model, args.max_ngram_length, args.min_count - ).serialize() + gptc.compile(model, args.max_ngram_length, args.min_count).serialize() ) elif args.subparser_name == "classify": with open(args.model, "rb") as f: model = gptc.deserialize(f.read()) - classifier = gptc.Classifier(model, args.max_ngram_length) - if sys.stdin.isatty(): text = input("Text to analyse: ") else: text = sys.stdin.read() + probabilities = model.confidence(text, args.max_ngram_length) + if args.category: - print(classifier.classify(text)) + try: + print(sorted(probabilities.items(), key=lambda x: x[1])[-1][0]) + except IndexError: + print(None) else: - print(json.dumps(classifier.confidence(text))) + print(json.dumps(probabilities)) else: print(json.dumps(gptc.pack(args.model, True)[0])) diff --git a/gptc/classifier.py b/gptc/classifier.py index de9e0a8..f2e2bf4 100755 --- a/gptc/classifier.py +++ b/gptc/classifier.py @@ -1,8 +1,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later -import gptc.tokenizer, gptc.compiler, gptc.exceptions, gptc.weighting -import warnings -from typing import Dict, Union, cast, List +import gptc.model +from typing import Dict, Union class Classifier: @@ -45,29 +44,7 @@ class Classifier: matching any categories in the model were found """ - - model = self.model.weights - - tokens = gptc.tokenizer.tokenize(text, self.max_ngram_length) - numbered_probs: Dict[int, float] = {} - for word in tokens: - try: - weighted_numbers = gptc.weighting.weight( - [i / 65535 for i in cast(List[float], model[word])] - ) - for category, value in enumerate(weighted_numbers): - try: - numbered_probs[category] += value - except KeyError: - numbered_probs[category] = value - except KeyError: - pass - total = sum(numbered_probs.values()) - probs: Dict[str, float] = { - self.model.names[category]: value / total - for category, value in numbered_probs.items() - } - return probs + return self.model.confidence(text, self.max_ngram_length) def classify(self, text: str) -> Union[str, None]: """Classify text. diff --git a/gptc/model.py b/gptc/model.py index 014189b..e105674 100644 --- a/gptc/model.py +++ b/gptc/model.py @@ -2,7 +2,8 @@ import gptc.tokenizer from gptc.exceptions import InvalidModelError -from typing import Iterable, Mapping, List, Dict, Union +import gptc.weighting +from typing import Iterable, Mapping, List, Dict, Union, cast import json @@ -17,6 +18,50 @@ class Model: self.names = names self.max_ngram_length = max_ngram_length + def confidence(self, text: str, max_ngram_length: int) -> Dict[str, float]: + """Classify text with confidence. + + Parameters + ---------- + text : str + The text to classify + + max_ngram_length : int + The maximum ngram length to use in classifying + + Returns + ------- + dict + {category:probability, category:probability...} or {} if no words + matching any categories in the model were found + + """ + + model = self.weights + + tokens = gptc.tokenizer.tokenize( + text, min(max_ngram_length, self.max_ngram_length) + ) + numbered_probs: Dict[int, float] = {} + for word in tokens: + try: + weighted_numbers = gptc.weighting.weight( + [i / 65535 for i in cast(List[float], model[word])] + ) + for category, value in enumerate(weighted_numbers): + try: + numbered_probs[category] += value + except KeyError: + numbered_probs[category] = value + except KeyError: + pass + total = sum(numbered_probs.values()) + probs: Dict[str, float] = { + self.names[category]: value / total + for category, value in numbered_probs.items() + } + return probs + def serialize(self) -> bytes: out = b"GPTC model v4\n" out += (