gptc/gptc/classifier.py

86 lines
2.1 KiB
Python
Executable File

# SPDX-License-Identifier: LGPL-3.0-or-later
import gptc.tokenizer, gptc.compiler, gptc.exceptions, gptc.weighting
import warnings
class Classifier:
"""A text classifier.
Parameters
----------
model : dict
A compiled GPTC model.
Attributes
----------
model : dict
The model used.
"""
def __init__(self, model):
if model.get("__version__", 0) != 3:
raise gptc.exceptions.UnsupportedModelError(f"unsupported model version")
self.model = model
def confidence(self, text):
"""Classify text with confidence.
Parameters
----------
text : str
The text to classify
Returns
-------
dict
{category:probability, category:probability...} or {} if no words
matching any categories in the model were found
"""
model = self.model
text = gptc.tokenizer.tokenize(text)
probs = {}
for word in text:
try:
weight, weighted_numbers = gptc.weighting.weight(
[i / 65535 for i in model[word]]
)
for category, value in enumerate(weighted_numbers):
try:
probs[category] += value
except KeyError:
probs[category] = value
except KeyError:
pass
probs = {
model["__names__"][category]: value for category, value in probs.items()
}
total = sum(probs.values())
probs = {category: value / total for category, value in probs.items()}
return probs
def classify(self, text):
"""Classify text.
Parameters
----------
text : str
The text to classify
Returns
-------
str or None
The most likely category, or None if no words matching any
category in the model were found.
"""
probs = self.confidence(text)
try:
return sorted(probs.items(), key=lambda x: x[1])[-1][0]
except IndexError:
return None