86 lines
2.1 KiB
Python
Executable File
86 lines
2.1 KiB
Python
Executable File
# SPDX-License-Identifier: LGPL-3.0-or-later
|
|
|
|
import gptc.tokenizer, gptc.compiler, gptc.exceptions, gptc.weighting
|
|
import warnings
|
|
|
|
|
|
class Classifier:
|
|
"""A text classifier.
|
|
|
|
Parameters
|
|
----------
|
|
model : dict
|
|
A compiled GPTC model.
|
|
|
|
Attributes
|
|
----------
|
|
model : dict
|
|
The model used.
|
|
|
|
"""
|
|
|
|
def __init__(self, model):
|
|
if model.get("__version__", 0) != 3:
|
|
raise gptc.exceptions.UnsupportedModelError(f"unsupported model version")
|
|
self.model = model
|
|
|
|
def confidence(self, text):
|
|
"""Classify text with confidence.
|
|
|
|
Parameters
|
|
----------
|
|
text : str
|
|
The text to classify
|
|
|
|
Returns
|
|
-------
|
|
dict
|
|
{category:probability, category:probability...} or {} if no words
|
|
matching any categories in the model were found
|
|
|
|
"""
|
|
|
|
model = self.model
|
|
|
|
text = gptc.tokenizer.tokenize(text)
|
|
probs = {}
|
|
for word in text:
|
|
try:
|
|
weight, weighted_numbers = gptc.weighting.weight(
|
|
[i / 65535 for i in model[word]]
|
|
)
|
|
for category, value in enumerate(weighted_numbers):
|
|
try:
|
|
probs[category] += value
|
|
except KeyError:
|
|
probs[category] = value
|
|
except KeyError:
|
|
pass
|
|
probs = {
|
|
model["__names__"][category]: value for category, value in probs.items()
|
|
}
|
|
total = sum(probs.values())
|
|
probs = {category: value / total for category, value in probs.items()}
|
|
return probs
|
|
|
|
def classify(self, text):
|
|
"""Classify text.
|
|
|
|
Parameters
|
|
----------
|
|
text : str
|
|
The text to classify
|
|
|
|
Returns
|
|
-------
|
|
str or None
|
|
The most likely category, or None if no words matching any
|
|
category in the model were found.
|
|
|
|
"""
|
|
probs = self.confidence(text)
|
|
try:
|
|
return sorted(probs.items(), key=lambda x: x[1])[-1][0]
|
|
except IndexError:
|
|
return None
|