Lightweight NLP library in pure Python - currently implements a text classifier
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

322 lines
9.0 KiB

# SPDX-License-Identifier: GPL-3.0-or-later
from typing import (
Iterable,
Mapping,
List,
Dict,
cast,
BinaryIO,
Tuple,
TypedDict,
)
import json
import gptc.tokenizer
from gptc.exceptions import InvalidModelError
import gptc.weighting
def _count_words(
raw_model: Iterable[Mapping[str, str]],
max_ngram_length: int,
hash_algorithm: str,
) -> Tuple[Dict[int, Dict[str, int]], Dict[str, int], List[str]]:
word_counts: Dict[int, Dict[str, int]] = {}
category_lengths: Dict[str, int] = {}
names: List[str] = []
for portion in raw_model:
text = gptc.tokenizer.hash_list(
gptc.tokenizer.tokenize(portion["text"], max_ngram_length),
hash_algorithm,
)
category = portion["category"]
if not category in names:
names.append(category)
category_lengths[category] = category_lengths.get(category, 0) + len(
text
)
for word in text:
if word in word_counts:
try:
word_counts[word][category] += 1
except KeyError:
word_counts[word][category] = 1
else:
word_counts[word] = {category: 1}
return word_counts, category_lengths, names
def _get_weights(
min_count: int,
word_counts: Dict[int, Dict[str, int]],
category_lengths: Dict[str, int],
names: List[str],
) -> Dict[int, List[int]]:
model: Dict[int, List[int]] = {}
for word, counts in word_counts.items():
if sum(counts.values()) >= min_count:
weights = {
category: value / category_lengths[category]
for category, value in counts.items()
}
total = sum(weights.values())
new_weights: List[int] = []
for category in names:
new_weights.append(
round((weights.get(category, 0) / total) * 65535)
)
model[word] = new_weights
return model
class ExplanationEntry(TypedDict):
weight: float
probabilities: Dict[str, float]
count: int
Explanation = Dict[
str,
ExplanationEntry,
]
Log = List[Tuple[str, float, List[float]]]
class Confidences(dict[str, float]):
def __init__(self, probs: Dict[str, float]):
dict.__init__(self, probs)
class TransparentConfidences(Confidences):
def __init__(
self,
probs: Dict[str, float],
explanation: Explanation,
):
self.explanation = explanation
Confidences.__init__(self, probs)
def convert_log(log: Log, names: List[str]) -> Explanation:
explanation: Explanation = {}
for word2, weight, word_probs in log:
if word2 in explanation:
explanation[word2]["count"] += 1
else:
explanation[word2] = {
"weight": weight,
"probabilities": {
name: word_probs[index] for index, name in enumerate(names)
},
"count": 1,
}
return explanation
class Model:
def __init__(
self,
weights: Dict[int, List[int]],
names: List[str],
max_ngram_length: int,
hash_algorithm: str,
):
self.weights = weights
self.names = names
self.max_ngram_length = max_ngram_length
self.hash_algorithm = hash_algorithm
def confidence(
self, text: str, max_ngram_length: int, transparent: bool = False
) -> Confidences:
"""Classify text with confidence.
Parameters
----------
text : str
The text to classify
max_ngram_length : int
The maximum ngram length to use in classifying
Returns
-------
dict
{category:probability, category:probability...} or {} if no words
matching any categories in the model were found
"""
model = self.weights
max_ngram_length = min(self.max_ngram_length, max_ngram_length)
raw_tokens = gptc.tokenizer.tokenize(
text, min(max_ngram_length, self.max_ngram_length)
)
tokens = gptc.tokenizer.hash_list(
raw_tokens,
self.hash_algorithm,
)
if transparent:
token_map = {tokens[i]: raw_tokens[i] for i in range(len(tokens))}
log: Log = []
numbered_probs: Dict[int, float] = {}
for word in tokens:
try:
unweighted_numbers = [
i / 65535 for i in cast(List[float], model[word])
]
weight, weighted_numbers = gptc.weighting.weight(
unweighted_numbers
)
if transparent:
log.append(
(
token_map[word],
weight,
unweighted_numbers,
)
)
for category, value in enumerate(weighted_numbers):
try:
numbered_probs[category] += value
except KeyError:
numbered_probs[category] = value
except KeyError:
pass
total = sum(numbered_probs.values())
probs: Dict[str, float] = {
self.names[category]: value / total
for category, value in numbered_probs.items()
}
if transparent:
explanation = convert_log(log, self.names)
return TransparentConfidences(probs, explanation)
return Confidences(probs)
def get(self, token: str) -> Dict[str, float]:
try:
weights = self.weights[
gptc.tokenizer.hash_single(
gptc.tokenizer.normalize(token), self.hash_algorithm
)
]
except KeyError:
return {}
return {
category: weights[index] / 65535
for index, category in enumerate(self.names)
}
def serialize(self, file: BinaryIO) -> None:
file.write(b"GPTC model v6\n")
file.write(
json.dumps(
{
"names": self.names,
"max_ngram_length": self.max_ngram_length,
"hash_algorithm": self.hash_algorithm,
}
).encode("utf-8")
+ b"\n"
)
for word, weights in self.weights.items():
file.write(
word.to_bytes(6, "big")
+ b"".join([weight.to_bytes(2, "big") for weight in weights])
)
@staticmethod
def compile(
raw_model: Iterable[Mapping[str, str]],
max_ngram_length: int = 1,
min_count: int = 1,
hash_algorithm: str = "sha256",
) -> 'Model':
"""Compile a raw model.
Parameters
----------
raw_model : list of dict
A raw GPTC model.
max_ngram_length : int
Maximum ngram lenght to compile with.
Returns
-------
dict
A compiled GPTC model.
"""
word_counts, category_lengths, names = _count_words(
raw_model, max_ngram_length, hash_algorithm
)
model = _get_weights(min_count, word_counts, category_lengths, names)
return Model(model, names, max_ngram_length, hash_algorithm)
@staticmethod
def deserialize(encoded_model: BinaryIO) -> "Model":
prefix = encoded_model.read(14)
if prefix != b"GPTC model v6\n":
raise InvalidModelError()
config_json = b""
while True:
byte = encoded_model.read(1)
if byte == b"\n":
break
if byte == b"":
raise InvalidModelError()
config_json += byte
try:
config = json.loads(config_json.decode("utf-8"))
except (UnicodeDecodeError, json.JSONDecodeError) as exc:
raise InvalidModelError() from exc
try:
names = config["names"]
max_ngram_length = config["max_ngram_length"]
hash_algorithm = config["hash_algorithm"]
except KeyError as exc:
raise InvalidModelError() from exc
if not (
isinstance(names, list) and isinstance(max_ngram_length, int)
) or not all(isinstance(name, str) for name in names):
raise InvalidModelError()
weight_code_length = 6 + 2 * len(names)
weights: Dict[int, List[int]] = {}
while True:
code = encoded_model.read(weight_code_length)
if not code:
break
if len(code) != weight_code_length:
raise InvalidModelError()
weights[int.from_bytes(code[:6], "big")] = [
int.from_bytes(value, "big")
for value in [code[x : x + 2] for x in range(6, len(code), 2)]
]
return Model(weights, names, max_ngram_length, hash_algorithm)