Compare commits

...

3 Commits

Author SHA1 Message Date
f38f4ca801
Add profiler 2023-04-16 14:27:31 -07:00
56550ca457
Remove Classifier objects
Closes #16
2023-04-16 14:27:07 -07:00
75fdb5ba3c
Split compiler into two functions 2023-01-15 09:39:35 -08:00
6 changed files with 59 additions and 118 deletions

View File

@ -112,13 +112,6 @@ GPTC.
See `models/unpacked/` for an example of the format. See `models/unpacked/` for an example of the format.
### `gptc.Classifier(model, max_ngram_length=1)`
`Classifier` objects are deprecated starting with GPTC 3.1.0, and will be
removed in 4.0.0. See [the README from
3.0.2](https://git.kj7rrv.com/kj7rrv/gptc/src/tag/v3.0.1/README.md) if you need
documentation.
## Ngrams ## Ngrams
GPTC optionally supports using ngrams to improve classification accuracy. They GPTC optionally supports using ngrams to improve classification accuracy. They

View File

@ -3,7 +3,6 @@
"""General-Purpose Text Classifier""" """General-Purpose Text Classifier"""
from gptc.compiler import compile as compile from gptc.compiler import compile as compile
from gptc.classifier import Classifier as Classifier
from gptc.pack import pack as pack from gptc.pack import pack as pack
from gptc.model import Model as Model, deserialize as deserialize from gptc.model import Model as Model, deserialize as deserialize
from gptc.tokenizer import normalize as normalize from gptc.tokenizer import normalize as normalize

View File

@ -44,19 +44,6 @@ def main() -> None:
type=int, type=int,
default=1, default=1,
) )
group = classify_parser.add_mutually_exclusive_group()
group.add_argument(
"-j",
"--json",
help="output confidence dict as JSON (default)",
action="store_true",
)
group.add_argument(
"-c",
"--category",
help="output most likely category or `None`",
action="store_true",
)
check_parser = subparsers.add_parser( check_parser = subparsers.add_parser(
"check", help="check one word or ngram in model" "check", help="check one word or ngram in model"
@ -88,12 +75,7 @@ def main() -> None:
else: else:
text = sys.stdin.read() text = sys.stdin.read()
if args.category: print(json.dumps(model.confidence(text, args.max_ngram_length)))
classifier = gptc.Classifier(model, args.max_ngram_length)
print(classifier.classify(text))
else:
probabilities = model.confidence(text, args.max_ngram_length)
print(json.dumps(probabilities))
elif args.subparser_name == "check": elif args.subparser_name == "check":
with open(args.model, "rb") as f: with open(args.model, "rb") as f:
model = gptc.deserialize(f) model = gptc.deserialize(f)

View File

@ -1,68 +0,0 @@
# SPDX-License-Identifier: GPL-3.0-or-later
import gptc.model
from typing import Dict, Union
class Classifier:
"""A text classifier.
Parameters
----------
model : dict
A compiled GPTC model.
max_ngram_length : int
The maximum ngram length to use when tokenizing input. If this is
greater than the value used when the model was compiled, it will be
silently lowered to that value.
Attributes
----------
model : dict
The model used.
"""
def __init__(self, model: gptc.model.Model, max_ngram_length: int = 1):
self.model = model
model_ngrams = model.max_ngram_length
self.max_ngram_length = min(max_ngram_length, model_ngrams)
def confidence(self, text: str) -> Dict[str, float]:
"""Classify text with confidence.
Parameters
----------
text : str
The text to classify
Returns
-------
dict
{category:probability, category:probability...} or {} if no words
matching any categories in the model were found
"""
return self.model.confidence(text, self.max_ngram_length)
def classify(self, text: str) -> Union[str, None]:
"""Classify text.
Parameters
----------
text : str
The text to classify
Returns
-------
str or None
The most likely category, or None if no words matching any
category in the model were found.
"""
probs: Dict[str, float] = self.confidence(text)
try:
return sorted(probs.items(), key=lambda x: x[1])[-1][0]
except IndexError:
return None

View File

@ -2,32 +2,14 @@
import gptc.tokenizer import gptc.tokenizer
import gptc.model import gptc.model
from typing import Iterable, Mapping, List, Dict, Union from typing import Iterable, Mapping, List, Dict, Union, Tuple
def compile( def _count_words(
raw_model: Iterable[Mapping[str, str]], raw_model: Iterable[Mapping[str, str]],
max_ngram_length: int = 1, max_ngram_length: int,
min_count: int = 1, hash_algorithm: str,
hash_algorithm: str = "sha256", ) -> Tuple[Dict[int, Dict[str, int]], Dict[str, int], List[str]]:
) -> gptc.model.Model:
"""Compile a raw model.
Parameters
----------
raw_model : list of dict
A raw GPTC model.
max_ngram_length : int
Maximum ngram lenght to compile with.
Returns
-------
dict
A compiled GPTC model.
"""
word_counts: Dict[int, Dict[str, int]] = {} word_counts: Dict[int, Dict[str, int]] = {}
category_lengths: Dict[str, int] = {} category_lengths: Dict[str, int] = {}
names: List[str] = [] names: List[str] = []
@ -55,6 +37,15 @@ def compile(
else: else:
word_counts[word] = {category: 1} word_counts[word] = {category: 1}
return word_counts, category_lengths, names
def _get_weights(
min_count: int,
word_counts: Dict[int, Dict[str, int]],
category_lengths: Dict[str, int],
names: List[str],
) -> Dict[int, List[int]]:
model: Dict[int, List[int]] = {} model: Dict[int, List[int]] = {}
for word, counts in word_counts.items(): for word, counts in word_counts.items():
if sum(counts.values()) >= min_count: if sum(counts.values()) >= min_count:
@ -69,5 +60,33 @@ def compile(
round((weights.get(category, 0) / total) * 65535) round((weights.get(category, 0) / total) * 65535)
) )
model[word] = new_weights model[word] = new_weights
return model
def compile(
raw_model: Iterable[Mapping[str, str]],
max_ngram_length: int = 1,
min_count: int = 1,
hash_algorithm: str = "sha256",
) -> gptc.model.Model:
"""Compile a raw model.
Parameters
----------
raw_model : list of dict
A raw GPTC model.
max_ngram_length : int
Maximum ngram lenght to compile with.
Returns
-------
dict
A compiled GPTC model.
"""
word_counts, category_lengths, names = _count_words(
raw_model, max_ngram_length, hash_algorithm
)
model = _get_weights(min_count, word_counts, category_lengths, names)
return gptc.model.Model(model, names, max_ngram_length, hash_algorithm) return gptc.model.Model(model, names, max_ngram_length, hash_algorithm)

16
profiler.py Normal file
View File

@ -0,0 +1,16 @@
# SPDX-License-Identifier: GPL-3.0-or-later
import cProfile
import gptc
import json
import sys
max_ngram_length = 10
with open("models/raw.json") as f:
raw_model = json.load(f)
with open("models/benchmark_text.txt") as f:
text = f.read()
cProfile.run("gptc.compile(raw_model, max_ngram_length)")