Compare commits
3 Commits
8a1cb6105e
...
f38f4ca801
Author | SHA1 | Date | |
---|---|---|---|
f38f4ca801 | |||
56550ca457 | |||
75fdb5ba3c |
|
@ -112,13 +112,6 @@ GPTC.
|
|||
|
||||
See `models/unpacked/` for an example of the format.
|
||||
|
||||
### `gptc.Classifier(model, max_ngram_length=1)`
|
||||
|
||||
`Classifier` objects are deprecated starting with GPTC 3.1.0, and will be
|
||||
removed in 4.0.0. See [the README from
|
||||
3.0.2](https://git.kj7rrv.com/kj7rrv/gptc/src/tag/v3.0.1/README.md) if you need
|
||||
documentation.
|
||||
|
||||
## Ngrams
|
||||
|
||||
GPTC optionally supports using ngrams to improve classification accuracy. They
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
"""General-Purpose Text Classifier"""
|
||||
|
||||
from gptc.compiler import compile as compile
|
||||
from gptc.classifier import Classifier as Classifier
|
||||
from gptc.pack import pack as pack
|
||||
from gptc.model import Model as Model, deserialize as deserialize
|
||||
from gptc.tokenizer import normalize as normalize
|
||||
|
|
|
@ -44,19 +44,6 @@ def main() -> None:
|
|||
type=int,
|
||||
default=1,
|
||||
)
|
||||
group = classify_parser.add_mutually_exclusive_group()
|
||||
group.add_argument(
|
||||
"-j",
|
||||
"--json",
|
||||
help="output confidence dict as JSON (default)",
|
||||
action="store_true",
|
||||
)
|
||||
group.add_argument(
|
||||
"-c",
|
||||
"--category",
|
||||
help="output most likely category or `None`",
|
||||
action="store_true",
|
||||
)
|
||||
|
||||
check_parser = subparsers.add_parser(
|
||||
"check", help="check one word or ngram in model"
|
||||
|
@ -88,12 +75,7 @@ def main() -> None:
|
|||
else:
|
||||
text = sys.stdin.read()
|
||||
|
||||
if args.category:
|
||||
classifier = gptc.Classifier(model, args.max_ngram_length)
|
||||
print(classifier.classify(text))
|
||||
else:
|
||||
probabilities = model.confidence(text, args.max_ngram_length)
|
||||
print(json.dumps(probabilities))
|
||||
print(json.dumps(model.confidence(text, args.max_ngram_length)))
|
||||
elif args.subparser_name == "check":
|
||||
with open(args.model, "rb") as f:
|
||||
model = gptc.deserialize(f)
|
||||
|
|
|
@ -1,68 +0,0 @@
|
|||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
|
||||
import gptc.model
|
||||
from typing import Dict, Union
|
||||
|
||||
|
||||
class Classifier:
|
||||
"""A text classifier.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : dict
|
||||
A compiled GPTC model.
|
||||
|
||||
max_ngram_length : int
|
||||
The maximum ngram length to use when tokenizing input. If this is
|
||||
greater than the value used when the model was compiled, it will be
|
||||
silently lowered to that value.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
model : dict
|
||||
The model used.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, model: gptc.model.Model, max_ngram_length: int = 1):
|
||||
self.model = model
|
||||
model_ngrams = model.max_ngram_length
|
||||
self.max_ngram_length = min(max_ngram_length, model_ngrams)
|
||||
|
||||
def confidence(self, text: str) -> Dict[str, float]:
|
||||
"""Classify text with confidence.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
text : str
|
||||
The text to classify
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
{category:probability, category:probability...} or {} if no words
|
||||
matching any categories in the model were found
|
||||
|
||||
"""
|
||||
return self.model.confidence(text, self.max_ngram_length)
|
||||
|
||||
def classify(self, text: str) -> Union[str, None]:
|
||||
"""Classify text.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
text : str
|
||||
The text to classify
|
||||
|
||||
Returns
|
||||
-------
|
||||
str or None
|
||||
The most likely category, or None if no words matching any
|
||||
category in the model were found.
|
||||
|
||||
"""
|
||||
probs: Dict[str, float] = self.confidence(text)
|
||||
try:
|
||||
return sorted(probs.items(), key=lambda x: x[1])[-1][0]
|
||||
except IndexError:
|
||||
return None
|
109
gptc/compiler.py
109
gptc/compiler.py
|
@ -2,7 +2,65 @@
|
|||
|
||||
import gptc.tokenizer
|
||||
import gptc.model
|
||||
from typing import Iterable, Mapping, List, Dict, Union
|
||||
from typing import Iterable, Mapping, List, Dict, Union, Tuple
|
||||
|
||||
|
||||
def _count_words(
|
||||
raw_model: Iterable[Mapping[str, str]],
|
||||
max_ngram_length: int,
|
||||
hash_algorithm: str,
|
||||
) -> Tuple[Dict[int, Dict[str, int]], Dict[str, int], List[str]]:
|
||||
word_counts: Dict[int, Dict[str, int]] = {}
|
||||
category_lengths: Dict[str, int] = {}
|
||||
names: List[str] = []
|
||||
|
||||
for portion in raw_model:
|
||||
text = gptc.tokenizer.hash(
|
||||
gptc.tokenizer.tokenize(portion["text"], max_ngram_length),
|
||||
hash_algorithm,
|
||||
)
|
||||
category = portion["category"]
|
||||
|
||||
if not category in names:
|
||||
names.append(category)
|
||||
|
||||
category_lengths[category] = category_lengths.get(category, 0) + len(
|
||||
text
|
||||
)
|
||||
|
||||
for word in text:
|
||||
if word in word_counts:
|
||||
try:
|
||||
word_counts[word][category] += 1
|
||||
except KeyError:
|
||||
word_counts[word][category] = 1
|
||||
else:
|
||||
word_counts[word] = {category: 1}
|
||||
|
||||
return word_counts, category_lengths, names
|
||||
|
||||
|
||||
def _get_weights(
|
||||
min_count: int,
|
||||
word_counts: Dict[int, Dict[str, int]],
|
||||
category_lengths: Dict[str, int],
|
||||
names: List[str],
|
||||
) -> Dict[int, List[int]]:
|
||||
model: Dict[int, List[int]] = {}
|
||||
for word, counts in word_counts.items():
|
||||
if sum(counts.values()) >= min_count:
|
||||
weights = {
|
||||
category: value / category_lengths[category]
|
||||
for category, value in counts.items()
|
||||
}
|
||||
total = sum(weights.values())
|
||||
new_weights: List[int] = []
|
||||
for category in names:
|
||||
new_weights.append(
|
||||
round((weights.get(category, 0) / total) * 65535)
|
||||
)
|
||||
model[word] = new_weights
|
||||
return model
|
||||
|
||||
|
||||
def compile(
|
||||
|
@ -27,49 +85,8 @@ def compile(
|
|||
A compiled GPTC model.
|
||||
|
||||
"""
|
||||
|
||||
word_counts: Dict[int, Dict[int, int]] = {}
|
||||
category_lengths: Dict[int, int] = {}
|
||||
names: List[str] = []
|
||||
|
||||
for portion in raw_model:
|
||||
text = gptc.tokenizer.hash(
|
||||
gptc.tokenizer.tokenize(portion["text"], max_ngram_length),
|
||||
hash_algorithm,
|
||||
)
|
||||
category_name = portion["category"]
|
||||
|
||||
if not category_name in names:
|
||||
names.append(category_name)
|
||||
|
||||
category = names.index(category_name)
|
||||
|
||||
category_lengths[category] = category_lengths.get(category, 0) + len(
|
||||
text
|
||||
)
|
||||
|
||||
for word in text:
|
||||
if word in word_counts:
|
||||
try:
|
||||
word_counts[word][category] += 1
|
||||
except KeyError:
|
||||
word_counts[word][category] = 1
|
||||
else:
|
||||
word_counts[word] = {category: 1}
|
||||
|
||||
model: Dict[int, List[int]] = {}
|
||||
for word, counts in word_counts.items():
|
||||
if sum(counts.values()) >= min_count:
|
||||
weights = {
|
||||
category: value / category_lengths[category]
|
||||
for category, value in counts.items()
|
||||
}
|
||||
total = sum(weights.values())
|
||||
new_weights: List[int] = []
|
||||
for category in range(len(names)):
|
||||
new_weights.append(
|
||||
round((weights.get(category, 0) / total) * 65535)
|
||||
)
|
||||
model[word] = new_weights
|
||||
|
||||
word_counts, category_lengths, names = _count_words(
|
||||
raw_model, max_ngram_length, hash_algorithm
|
||||
)
|
||||
model = _get_weights(min_count, word_counts, category_lengths, names)
|
||||
return gptc.model.Model(model, names, max_ngram_length, hash_algorithm)
|
||||
|
|
16
profiler.py
Normal file
16
profiler.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
|
||||
import cProfile
|
||||
import gptc
|
||||
import json
|
||||
import sys
|
||||
|
||||
max_ngram_length = 10
|
||||
|
||||
with open("models/raw.json") as f:
|
||||
raw_model = json.load(f)
|
||||
|
||||
with open("models/benchmark_text.txt") as f:
|
||||
text = f.read()
|
||||
|
||||
cProfile.run("gptc.compile(raw_model, max_ngram_length)")
|
Loading…
Reference in New Issue
Block a user