Compare commits

..

3 Commits

Author SHA1 Message Date
f38f4ca801
Add profiler 2023-04-16 14:27:31 -07:00
56550ca457
Remove Classifier objects
Closes #16
2023-04-16 14:27:07 -07:00
75fdb5ba3c
Split compiler into two functions 2023-01-15 09:39:35 -08:00
6 changed files with 80 additions and 141 deletions

View File

@ -112,13 +112,6 @@ GPTC.
See `models/unpacked/` for an example of the format.
### `gptc.Classifier(model, max_ngram_length=1)`
`Classifier` objects are deprecated starting with GPTC 3.1.0, and will be
removed in 4.0.0. See [the README from
3.0.2](https://git.kj7rrv.com/kj7rrv/gptc/src/tag/v3.0.1/README.md) if you need
documentation.
## Ngrams
GPTC optionally supports using ngrams to improve classification accuracy. They

View File

@ -3,7 +3,6 @@
"""General-Purpose Text Classifier"""
from gptc.compiler import compile as compile
from gptc.classifier import Classifier as Classifier
from gptc.pack import pack as pack
from gptc.model import Model as Model, deserialize as deserialize
from gptc.tokenizer import normalize as normalize

View File

@ -44,19 +44,6 @@ def main() -> None:
type=int,
default=1,
)
group = classify_parser.add_mutually_exclusive_group()
group.add_argument(
"-j",
"--json",
help="output confidence dict as JSON (default)",
action="store_true",
)
group.add_argument(
"-c",
"--category",
help="output most likely category or `None`",
action="store_true",
)
check_parser = subparsers.add_parser(
"check", help="check one word or ngram in model"
@ -88,12 +75,7 @@ def main() -> None:
else:
text = sys.stdin.read()
if args.category:
classifier = gptc.Classifier(model, args.max_ngram_length)
print(classifier.classify(text))
else:
probabilities = model.confidence(text, args.max_ngram_length)
print(json.dumps(probabilities))
print(json.dumps(model.confidence(text, args.max_ngram_length)))
elif args.subparser_name == "check":
with open(args.model, "rb") as f:
model = gptc.deserialize(f)

View File

@ -1,68 +0,0 @@
# SPDX-License-Identifier: GPL-3.0-or-later
import gptc.model
from typing import Dict, Union
class Classifier:
"""A text classifier.
Parameters
----------
model : dict
A compiled GPTC model.
max_ngram_length : int
The maximum ngram length to use when tokenizing input. If this is
greater than the value used when the model was compiled, it will be
silently lowered to that value.
Attributes
----------
model : dict
The model used.
"""
def __init__(self, model: gptc.model.Model, max_ngram_length: int = 1):
self.model = model
model_ngrams = model.max_ngram_length
self.max_ngram_length = min(max_ngram_length, model_ngrams)
def confidence(self, text: str) -> Dict[str, float]:
"""Classify text with confidence.
Parameters
----------
text : str
The text to classify
Returns
-------
dict
{category:probability, category:probability...} or {} if no words
matching any categories in the model were found
"""
return self.model.confidence(text, self.max_ngram_length)
def classify(self, text: str) -> Union[str, None]:
"""Classify text.
Parameters
----------
text : str
The text to classify
Returns
-------
str or None
The most likely category, or None if no words matching any
category in the model were found.
"""
probs: Dict[str, float] = self.confidence(text)
try:
return sorted(probs.items(), key=lambda x: x[1])[-1][0]
except IndexError:
return None

View File

@ -2,7 +2,65 @@
import gptc.tokenizer
import gptc.model
from typing import Iterable, Mapping, List, Dict, Union
from typing import Iterable, Mapping, List, Dict, Union, Tuple
def _count_words(
raw_model: Iterable[Mapping[str, str]],
max_ngram_length: int,
hash_algorithm: str,
) -> Tuple[Dict[int, Dict[str, int]], Dict[str, int], List[str]]:
word_counts: Dict[int, Dict[str, int]] = {}
category_lengths: Dict[str, int] = {}
names: List[str] = []
for portion in raw_model:
text = gptc.tokenizer.hash(
gptc.tokenizer.tokenize(portion["text"], max_ngram_length),
hash_algorithm,
)
category = portion["category"]
if not category in names:
names.append(category)
category_lengths[category] = category_lengths.get(category, 0) + len(
text
)
for word in text:
if word in word_counts:
try:
word_counts[word][category] += 1
except KeyError:
word_counts[word][category] = 1
else:
word_counts[word] = {category: 1}
return word_counts, category_lengths, names
def _get_weights(
min_count: int,
word_counts: Dict[int, Dict[str, int]],
category_lengths: Dict[str, int],
names: List[str],
) -> Dict[int, List[int]]:
model: Dict[int, List[int]] = {}
for word, counts in word_counts.items():
if sum(counts.values()) >= min_count:
weights = {
category: value / category_lengths[category]
for category, value in counts.items()
}
total = sum(weights.values())
new_weights: List[int] = []
for category in names:
new_weights.append(
round((weights.get(category, 0) / total) * 65535)
)
model[word] = new_weights
return model
def compile(
@ -27,49 +85,8 @@ def compile(
A compiled GPTC model.
"""
word_counts: Dict[int, Dict[int, int]] = {}
category_lengths: Dict[int, int] = {}
names: List[str] = []
for portion in raw_model:
text = gptc.tokenizer.hash(
gptc.tokenizer.tokenize(portion["text"], max_ngram_length),
hash_algorithm,
)
category_name = portion["category"]
if not category_name in names:
names.append(category_name)
category = names.index(category_name)
category_lengths[category] = category_lengths.get(category, 0) + len(
text
)
for word in text:
if word in word_counts:
try:
word_counts[word][category] += 1
except KeyError:
word_counts[word][category] = 1
else:
word_counts[word] = {category: 1}
model: Dict[int, List[int]] = {}
for word, counts in word_counts.items():
if sum(counts.values()) >= min_count:
weights = {
category: value / category_lengths[category]
for category, value in counts.items()
}
total = sum(weights.values())
new_weights: List[int] = []
for category in range(len(names)):
new_weights.append(
round((weights.get(category, 0) / total) * 65535)
)
model[word] = new_weights
word_counts, category_lengths, names = _count_words(
raw_model, max_ngram_length, hash_algorithm
)
model = _get_weights(min_count, word_counts, category_lengths, names)
return gptc.model.Model(model, names, max_ngram_length, hash_algorithm)

16
profiler.py Normal file
View File

@ -0,0 +1,16 @@
# SPDX-License-Identifier: GPL-3.0-or-later
import cProfile
import gptc
import json
import sys
max_ngram_length = 10
with open("models/raw.json") as f:
raw_model = json.load(f)
with open("models/benchmark_text.txt") as f:
text = f.read()
cProfile.run("gptc.compile(raw_model, max_ngram_length)")