Compare commits
14 Commits
Author | SHA1 | Date | |
---|---|---|---|
71e9249ff4 | |||
97c4eef086 | |||
457b569741 | |||
4546c4cffa | |||
7b7ef39d0b | |||
a252a15e9d | |||
9513025e60 | |||
2c3fc77ba6 | |||
d8f3d2e701 | |||
7f68dc6fc6 | |||
99ad07a876 | |||
f38f4ca801 | |||
56550ca457 | |||
75fdb5ba3c |
|
@ -48,7 +48,7 @@ example of the format. Any exceptions will be printed to stderr.
|
||||||
|
|
||||||
Write binary data representing the model to `file`.
|
Write binary data representing the model to `file`.
|
||||||
|
|
||||||
### `gptc.deserialize(encoded_model)`
|
### `Model.deserialize(encoded_model)`
|
||||||
|
|
||||||
Deserialize a `Model` from a file containing data from `Model.serialize()`.
|
Deserialize a `Model` from a file containing data from `Model.serialize()`.
|
||||||
|
|
||||||
|
@ -70,7 +70,7 @@ Return a confidence dict for the given token or ngram. This function is very
|
||||||
similar to `Model.confidence()`, except it treats the input as a single token
|
similar to `Model.confidence()`, except it treats the input as a single token
|
||||||
or ngram.
|
or ngram.
|
||||||
|
|
||||||
### `gptc.compile(raw_model, max_ngram_length=1, min_count=1, hash_algorithm="sha256")`
|
### `Model.compile(raw_model, max_ngram_length=1, min_count=1, hash_algorithm="sha256")`
|
||||||
|
|
||||||
Compile a raw model (as a list, not JSON) and return the compiled model (as a
|
Compile a raw model (as a list, not JSON) and return the compiled model (as a
|
||||||
`gptc.Model` object).
|
`gptc.Model` object).
|
||||||
|
@ -115,7 +115,7 @@ See `models/unpacked/` for an example of the format.
|
||||||
### `gptc.Classifier(model, max_ngram_length=1)`
|
### `gptc.Classifier(model, max_ngram_length=1)`
|
||||||
|
|
||||||
`Classifier` objects are deprecated starting with GPTC 3.1.0, and will be
|
`Classifier` objects are deprecated starting with GPTC 3.1.0, and will be
|
||||||
removed in 4.0.0. See [the README from
|
removed in 5.0.0. See [the README from
|
||||||
3.0.2](https://git.kj7rrv.com/kj7rrv/gptc/src/tag/v3.0.1/README.md) if you need
|
3.0.2](https://git.kj7rrv.com/kj7rrv/gptc/src/tag/v3.0.1/README.md) if you need
|
||||||
documentation.
|
documentation.
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@ print(
|
||||||
round(
|
round(
|
||||||
1000000
|
1000000
|
||||||
* timeit.timeit(
|
* timeit.timeit(
|
||||||
"gptc.compile(raw_model, max_ngram_length)",
|
"gptc.Model.compile(raw_model, max_ngram_length)",
|
||||||
number=compile_iterations,
|
number=compile_iterations,
|
||||||
globals=globals(),
|
globals=globals(),
|
||||||
)
|
)
|
||||||
|
|
|
@ -2,13 +2,11 @@
|
||||||
|
|
||||||
"""General-Purpose Text Classifier"""
|
"""General-Purpose Text Classifier"""
|
||||||
|
|
||||||
from gptc.compiler import compile as compile
|
from gptc.pack import pack
|
||||||
from gptc.classifier import Classifier as Classifier
|
from gptc.model import Model
|
||||||
from gptc.pack import pack as pack
|
from gptc.tokenizer import normalize
|
||||||
from gptc.model import Model as Model, deserialize as deserialize
|
|
||||||
from gptc.tokenizer import normalize as normalize
|
|
||||||
from gptc.exceptions import (
|
from gptc.exceptions import (
|
||||||
GPTCError as GPTCError,
|
GPTCError,
|
||||||
ModelError as ModelError,
|
ModelError,
|
||||||
InvalidModelError as InvalidModelError,
|
InvalidModelError,
|
||||||
)
|
)
|
||||||
|
|
|
@ -44,19 +44,6 @@ def main() -> None:
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
)
|
)
|
||||||
group = classify_parser.add_mutually_exclusive_group()
|
|
||||||
group.add_argument(
|
|
||||||
"-j",
|
|
||||||
"--json",
|
|
||||||
help="output confidence dict as JSON (default)",
|
|
||||||
action="store_true",
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"-c",
|
|
||||||
"--category",
|
|
||||||
help="output most likely category or `None`",
|
|
||||||
action="store_true",
|
|
||||||
)
|
|
||||||
|
|
||||||
check_parser = subparsers.add_parser(
|
check_parser = subparsers.add_parser(
|
||||||
"check", help="check one word or ngram in model"
|
"check", help="check one word or ngram in model"
|
||||||
|
@ -72,31 +59,26 @@ def main() -> None:
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.subparser_name == "compile":
|
if args.subparser_name == "compile":
|
||||||
with open(args.model, "r") as f:
|
with open(args.model, "r", encoding="utf-8") as input_file:
|
||||||
model = json.load(f)
|
model = json.load(input_file)
|
||||||
|
|
||||||
with open(args.out, "wb+") as f:
|
with open(args.out, "wb+") as output_file:
|
||||||
gptc.compile(
|
gptc.Model.compile(
|
||||||
model, args.max_ngram_length, args.min_count
|
model, args.max_ngram_length, args.min_count
|
||||||
).serialize(f)
|
).serialize(output_file)
|
||||||
elif args.subparser_name == "classify":
|
elif args.subparser_name == "classify":
|
||||||
with open(args.model, "rb") as f:
|
with open(args.model, "rb") as model_file:
|
||||||
model = gptc.deserialize(f)
|
model = gptc.Model.deserialize(model_file)
|
||||||
|
|
||||||
if sys.stdin.isatty():
|
if sys.stdin.isatty():
|
||||||
text = input("Text to analyse: ")
|
text = input("Text to analyse: ")
|
||||||
else:
|
else:
|
||||||
text = sys.stdin.read()
|
text = sys.stdin.read()
|
||||||
|
|
||||||
if args.category:
|
print(json.dumps(model.confidence(text, args.max_ngram_length)))
|
||||||
classifier = gptc.Classifier(model, args.max_ngram_length)
|
|
||||||
print(classifier.classify(text))
|
|
||||||
else:
|
|
||||||
probabilities = model.confidence(text, args.max_ngram_length)
|
|
||||||
print(json.dumps(probabilities))
|
|
||||||
elif args.subparser_name == "check":
|
elif args.subparser_name == "check":
|
||||||
with open(args.model, "rb") as f:
|
with open(args.model, "rb") as model_file:
|
||||||
model = gptc.deserialize(f)
|
model = gptc.Model.deserialize(model_file)
|
||||||
print(json.dumps(model.get(args.token)))
|
print(json.dumps(model.get(args.token)))
|
||||||
else:
|
else:
|
||||||
print(json.dumps(gptc.pack(args.model, True)[0]))
|
print(json.dumps(gptc.pack(args.model, True)[0]))
|
||||||
|
|
|
@ -1,68 +0,0 @@
|
||||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
|
||||||
|
|
||||||
import gptc.model
|
|
||||||
from typing import Dict, Union
|
|
||||||
|
|
||||||
|
|
||||||
class Classifier:
|
|
||||||
"""A text classifier.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
model : dict
|
|
||||||
A compiled GPTC model.
|
|
||||||
|
|
||||||
max_ngram_length : int
|
|
||||||
The maximum ngram length to use when tokenizing input. If this is
|
|
||||||
greater than the value used when the model was compiled, it will be
|
|
||||||
silently lowered to that value.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
model : dict
|
|
||||||
The model used.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, model: gptc.model.Model, max_ngram_length: int = 1):
|
|
||||||
self.model = model
|
|
||||||
model_ngrams = model.max_ngram_length
|
|
||||||
self.max_ngram_length = min(max_ngram_length, model_ngrams)
|
|
||||||
|
|
||||||
def confidence(self, text: str) -> Dict[str, float]:
|
|
||||||
"""Classify text with confidence.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
text : str
|
|
||||||
The text to classify
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
dict
|
|
||||||
{category:probability, category:probability...} or {} if no words
|
|
||||||
matching any categories in the model were found
|
|
||||||
|
|
||||||
"""
|
|
||||||
return self.model.confidence(text, self.max_ngram_length)
|
|
||||||
|
|
||||||
def classify(self, text: str) -> Union[str, None]:
|
|
||||||
"""Classify text.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
text : str
|
|
||||||
The text to classify
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
str or None
|
|
||||||
The most likely category, or None if no words matching any
|
|
||||||
category in the model were found.
|
|
||||||
|
|
||||||
"""
|
|
||||||
probs: Dict[str, float] = self.confidence(text)
|
|
||||||
try:
|
|
||||||
return sorted(probs.items(), key=lambda x: x[1])[-1][0]
|
|
||||||
except IndexError:
|
|
||||||
return None
|
|
|
@ -1,73 +0,0 @@
|
||||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
|
||||||
|
|
||||||
import gptc.tokenizer
|
|
||||||
import gptc.model
|
|
||||||
from typing import Iterable, Mapping, List, Dict, Union
|
|
||||||
|
|
||||||
|
|
||||||
def compile(
|
|
||||||
raw_model: Iterable[Mapping[str, str]],
|
|
||||||
max_ngram_length: int = 1,
|
|
||||||
min_count: int = 1,
|
|
||||||
hash_algorithm: str = "sha256",
|
|
||||||
) -> gptc.model.Model:
|
|
||||||
"""Compile a raw model.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
raw_model : list of dict
|
|
||||||
A raw GPTC model.
|
|
||||||
|
|
||||||
max_ngram_length : int
|
|
||||||
Maximum ngram lenght to compile with.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
dict
|
|
||||||
A compiled GPTC model.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
word_counts: Dict[int, Dict[str, int]] = {}
|
|
||||||
category_lengths: Dict[str, int] = {}
|
|
||||||
names: List[str] = []
|
|
||||||
|
|
||||||
for portion in raw_model:
|
|
||||||
text = gptc.tokenizer.hash(
|
|
||||||
gptc.tokenizer.tokenize(portion["text"], max_ngram_length),
|
|
||||||
hash_algorithm,
|
|
||||||
)
|
|
||||||
category = portion["category"]
|
|
||||||
|
|
||||||
if not category in names:
|
|
||||||
names.append(category)
|
|
||||||
|
|
||||||
category_lengths[category] = category_lengths.get(category, 0) + len(
|
|
||||||
text
|
|
||||||
)
|
|
||||||
|
|
||||||
for word in text:
|
|
||||||
if word in word_counts:
|
|
||||||
try:
|
|
||||||
word_counts[word][category] += 1
|
|
||||||
except KeyError:
|
|
||||||
word_counts[word][category] = 1
|
|
||||||
else:
|
|
||||||
word_counts[word] = {category: 1}
|
|
||||||
|
|
||||||
model: Dict[int, List[int]] = {}
|
|
||||||
for word, counts in word_counts.items():
|
|
||||||
if sum(counts.values()) >= min_count:
|
|
||||||
weights = {
|
|
||||||
category: value / category_lengths[category]
|
|
||||||
for category, value in counts.items()
|
|
||||||
}
|
|
||||||
total = sum(weights.values())
|
|
||||||
new_weights: List[int] = []
|
|
||||||
for category in names:
|
|
||||||
new_weights.append(
|
|
||||||
round((weights.get(category, 0) / total) * 65535)
|
|
||||||
)
|
|
||||||
model[word] = new_weights
|
|
||||||
|
|
||||||
return gptc.model.Model(model, names, max_ngram_length, hash_algorithm)
|
|
273
gptc/model.py
273
gptc/model.py
|
@ -1,10 +1,120 @@
|
||||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||||
|
|
||||||
|
from typing import (
|
||||||
|
Iterable,
|
||||||
|
Mapping,
|
||||||
|
List,
|
||||||
|
Dict,
|
||||||
|
cast,
|
||||||
|
BinaryIO,
|
||||||
|
Tuple,
|
||||||
|
TypedDict,
|
||||||
|
)
|
||||||
|
import json
|
||||||
import gptc.tokenizer
|
import gptc.tokenizer
|
||||||
from gptc.exceptions import InvalidModelError
|
from gptc.exceptions import InvalidModelError
|
||||||
import gptc.weighting
|
import gptc.weighting
|
||||||
from typing import Iterable, Mapping, List, Dict, Union, cast, BinaryIO
|
|
||||||
import json
|
def _count_words(
|
||||||
|
raw_model: Iterable[Mapping[str, str]],
|
||||||
|
max_ngram_length: int,
|
||||||
|
hash_algorithm: str,
|
||||||
|
) -> Tuple[Dict[int, Dict[str, int]], Dict[str, int], List[str]]:
|
||||||
|
word_counts: Dict[int, Dict[str, int]] = {}
|
||||||
|
category_lengths: Dict[str, int] = {}
|
||||||
|
names: List[str] = []
|
||||||
|
|
||||||
|
for portion in raw_model:
|
||||||
|
text = gptc.tokenizer.hash_list(
|
||||||
|
gptc.tokenizer.tokenize(portion["text"], max_ngram_length),
|
||||||
|
hash_algorithm,
|
||||||
|
)
|
||||||
|
category = portion["category"]
|
||||||
|
|
||||||
|
if not category in names:
|
||||||
|
names.append(category)
|
||||||
|
|
||||||
|
category_lengths[category] = category_lengths.get(category, 0) + len(
|
||||||
|
text
|
||||||
|
)
|
||||||
|
|
||||||
|
for word in text:
|
||||||
|
if word in word_counts:
|
||||||
|
try:
|
||||||
|
word_counts[word][category] += 1
|
||||||
|
except KeyError:
|
||||||
|
word_counts[word][category] = 1
|
||||||
|
else:
|
||||||
|
word_counts[word] = {category: 1}
|
||||||
|
|
||||||
|
return word_counts, category_lengths, names
|
||||||
|
|
||||||
|
|
||||||
|
def _get_weights(
|
||||||
|
min_count: int,
|
||||||
|
word_counts: Dict[int, Dict[str, int]],
|
||||||
|
category_lengths: Dict[str, int],
|
||||||
|
names: List[str],
|
||||||
|
) -> Dict[int, List[int]]:
|
||||||
|
model: Dict[int, List[int]] = {}
|
||||||
|
for word, counts in word_counts.items():
|
||||||
|
if sum(counts.values()) >= min_count:
|
||||||
|
weights = {
|
||||||
|
category: value / category_lengths[category]
|
||||||
|
for category, value in counts.items()
|
||||||
|
}
|
||||||
|
total = sum(weights.values())
|
||||||
|
new_weights: List[int] = []
|
||||||
|
for category in names:
|
||||||
|
new_weights.append(
|
||||||
|
round((weights.get(category, 0) / total) * 65535)
|
||||||
|
)
|
||||||
|
model[word] = new_weights
|
||||||
|
return model
|
||||||
|
|
||||||
|
class ExplanationEntry(TypedDict):
|
||||||
|
weight: float
|
||||||
|
probabilities: Dict[str, float]
|
||||||
|
count: int
|
||||||
|
|
||||||
|
|
||||||
|
Explanation = Dict[
|
||||||
|
str,
|
||||||
|
ExplanationEntry,
|
||||||
|
]
|
||||||
|
|
||||||
|
Log = List[Tuple[str, float, List[float]]]
|
||||||
|
|
||||||
|
|
||||||
|
class Confidences(dict[str, float]):
|
||||||
|
def __init__(self, probs: Dict[str, float]):
|
||||||
|
dict.__init__(self, probs)
|
||||||
|
|
||||||
|
|
||||||
|
class TransparentConfidences(Confidences):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
probs: Dict[str, float],
|
||||||
|
explanation: Explanation,
|
||||||
|
):
|
||||||
|
self.explanation = explanation
|
||||||
|
Confidences.__init__(self, probs)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_log(log: Log, names: List[str]) -> Explanation:
|
||||||
|
explanation: Explanation = {}
|
||||||
|
for word2, weight, word_probs in log:
|
||||||
|
if word2 in explanation:
|
||||||
|
explanation[word2]["count"] += 1
|
||||||
|
else:
|
||||||
|
explanation[word2] = {
|
||||||
|
"weight": weight,
|
||||||
|
"probabilities": {
|
||||||
|
name: word_probs[index] for index, name in enumerate(names)
|
||||||
|
},
|
||||||
|
"count": 1,
|
||||||
|
}
|
||||||
|
return explanation
|
||||||
|
|
||||||
|
|
||||||
class Model:
|
class Model:
|
||||||
|
@ -20,7 +130,9 @@ class Model:
|
||||||
self.max_ngram_length = max_ngram_length
|
self.max_ngram_length = max_ngram_length
|
||||||
self.hash_algorithm = hash_algorithm
|
self.hash_algorithm = hash_algorithm
|
||||||
|
|
||||||
def confidence(self, text: str, max_ngram_length: int) -> Dict[str, float]:
|
def confidence(
|
||||||
|
self, text: str, max_ngram_length: int, transparent: bool = False
|
||||||
|
) -> Confidences:
|
||||||
"""Classify text with confidence.
|
"""Classify text with confidence.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
|
@ -40,19 +152,42 @@ class Model:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model = self.weights
|
model = self.weights
|
||||||
|
max_ngram_length = min(self.max_ngram_length, max_ngram_length)
|
||||||
|
|
||||||
tokens = gptc.tokenizer.hash(
|
raw_tokens = gptc.tokenizer.tokenize(
|
||||||
gptc.tokenizer.tokenize(
|
text, min(max_ngram_length, self.max_ngram_length)
|
||||||
text, min(max_ngram_length, self.max_ngram_length)
|
)
|
||||||
),
|
|
||||||
|
tokens = gptc.tokenizer.hash_list(
|
||||||
|
raw_tokens,
|
||||||
self.hash_algorithm,
|
self.hash_algorithm,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if transparent:
|
||||||
|
token_map = {tokens[i]: raw_tokens[i] for i in range(len(tokens))}
|
||||||
|
log: Log = []
|
||||||
|
|
||||||
numbered_probs: Dict[int, float] = {}
|
numbered_probs: Dict[int, float] = {}
|
||||||
|
|
||||||
for word in tokens:
|
for word in tokens:
|
||||||
try:
|
try:
|
||||||
weighted_numbers = gptc.weighting.weight(
|
unweighted_numbers = [
|
||||||
[i / 65535 for i in cast(List[float], model[word])]
|
i / 65535 for i in cast(List[float], model[word])
|
||||||
|
]
|
||||||
|
|
||||||
|
weight, weighted_numbers = gptc.weighting.weight(
|
||||||
|
unweighted_numbers
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if transparent:
|
||||||
|
log.append(
|
||||||
|
(
|
||||||
|
token_map[word],
|
||||||
|
weight,
|
||||||
|
unweighted_numbers,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
for category, value in enumerate(weighted_numbers):
|
for category, value in enumerate(weighted_numbers):
|
||||||
try:
|
try:
|
||||||
numbered_probs[category] += value
|
numbered_probs[category] += value
|
||||||
|
@ -60,12 +195,18 @@ class Model:
|
||||||
numbered_probs[category] = value
|
numbered_probs[category] = value
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
total = sum(numbered_probs.values())
|
total = sum(numbered_probs.values())
|
||||||
probs: Dict[str, float] = {
|
probs: Dict[str, float] = {
|
||||||
self.names[category]: value / total
|
self.names[category]: value / total
|
||||||
for category, value in numbered_probs.items()
|
for category, value in numbered_probs.items()
|
||||||
}
|
}
|
||||||
return probs
|
|
||||||
|
if transparent:
|
||||||
|
explanation = convert_log(log, self.names)
|
||||||
|
return TransparentConfidences(probs, explanation)
|
||||||
|
|
||||||
|
return Confidences(probs)
|
||||||
|
|
||||||
def get(self, token: str) -> Dict[str, float]:
|
def get(self, token: str) -> Dict[str, float]:
|
||||||
try:
|
try:
|
||||||
|
@ -82,7 +223,7 @@ class Model:
|
||||||
}
|
}
|
||||||
|
|
||||||
def serialize(self, file: BinaryIO) -> None:
|
def serialize(self, file: BinaryIO) -> None:
|
||||||
file.write(b"GPTC model v5\n")
|
file.write(b"GPTC model v6\n")
|
||||||
file.write(
|
file.write(
|
||||||
json.dumps(
|
json.dumps(
|
||||||
{
|
{
|
||||||
|
@ -99,53 +240,83 @@ class Model:
|
||||||
+ b"".join([weight.to_bytes(2, "big") for weight in weights])
|
+ b"".join([weight.to_bytes(2, "big") for weight in weights])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def compile(
|
||||||
|
raw_model: Iterable[Mapping[str, str]],
|
||||||
|
max_ngram_length: int = 1,
|
||||||
|
min_count: int = 1,
|
||||||
|
hash_algorithm: str = "sha256",
|
||||||
|
) -> 'Model':
|
||||||
|
"""Compile a raw model.
|
||||||
|
|
||||||
def deserialize(encoded_model: BinaryIO) -> Model:
|
Parameters
|
||||||
prefix = encoded_model.read(14)
|
----------
|
||||||
if prefix != b"GPTC model v5\n":
|
raw_model : list of dict
|
||||||
raise InvalidModelError()
|
A raw GPTC model.
|
||||||
|
|
||||||
config_json = b""
|
max_ngram_length : int
|
||||||
while True:
|
Maximum ngram lenght to compile with.
|
||||||
byte = encoded_model.read(1)
|
|
||||||
if byte == b"\n":
|
Returns
|
||||||
break
|
-------
|
||||||
elif byte == b"":
|
dict
|
||||||
|
A compiled GPTC model.
|
||||||
|
|
||||||
|
"""
|
||||||
|
word_counts, category_lengths, names = _count_words(
|
||||||
|
raw_model, max_ngram_length, hash_algorithm
|
||||||
|
)
|
||||||
|
model = _get_weights(min_count, word_counts, category_lengths, names)
|
||||||
|
return Model(model, names, max_ngram_length, hash_algorithm)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def deserialize(encoded_model: BinaryIO) -> "Model":
|
||||||
|
prefix = encoded_model.read(14)
|
||||||
|
if prefix != b"GPTC model v6\n":
|
||||||
raise InvalidModelError()
|
raise InvalidModelError()
|
||||||
else:
|
|
||||||
|
config_json = b""
|
||||||
|
while True:
|
||||||
|
byte = encoded_model.read(1)
|
||||||
|
if byte == b"\n":
|
||||||
|
break
|
||||||
|
|
||||||
|
if byte == b"":
|
||||||
|
raise InvalidModelError()
|
||||||
|
|
||||||
config_json += byte
|
config_json += byte
|
||||||
|
|
||||||
try:
|
try:
|
||||||
config = json.loads(config_json.decode("utf-8"))
|
config = json.loads(config_json.decode("utf-8"))
|
||||||
except (UnicodeDecodeError, json.JSONDecodeError):
|
except (UnicodeDecodeError, json.JSONDecodeError) as exc:
|
||||||
raise InvalidModelError()
|
raise InvalidModelError() from exc
|
||||||
|
|
||||||
try:
|
try:
|
||||||
names = config["names"]
|
names = config["names"]
|
||||||
max_ngram_length = config["max_ngram_length"]
|
max_ngram_length = config["max_ngram_length"]
|
||||||
hash_algorithm = config["hash_algorithm"]
|
hash_algorithm = config["hash_algorithm"]
|
||||||
except KeyError:
|
except KeyError as exc:
|
||||||
raise InvalidModelError()
|
raise InvalidModelError() from exc
|
||||||
|
|
||||||
if not (
|
if not (
|
||||||
isinstance(names, list) and isinstance(max_ngram_length, int)
|
isinstance(names, list) and isinstance(max_ngram_length, int)
|
||||||
) or not all([isinstance(name, str) for name in names]):
|
) or not all(isinstance(name, str) for name in names):
|
||||||
raise InvalidModelError()
|
|
||||||
|
|
||||||
weight_code_length = 6 + 2 * len(names)
|
|
||||||
|
|
||||||
weights: Dict[int, List[int]] = {}
|
|
||||||
|
|
||||||
while True:
|
|
||||||
code = encoded_model.read(weight_code_length)
|
|
||||||
if not code:
|
|
||||||
break
|
|
||||||
elif len(code) != weight_code_length:
|
|
||||||
raise InvalidModelError()
|
raise InvalidModelError()
|
||||||
|
|
||||||
weights[int.from_bytes(code[:6], "big")] = [
|
weight_code_length = 6 + 2 * len(names)
|
||||||
int.from_bytes(value, "big")
|
|
||||||
for value in [code[x : x + 2] for x in range(6, len(code), 2)]
|
|
||||||
]
|
|
||||||
|
|
||||||
return Model(weights, names, max_ngram_length, hash_algorithm)
|
weights: Dict[int, List[int]] = {}
|
||||||
|
|
||||||
|
while True:
|
||||||
|
code = encoded_model.read(weight_code_length)
|
||||||
|
if not code:
|
||||||
|
break
|
||||||
|
if len(code) != weight_code_length:
|
||||||
|
raise InvalidModelError()
|
||||||
|
|
||||||
|
weights[int.from_bytes(code[:6], "big")] = [
|
||||||
|
int.from_bytes(value, "big")
|
||||||
|
for value in [code[x : x + 2] for x in range(6, len(code), 2)]
|
||||||
|
]
|
||||||
|
|
||||||
|
return Model(weights, names, max_ngram_length, hash_algorithm)
|
||||||
|
|
20
gptc/pack.py
20
gptc/pack.py
|
@ -7,7 +7,7 @@ from typing import List, Dict, Tuple
|
||||||
|
|
||||||
def pack(
|
def pack(
|
||||||
directory: str, print_exceptions: bool = False
|
directory: str, print_exceptions: bool = False
|
||||||
) -> Tuple[List[Dict[str, str]], List[Tuple[Exception]]]:
|
) -> Tuple[List[Dict[str, str]], List[Tuple[OSError]]]:
|
||||||
paths = os.listdir(directory)
|
paths = os.listdir(directory)
|
||||||
texts: Dict[str, List[str]] = {}
|
texts: Dict[str, List[str]] = {}
|
||||||
exceptions = []
|
exceptions = []
|
||||||
|
@ -17,16 +17,18 @@ def pack(
|
||||||
try:
|
try:
|
||||||
for file in os.listdir(os.path.join(directory, path)):
|
for file in os.listdir(os.path.join(directory, path)):
|
||||||
try:
|
try:
|
||||||
with open(os.path.join(directory, path, file)) as f:
|
with open(
|
||||||
texts[path].append(f.read())
|
os.path.join(directory, path, file), encoding="utf-8"
|
||||||
except Exception as e:
|
) as input_file:
|
||||||
exceptions.append((e,))
|
texts[path].append(input_file.read())
|
||||||
|
except OSError as error:
|
||||||
|
exceptions.append((error,))
|
||||||
if print_exceptions:
|
if print_exceptions:
|
||||||
print(e, file=sys.stderr)
|
print(error, file=sys.stderr)
|
||||||
except Exception as e:
|
except OSError as error:
|
||||||
exceptions.append((e,))
|
exceptions.append((error,))
|
||||||
if print_exceptions:
|
if print_exceptions:
|
||||||
print(e, file=sys.stderr)
|
print(error, file=sys.stderr)
|
||||||
|
|
||||||
raw_model = []
|
raw_model = []
|
||||||
|
|
||||||
|
|
|
@ -1,13 +1,13 @@
|
||||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||||
|
|
||||||
from typing import List, Union, Callable, Any, cast
|
import unicodedata
|
||||||
|
from typing import List, cast
|
||||||
import hashlib
|
import hashlib
|
||||||
import emoji
|
import emoji
|
||||||
import unicodedata
|
|
||||||
|
|
||||||
|
|
||||||
def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
|
def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
|
||||||
text = unicodedata.normalize("NFKD", text).lower()
|
text = unicodedata.normalize("NFKD", text).casefold()
|
||||||
parts = []
|
parts = []
|
||||||
highest_end = 0
|
highest_end = 0
|
||||||
for emoji_part in emoji.emoji_list(text):
|
for emoji_part in emoji.emoji_list(text):
|
||||||
|
@ -37,12 +37,12 @@ def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
|
||||||
|
|
||||||
if max_ngram_length == 1:
|
if max_ngram_length == 1:
|
||||||
return tokens
|
return tokens
|
||||||
else:
|
|
||||||
ngrams = []
|
ngrams = []
|
||||||
for ngram_length in range(1, max_ngram_length + 1):
|
for ngram_length in range(1, max_ngram_length + 1):
|
||||||
for index in range(len(tokens) + 1 - ngram_length):
|
for index in range(len(tokens) + 1 - ngram_length):
|
||||||
ngrams.append(" ".join(tokens[index : index + ngram_length]))
|
ngrams.append(" ".join(tokens[index : index + ngram_length]))
|
||||||
return ngrams
|
return ngrams
|
||||||
|
|
||||||
|
|
||||||
def _hash_single(token: str, hash_function: type) -> int:
|
def _hash_single(token: str, hash_function: type) -> int:
|
||||||
|
@ -69,15 +69,15 @@ def _get_hash_function(hash_algorithm: str) -> type:
|
||||||
"sha3_384",
|
"sha3_384",
|
||||||
}:
|
}:
|
||||||
return cast(type, getattr(hashlib, hash_algorithm))
|
return cast(type, getattr(hashlib, hash_algorithm))
|
||||||
else:
|
|
||||||
raise ValueError("not a valid hash function: " + hash_algorithm)
|
raise ValueError("not a valid hash function: " + hash_algorithm)
|
||||||
|
|
||||||
|
|
||||||
def hash_single(token: str, hash_algorithm: str) -> int:
|
def hash_single(token: str, hash_algorithm: str) -> int:
|
||||||
return _hash_single(token, _get_hash_function(hash_algorithm))
|
return _hash_single(token, _get_hash_function(hash_algorithm))
|
||||||
|
|
||||||
|
|
||||||
def hash(tokens: List[str], hash_algorithm: str) -> List[int]:
|
def hash_list(tokens: List[str], hash_algorithm: str) -> List[int]:
|
||||||
hash_function = _get_hash_function(hash_algorithm)
|
hash_function = _get_hash_function(hash_algorithm)
|
||||||
return [_hash_single(token, hash_function) for token in tokens]
|
return [_hash_single(token, hash_function) for token in tokens]
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Sequence, Union, Tuple, List
|
from typing import Sequence, Tuple, List
|
||||||
|
|
||||||
|
|
||||||
def _mean(numbers: Sequence[float]) -> float:
|
def _mean(numbers: Sequence[float]) -> float:
|
||||||
|
@ -39,8 +39,8 @@ def _standard_deviation(numbers: Sequence[float]) -> float:
|
||||||
return math.sqrt(_mean(squared_deviations))
|
return math.sqrt(_mean(squared_deviations))
|
||||||
|
|
||||||
|
|
||||||
def weight(numbers: Sequence[float]) -> List[float]:
|
def weight(numbers: Sequence[float]) -> Tuple[float, List[float]]:
|
||||||
standard_deviation = _standard_deviation(numbers)
|
standard_deviation = _standard_deviation(numbers)
|
||||||
weight = standard_deviation * 2
|
weight_assigned = standard_deviation * 2
|
||||||
weighted_numbers = [i * weight for i in numbers]
|
weighted_numbers = [i * weight_assigned for i in numbers]
|
||||||
return weighted_numbers
|
return weight_assigned, weighted_numbers
|
||||||
|
|
Binary file not shown.
16
profiler.py
Normal file
16
profiler.py
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||||
|
|
||||||
|
import cProfile
|
||||||
|
import gptc
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
|
||||||
|
max_ngram_length = 10
|
||||||
|
|
||||||
|
with open("models/raw.json") as f:
|
||||||
|
raw_model = json.load(f)
|
||||||
|
|
||||||
|
with open("models/benchmark_text.txt") as f:
|
||||||
|
text = f.read()
|
||||||
|
|
||||||
|
cProfile.run("gptc.Model.compile(raw_model, max_ngram_length)")
|
Loading…
Reference in New Issue
Block a user