Compare commits
3 Commits
8a1cb6105e
...
f38f4ca801
Author | SHA1 | Date | |
---|---|---|---|
f38f4ca801 | |||
56550ca457 | |||
75fdb5ba3c |
|
@ -112,13 +112,6 @@ GPTC.
|
||||||
|
|
||||||
See `models/unpacked/` for an example of the format.
|
See `models/unpacked/` for an example of the format.
|
||||||
|
|
||||||
### `gptc.Classifier(model, max_ngram_length=1)`
|
|
||||||
|
|
||||||
`Classifier` objects are deprecated starting with GPTC 3.1.0, and will be
|
|
||||||
removed in 4.0.0. See [the README from
|
|
||||||
3.0.2](https://git.kj7rrv.com/kj7rrv/gptc/src/tag/v3.0.1/README.md) if you need
|
|
||||||
documentation.
|
|
||||||
|
|
||||||
## Ngrams
|
## Ngrams
|
||||||
|
|
||||||
GPTC optionally supports using ngrams to improve classification accuracy. They
|
GPTC optionally supports using ngrams to improve classification accuracy. They
|
||||||
|
|
|
@ -3,7 +3,6 @@
|
||||||
"""General-Purpose Text Classifier"""
|
"""General-Purpose Text Classifier"""
|
||||||
|
|
||||||
from gptc.compiler import compile as compile
|
from gptc.compiler import compile as compile
|
||||||
from gptc.classifier import Classifier as Classifier
|
|
||||||
from gptc.pack import pack as pack
|
from gptc.pack import pack as pack
|
||||||
from gptc.model import Model as Model, deserialize as deserialize
|
from gptc.model import Model as Model, deserialize as deserialize
|
||||||
from gptc.tokenizer import normalize as normalize
|
from gptc.tokenizer import normalize as normalize
|
||||||
|
|
|
@ -44,19 +44,6 @@ def main() -> None:
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
)
|
)
|
||||||
group = classify_parser.add_mutually_exclusive_group()
|
|
||||||
group.add_argument(
|
|
||||||
"-j",
|
|
||||||
"--json",
|
|
||||||
help="output confidence dict as JSON (default)",
|
|
||||||
action="store_true",
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"-c",
|
|
||||||
"--category",
|
|
||||||
help="output most likely category or `None`",
|
|
||||||
action="store_true",
|
|
||||||
)
|
|
||||||
|
|
||||||
check_parser = subparsers.add_parser(
|
check_parser = subparsers.add_parser(
|
||||||
"check", help="check one word or ngram in model"
|
"check", help="check one word or ngram in model"
|
||||||
|
@ -88,12 +75,7 @@ def main() -> None:
|
||||||
else:
|
else:
|
||||||
text = sys.stdin.read()
|
text = sys.stdin.read()
|
||||||
|
|
||||||
if args.category:
|
print(json.dumps(model.confidence(text, args.max_ngram_length)))
|
||||||
classifier = gptc.Classifier(model, args.max_ngram_length)
|
|
||||||
print(classifier.classify(text))
|
|
||||||
else:
|
|
||||||
probabilities = model.confidence(text, args.max_ngram_length)
|
|
||||||
print(json.dumps(probabilities))
|
|
||||||
elif args.subparser_name == "check":
|
elif args.subparser_name == "check":
|
||||||
with open(args.model, "rb") as f:
|
with open(args.model, "rb") as f:
|
||||||
model = gptc.deserialize(f)
|
model = gptc.deserialize(f)
|
||||||
|
|
|
@ -1,68 +0,0 @@
|
||||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
|
||||||
|
|
||||||
import gptc.model
|
|
||||||
from typing import Dict, Union
|
|
||||||
|
|
||||||
|
|
||||||
class Classifier:
|
|
||||||
"""A text classifier.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
model : dict
|
|
||||||
A compiled GPTC model.
|
|
||||||
|
|
||||||
max_ngram_length : int
|
|
||||||
The maximum ngram length to use when tokenizing input. If this is
|
|
||||||
greater than the value used when the model was compiled, it will be
|
|
||||||
silently lowered to that value.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
model : dict
|
|
||||||
The model used.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, model: gptc.model.Model, max_ngram_length: int = 1):
|
|
||||||
self.model = model
|
|
||||||
model_ngrams = model.max_ngram_length
|
|
||||||
self.max_ngram_length = min(max_ngram_length, model_ngrams)
|
|
||||||
|
|
||||||
def confidence(self, text: str) -> Dict[str, float]:
|
|
||||||
"""Classify text with confidence.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
text : str
|
|
||||||
The text to classify
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
dict
|
|
||||||
{category:probability, category:probability...} or {} if no words
|
|
||||||
matching any categories in the model were found
|
|
||||||
|
|
||||||
"""
|
|
||||||
return self.model.confidence(text, self.max_ngram_length)
|
|
||||||
|
|
||||||
def classify(self, text: str) -> Union[str, None]:
|
|
||||||
"""Classify text.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
text : str
|
|
||||||
The text to classify
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
str or None
|
|
||||||
The most likely category, or None if no words matching any
|
|
||||||
category in the model were found.
|
|
||||||
|
|
||||||
"""
|
|
||||||
probs: Dict[str, float] = self.confidence(text)
|
|
||||||
try:
|
|
||||||
return sorted(probs.items(), key=lambda x: x[1])[-1][0]
|
|
||||||
except IndexError:
|
|
||||||
return None
|
|
|
@ -2,32 +2,14 @@
|
||||||
|
|
||||||
import gptc.tokenizer
|
import gptc.tokenizer
|
||||||
import gptc.model
|
import gptc.model
|
||||||
from typing import Iterable, Mapping, List, Dict, Union
|
from typing import Iterable, Mapping, List, Dict, Union, Tuple
|
||||||
|
|
||||||
|
|
||||||
def compile(
|
def _count_words(
|
||||||
raw_model: Iterable[Mapping[str, str]],
|
raw_model: Iterable[Mapping[str, str]],
|
||||||
max_ngram_length: int = 1,
|
max_ngram_length: int,
|
||||||
min_count: int = 1,
|
hash_algorithm: str,
|
||||||
hash_algorithm: str = "sha256",
|
) -> Tuple[Dict[int, Dict[str, int]], Dict[str, int], List[str]]:
|
||||||
) -> gptc.model.Model:
|
|
||||||
"""Compile a raw model.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
raw_model : list of dict
|
|
||||||
A raw GPTC model.
|
|
||||||
|
|
||||||
max_ngram_length : int
|
|
||||||
Maximum ngram lenght to compile with.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
dict
|
|
||||||
A compiled GPTC model.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
word_counts: Dict[int, Dict[str, int]] = {}
|
word_counts: Dict[int, Dict[str, int]] = {}
|
||||||
category_lengths: Dict[str, int] = {}
|
category_lengths: Dict[str, int] = {}
|
||||||
names: List[str] = []
|
names: List[str] = []
|
||||||
|
@ -55,6 +37,15 @@ def compile(
|
||||||
else:
|
else:
|
||||||
word_counts[word] = {category: 1}
|
word_counts[word] = {category: 1}
|
||||||
|
|
||||||
|
return word_counts, category_lengths, names
|
||||||
|
|
||||||
|
|
||||||
|
def _get_weights(
|
||||||
|
min_count: int,
|
||||||
|
word_counts: Dict[int, Dict[str, int]],
|
||||||
|
category_lengths: Dict[str, int],
|
||||||
|
names: List[str],
|
||||||
|
) -> Dict[int, List[int]]:
|
||||||
model: Dict[int, List[int]] = {}
|
model: Dict[int, List[int]] = {}
|
||||||
for word, counts in word_counts.items():
|
for word, counts in word_counts.items():
|
||||||
if sum(counts.values()) >= min_count:
|
if sum(counts.values()) >= min_count:
|
||||||
|
@ -69,5 +60,33 @@ def compile(
|
||||||
round((weights.get(category, 0) / total) * 65535)
|
round((weights.get(category, 0) / total) * 65535)
|
||||||
)
|
)
|
||||||
model[word] = new_weights
|
model[word] = new_weights
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def compile(
|
||||||
|
raw_model: Iterable[Mapping[str, str]],
|
||||||
|
max_ngram_length: int = 1,
|
||||||
|
min_count: int = 1,
|
||||||
|
hash_algorithm: str = "sha256",
|
||||||
|
) -> gptc.model.Model:
|
||||||
|
"""Compile a raw model.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
raw_model : list of dict
|
||||||
|
A raw GPTC model.
|
||||||
|
|
||||||
|
max_ngram_length : int
|
||||||
|
Maximum ngram lenght to compile with.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
dict
|
||||||
|
A compiled GPTC model.
|
||||||
|
|
||||||
|
"""
|
||||||
|
word_counts, category_lengths, names = _count_words(
|
||||||
|
raw_model, max_ngram_length, hash_algorithm
|
||||||
|
)
|
||||||
|
model = _get_weights(min_count, word_counts, category_lengths, names)
|
||||||
return gptc.model.Model(model, names, max_ngram_length, hash_algorithm)
|
return gptc.model.Model(model, names, max_ngram_length, hash_algorithm)
|
||||||
|
|
16
profiler.py
Normal file
16
profiler.py
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||||
|
|
||||||
|
import cProfile
|
||||||
|
import gptc
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
|
||||||
|
max_ngram_length = 10
|
||||||
|
|
||||||
|
with open("models/raw.json") as f:
|
||||||
|
raw_model = json.load(f)
|
||||||
|
|
||||||
|
with open("models/benchmark_text.txt") as f:
|
||||||
|
text = f.read()
|
||||||
|
|
||||||
|
cProfile.run("gptc.compile(raw_model, max_ngram_length)")
|
Loading…
Reference in New Issue
Block a user