# SPDX-License-Identifier: LGPL-3.0-or-later import gptc.tokenizer, gptc.compiler, gptc.exceptions, gptc.weighting import warnings class Classifier: """A text classifier. Parameters ---------- model : dict A compiled GPTC model. Attributes ---------- model : dict The model used. """ def __init__(self, model): if model.get("__version__", 0) != 3: raise gptc.exceptions.UnsupportedModelError(f"unsupported model version") self.model = model def confidence(self, text): """Classify text with confidence. Parameters ---------- text : str The text to classify Returns ------- dict {category:probability, category:probability...} or {} if no words matching any categories in the model were found """ model = self.model text = gptc.tokenizer.tokenize(text) probs = {} for word in text: try: weight, weighted_numbers = gptc.weighting.weight( [i / 65535 for i in model[word]] ) for category, value in enumerate(weighted_numbers): try: probs[category] += value except KeyError: probs[category] = value except KeyError: pass probs = { model["__names__"][category]: value for category, value in probs.items() } total = sum(probs.values()) probs = {category: value / total for category, value in probs.items()} return probs def classify(self, text): """Classify text. Parameters ---------- text : str The text to classify Returns ------- str or None The most likely category, or None if no words matching any category in the model were found. """ probs = self.confidence(text) try: return sorted(probs.items(), key=lambda x: x[1])[-1][0] except IndexError: return None