Compare commits
No commits in common. "d8f3d2e7013d200201e4aa6759c7bf7045562735" and "f38f4ca8013baeb56ff05acea4eeddcc2664ece7" have entirely different histories.
d8f3d2e701
...
f38f4ca801
|
@ -5,7 +5,6 @@ from gptc.exceptions import InvalidModelError
|
||||||
import gptc.weighting
|
import gptc.weighting
|
||||||
from typing import Iterable, Mapping, List, Dict, Union, cast, BinaryIO
|
from typing import Iterable, Mapping, List, Dict, Union, cast, BinaryIO
|
||||||
import json
|
import json
|
||||||
import collections
|
|
||||||
|
|
||||||
|
|
||||||
class Model:
|
class Model:
|
||||||
|
@ -21,9 +20,7 @@ class Model:
|
||||||
self.max_ngram_length = max_ngram_length
|
self.max_ngram_length = max_ngram_length
|
||||||
self.hash_algorithm = hash_algorithm
|
self.hash_algorithm = hash_algorithm
|
||||||
|
|
||||||
def confidence(
|
def confidence(self, text: str, max_ngram_length: int) -> Dict[str, float]:
|
||||||
self, text: str, max_ngram_length: int, transparent: bool = False
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
"""Classify text with confidence.
|
"""Classify text with confidence.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
|
@ -43,36 +40,19 @@ class Model:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model = self.weights
|
model = self.weights
|
||||||
max_ngram_length = min(self.max_ngram_length, max_ngram_length)
|
|
||||||
|
|
||||||
raw_tokens = gptc.tokenizer.tokenize(
|
|
||||||
text, min(max_ngram_length, self.max_ngram_length)
|
|
||||||
)
|
|
||||||
|
|
||||||
tokens = gptc.tokenizer.hash(
|
tokens = gptc.tokenizer.hash(
|
||||||
raw_tokens,
|
gptc.tokenizer.tokenize(
|
||||||
|
text, min(max_ngram_length, self.max_ngram_length)
|
||||||
|
),
|
||||||
self.hash_algorithm,
|
self.hash_algorithm,
|
||||||
)
|
)
|
||||||
|
|
||||||
if transparent:
|
|
||||||
token_map = {tokens[i]: raw_tokens[i] for i in range(len(tokens))}
|
|
||||||
log = []
|
|
||||||
|
|
||||||
numbered_probs: Dict[int, float] = {}
|
numbered_probs: Dict[int, float] = {}
|
||||||
|
|
||||||
for word in tokens:
|
for word in tokens:
|
||||||
try:
|
try:
|
||||||
unweighted_numbers = [
|
weighted_numbers = gptc.weighting.weight(
|
||||||
i / 65535 for i in cast(List[float], model[word])
|
[i / 65535 for i in cast(List[float], model[word])]
|
||||||
]
|
|
||||||
|
|
||||||
weight, weighted_numbers = gptc.weighting.weight(
|
|
||||||
unweighted_numbers
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if transparent:
|
|
||||||
log.append([token_map[word], weight, unweighted_numbers])
|
|
||||||
|
|
||||||
for category, value in enumerate(weighted_numbers):
|
for category, value in enumerate(weighted_numbers):
|
||||||
try:
|
try:
|
||||||
numbered_probs[category] += value
|
numbered_probs[category] += value
|
||||||
|
@ -80,30 +60,12 @@ class Model:
|
||||||
numbered_probs[category] = value
|
numbered_probs[category] = value
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
total = sum(numbered_probs.values())
|
total = sum(numbered_probs.values())
|
||||||
probs: Dict[str, float] = {
|
probs: Dict[str, float] = {
|
||||||
self.names[category]: value / total
|
self.names[category]: value / total
|
||||||
for category, value in numbered_probs.items()
|
for category, value in numbered_probs.items()
|
||||||
}
|
}
|
||||||
|
return probs
|
||||||
if transparent:
|
|
||||||
explanation = {}
|
|
||||||
for word, weight, word_probs in log:
|
|
||||||
if word in explanation:
|
|
||||||
explanation[word]["count"] += 1
|
|
||||||
else:
|
|
||||||
explanation[word] = {
|
|
||||||
"weight": weight,
|
|
||||||
"probabilities": word_probs,
|
|
||||||
"count": 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
return TransparentConfidences(
|
|
||||||
probs, explanation, self, text, max_ngram_length
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return Confidences(probs, self, text, max_ngram_length)
|
|
||||||
|
|
||||||
def get(self, token: str) -> Dict[str, float]:
|
def get(self, token: str) -> Dict[str, float]:
|
||||||
try:
|
try:
|
||||||
|
@ -120,7 +82,7 @@ class Model:
|
||||||
}
|
}
|
||||||
|
|
||||||
def serialize(self, file: BinaryIO) -> None:
|
def serialize(self, file: BinaryIO) -> None:
|
||||||
file.write(b"GPTC model v6\n")
|
file.write(b"GPTC model v5\n")
|
||||||
file.write(
|
file.write(
|
||||||
json.dumps(
|
json.dumps(
|
||||||
{
|
{
|
||||||
|
@ -138,23 +100,9 @@ class Model:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Confidences(collections.UserDict):
|
|
||||||
def __init__(self, probs, model, text, max_ngram_length):
|
|
||||||
collections.UserDict.__init__(self, probs)
|
|
||||||
self.model = model
|
|
||||||
self.text = text
|
|
||||||
self.max_ngram_length = max_ngram_length
|
|
||||||
|
|
||||||
|
|
||||||
class TransparentConfidences(Confidences):
|
|
||||||
def __init__(self, probs, explanation, model, text, max_ngram_length):
|
|
||||||
Confidences.__init__(self, probs, model, text, max_ngram_length)
|
|
||||||
self.explanation = explanation
|
|
||||||
|
|
||||||
|
|
||||||
def deserialize(encoded_model: BinaryIO) -> Model:
|
def deserialize(encoded_model: BinaryIO) -> Model:
|
||||||
prefix = encoded_model.read(14)
|
prefix = encoded_model.read(14)
|
||||||
if prefix != b"GPTC model v6\n":
|
if prefix != b"GPTC model v5\n":
|
||||||
raise InvalidModelError()
|
raise InvalidModelError()
|
||||||
|
|
||||||
config_json = b""
|
config_json = b""
|
||||||
|
|
|
@ -7,7 +7,7 @@ import unicodedata
|
||||||
|
|
||||||
|
|
||||||
def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
|
def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
|
||||||
text = unicodedata.normalize("NFKD", text).casefold()
|
text = unicodedata.normalize("NFKD", text).lower()
|
||||||
parts = []
|
parts = []
|
||||||
highest_end = 0
|
highest_end = 0
|
||||||
for emoji_part in emoji.emoji_list(text):
|
for emoji_part in emoji.emoji_list(text):
|
||||||
|
|
|
@ -43,4 +43,4 @@ def weight(numbers: Sequence[float]) -> List[float]:
|
||||||
standard_deviation = _standard_deviation(numbers)
|
standard_deviation = _standard_deviation(numbers)
|
||||||
weight = standard_deviation * 2
|
weight = standard_deviation * 2
|
||||||
weighted_numbers = [i * weight for i in numbers]
|
weighted_numbers = [i * weight for i in numbers]
|
||||||
return weight, weighted_numbers
|
return weighted_numbers
|
||||||
|
|
Loading…
Reference in New Issue
Block a user