Performance improvements

This commit is contained in:
Samuel Sloniker 2022-12-22 18:01:37 -08:00
parent a76c6d3da8
commit 7e7b5f3e9c
Signed by: kj7rrv
GPG Key ID: 1BB4029E66285A62
4 changed files with 38 additions and 43 deletions

View File

@ -43,9 +43,9 @@ example of the format. Any exceptions will be printed to stderr.
## Library
### `Model.serialize()`
### `Model.serialize(file)`
Returns a `bytes` representing the model.
Write binary data representing the model to `file`.
### `gptc.deserialize(encoded_model)`

View File

@ -66,9 +66,7 @@ def main() -> None:
with open(args.model, "r") as f:
model = json.load(f)
sys.stdout.buffer.write(
gptc.compile(model, args.max_ngram_length, args.min_count).serialize()
)
gptc.compile(model, args.max_ngram_length, args.min_count).serialize(sys.stdout.buffer)
elif args.subparser_name == "classify":
with open(args.model, "rb") as f:
model = gptc.deserialize(f.read())

View File

@ -27,23 +27,23 @@ def compile(
"""
categories: Dict[str, List[int]] = {}
word_counts: Dict[int, Dict[str, int]] = {}
category_lengths: Dict[str, int] = {}
names: List[str] = []
for portion in raw_model:
text = gptc.tokenizer.hash(
gptc.tokenizer.tokenize(portion["text"], max_ngram_length)
)
category = portion["category"]
try:
categories[category] += text
except KeyError:
categories[category] = text
word_counts: Dict[int, Dict[str, int]] = {}
if not category in names:
names.append(category)
names = list(categories.keys())
category_lengths[category] = category_lengths.get(category, 0) + len(
text
)
for category, text in categories.items():
for word in text:
if word in word_counts:
try:
@ -53,27 +53,22 @@ def compile(
else:
word_counts[word] = {category: 1}
category_lengths = {
category: len(text) for category, text in categories.items()
}
word_weights: Dict[int, Dict[str, float]] = {
word: {
category: value / category_lengths[category]
for category, value in values.items()
}
for word, values in word_counts.items()
if sum(values.values()) >= min_count
}
print("counted")
model: Dict[int, List[int]] = {}
for word, weights in word_weights.items():
total = sum(weights.values())
new_weights: List[int] = []
for category in names:
new_weights.append(
round((weights.get(category, 0) / total) * 65535)
)
model[word] = new_weights
for word, counts in word_counts.items():
if sum(counts.values()) >= min_count:
weights = {
category: value / category_lengths[category]
for category, value in counts.items()
}
total = sum(weights.values())
new_weights: List[int] = []
for category in names:
new_weights.append(
round((weights.get(category, 0) / total) * 65535)
)
model[word] = new_weights
print("weighted")
return gptc.model.Model(model, names, max_ngram_length)

View File

@ -40,7 +40,9 @@ class Model:
model = self.weights
tokens = gptc.tokenizer.hash(
gptc.tokenizer.tokenize(text, min(max_ngram_length, self.max_ngram_length))
gptc.tokenizer.tokenize(
text, min(max_ngram_length, self.max_ngram_length)
)
)
numbered_probs: Dict[int, float] = {}
for word in tokens:
@ -74,9 +76,9 @@ class Model:
for index, category in enumerate(self.names)
}
def serialize(self) -> bytes:
out = b"GPTC model v4\n"
out += (
def serialize(self, file):
file.write(b"GPTC model v4\n")
file.write(
json.dumps(
{
"names": self.names,
@ -96,10 +98,10 @@ class Model:
+ b"\n"
)
for word, weights in self.weights.items():
out += word.to_bytes(6, "big") + b"".join(
[weight.to_bytes(2, "big") for weight in weights]
file.write(
word.to_bytes(6, "big")
+ b"".join([weight.to_bytes(2, "big") for weight in weights])
)
return out
def deserialize(encoded_model: bytes) -> Model:
@ -122,9 +124,9 @@ def deserialize(encoded_model: bytes) -> Model:
except KeyError:
raise InvalidModelError()
if not (isinstance(names, list) and isinstance(max_ngram_length, int)) or not all(
[isinstance(name, str) for name in names]
):
if not (
isinstance(names, list) and isinstance(max_ngram_length, int)
) or not all([isinstance(name, str) for name in names]):
raise InvalidModelError()
weight_code_length = 6 + 2 * len(names)