Working type checks
This commit is contained in:
parent
b36d8e6081
commit
67ac3a4591
|
@ -7,7 +7,7 @@ import sys
|
|||
import gptc
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="General Purpose Text Classifier", prog="gptc"
|
||||
)
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
import gptc.tokenizer, gptc.compiler, gptc.exceptions, gptc.weighting
|
||||
import warnings
|
||||
from typing import Dict, Union, cast, List
|
||||
|
||||
|
||||
class Classifier:
|
||||
|
@ -24,17 +25,18 @@ class Classifier:
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self, model, max_ngram_length=1):
|
||||
def __init__(self, model: gptc.compiler.MODEL, max_ngram_length: int = 1):
|
||||
if model.get("__version__", 0) != 3:
|
||||
raise gptc.exceptions.UnsupportedModelError(
|
||||
f"unsupported model version"
|
||||
)
|
||||
self.model = model
|
||||
model_ngrams = cast(int, model.get("__ngrams__", 1))
|
||||
self.max_ngram_length = min(
|
||||
max_ngram_length, model.get("__ngrams__", 1)
|
||||
max_ngram_length, model_ngrams
|
||||
)
|
||||
|
||||
def confidence(self, text):
|
||||
def confidence(self, text: str) -> Dict[str, float]:
|
||||
"""Classify text with confidence.
|
||||
|
||||
Parameters
|
||||
|
@ -52,29 +54,28 @@ class Classifier:
|
|||
|
||||
model = self.model
|
||||
|
||||
text = gptc.tokenizer.tokenize(text, self.max_ngram_length)
|
||||
probs = {}
|
||||
for word in text:
|
||||
tokens = gptc.tokenizer.tokenize(text, self.max_ngram_length)
|
||||
numbered_probs: Dict[int, float] = {}
|
||||
for word in tokens:
|
||||
try:
|
||||
weight, weighted_numbers = gptc.weighting.weight(
|
||||
[i / 65535 for i in model[word]]
|
||||
weighted_numbers = gptc.weighting.weight(
|
||||
[i / 65535 for i in cast(List[float], model[word])]
|
||||
)
|
||||
for category, value in enumerate(weighted_numbers):
|
||||
try:
|
||||
probs[category] += value
|
||||
numbered_probs[category] += value
|
||||
except KeyError:
|
||||
probs[category] = value
|
||||
numbered_probs[category] = value
|
||||
except KeyError:
|
||||
pass
|
||||
probs = {
|
||||
model["__names__"][category]: value
|
||||
for category, value in probs.items()
|
||||
total = sum(numbered_probs.values())
|
||||
probs: Dict[str, float] = {
|
||||
cast(List[str], model["__names__"])[category]: value / total
|
||||
for category, value in numbered_probs.items()
|
||||
}
|
||||
total = sum(probs.values())
|
||||
probs = {category: value / total for category, value in probs.items()}
|
||||
return probs
|
||||
|
||||
def classify(self, text):
|
||||
def classify(self, text: str) -> Union[str, None]:
|
||||
"""Classify text.
|
||||
|
||||
Parameters
|
||||
|
@ -89,7 +90,7 @@ class Classifier:
|
|||
category in the model were found.
|
||||
|
||||
"""
|
||||
probs = self.confidence(text)
|
||||
probs: Dict[str, float] = self.confidence(text)
|
||||
try:
|
||||
return sorted(probs.items(), key=lambda x: x[1])[-1][0]
|
||||
except IndexError:
|
||||
|
|
|
@ -3,10 +3,14 @@
|
|||
import gptc.tokenizer
|
||||
from typing import Iterable, Mapping, List, Dict, Union
|
||||
|
||||
WEIGHTS_T = List[int]
|
||||
CONFIG_T = Union[List[str], int, str]
|
||||
MODEL = Dict[str, Union[WEIGHTS_T, CONFIG_T]]
|
||||
|
||||
|
||||
def compile(
|
||||
raw_model: Iterable[Mapping[str, str]], max_ngram_length: int = 1
|
||||
) -> Dict[str, Union[str, int, List[int], List[str]]]:
|
||||
) -> MODEL:
|
||||
"""Compile a raw model.
|
||||
|
||||
Parameters
|
||||
|
@ -24,7 +28,7 @@ def compile(
|
|||
|
||||
"""
|
||||
|
||||
categories: Dict[str, str] = {}
|
||||
categories: Dict[str, List[str]] = {}
|
||||
|
||||
for portion in raw_model:
|
||||
text = gptc.tokenizer.tokenize(portion["text"], max_ngram_length)
|
||||
|
@ -60,7 +64,7 @@ def compile(
|
|||
except KeyError:
|
||||
word_weights[word] = {category: value}
|
||||
|
||||
model: Dict[str, Union[str, int, List[int], List[str]]] = {}
|
||||
model: MODEL = {}
|
||||
for word, weights in word_weights.items():
|
||||
total = sum(weights.values())
|
||||
new_weights: List[int] = []
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# SPDX-License-Identifier: LGPL-3.0-or-later
|
||||
|
||||
from typing import List, Union
|
||||
try:
|
||||
import emoji
|
||||
|
||||
|
@ -8,9 +9,9 @@ except ImportError:
|
|||
has_emoji = False
|
||||
|
||||
|
||||
def tokenize(text, max_ngram_length=1):
|
||||
def tokenize(text: str, max_ngram_length: int=1) -> List[str]:
|
||||
"""Convert a string to a list of lemmas."""
|
||||
text = text.lower()
|
||||
converted_text: Union[str, List[str]] = text.lower()
|
||||
|
||||
if has_emoji:
|
||||
parts = []
|
||||
|
@ -20,11 +21,11 @@ def tokenize(text, max_ngram_length=1):
|
|||
parts.append(emoji_part["emoji"])
|
||||
highest_end = emoji_part["match_end"]
|
||||
parts += list(text[highest_end:])
|
||||
text = [part for part in parts if part]
|
||||
converted_text = [part for part in parts if part]
|
||||
|
||||
tokens = [""]
|
||||
|
||||
for char in text:
|
||||
for char in converted_text:
|
||||
if char.isalpha() or char == "'":
|
||||
tokens[-1] += char
|
||||
elif has_emoji and emoji.is_emoji(char):
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
# SPDX-License-Identifier: LGPL-3.0-or-later
|
||||
|
||||
import math
|
||||
from typing import Sequence, Union, Tuple, List
|
||||
|
||||
|
||||
def _mean(numbers):
|
||||
def _mean(numbers: Sequence[float]) -> float:
|
||||
"""Calculate the mean of a group of numbers
|
||||
|
||||
Parameters
|
||||
|
@ -19,7 +20,7 @@ def _mean(numbers):
|
|||
return sum(numbers) / len(numbers)
|
||||
|
||||
|
||||
def _standard_deviation(numbers):
|
||||
def _standard_deviation(numbers: Sequence[float]) -> float:
|
||||
"""Calculate the standard deviation of a group of numbers
|
||||
|
||||
Parameters
|
||||
|
@ -38,8 +39,8 @@ def _standard_deviation(numbers):
|
|||
return math.sqrt(_mean(squared_deviations))
|
||||
|
||||
|
||||
def weight(numbers):
|
||||
def weight(numbers: Sequence[float]) -> List[float]:
|
||||
standard_deviation = _standard_deviation(numbers)
|
||||
weight = standard_deviation * 2
|
||||
weighted_numbers = [i * weight for i in numbers]
|
||||
return weight, weighted_numbers
|
||||
return weighted_numbers
|
||||
|
|
Loading…
Reference in New Issue
Block a user