Performance improvements

This commit is contained in:
Samuel Sloniker 2022-12-22 18:01:37 -08:00
parent a76c6d3da8
commit 7e7b5f3e9c
Signed by: kj7rrv
GPG Key ID: 1BB4029E66285A62
4 changed files with 38 additions and 43 deletions

View File

@ -43,9 +43,9 @@ example of the format. Any exceptions will be printed to stderr.
## Library ## Library
### `Model.serialize()` ### `Model.serialize(file)`
Returns a `bytes` representing the model. Write binary data representing the model to `file`.
### `gptc.deserialize(encoded_model)` ### `gptc.deserialize(encoded_model)`

View File

@ -66,9 +66,7 @@ def main() -> None:
with open(args.model, "r") as f: with open(args.model, "r") as f:
model = json.load(f) model = json.load(f)
sys.stdout.buffer.write( gptc.compile(model, args.max_ngram_length, args.min_count).serialize(sys.stdout.buffer)
gptc.compile(model, args.max_ngram_length, args.min_count).serialize()
)
elif args.subparser_name == "classify": elif args.subparser_name == "classify":
with open(args.model, "rb") as f: with open(args.model, "rb") as f:
model = gptc.deserialize(f.read()) model = gptc.deserialize(f.read())

View File

@ -27,23 +27,23 @@ def compile(
""" """
categories: Dict[str, List[int]] = {} word_counts: Dict[int, Dict[str, int]] = {}
category_lengths: Dict[str, int] = {}
names: List[str] = []
for portion in raw_model: for portion in raw_model:
text = gptc.tokenizer.hash( text = gptc.tokenizer.hash(
gptc.tokenizer.tokenize(portion["text"], max_ngram_length) gptc.tokenizer.tokenize(portion["text"], max_ngram_length)
) )
category = portion["category"] category = portion["category"]
try:
categories[category] += text
except KeyError:
categories[category] = text
word_counts: Dict[int, Dict[str, int]] = {} if not category in names:
names.append(category)
names = list(categories.keys()) category_lengths[category] = category_lengths.get(category, 0) + len(
text
)
for category, text in categories.items():
for word in text: for word in text:
if word in word_counts: if word in word_counts:
try: try:
@ -53,27 +53,22 @@ def compile(
else: else:
word_counts[word] = {category: 1} word_counts[word] = {category: 1}
category_lengths = { print("counted")
category: len(text) for category, text in categories.items()
}
word_weights: Dict[int, Dict[str, float]] = {
word: {
category: value / category_lengths[category]
for category, value in values.items()
}
for word, values in word_counts.items()
if sum(values.values()) >= min_count
}
model: Dict[int, List[int]] = {} model: Dict[int, List[int]] = {}
for word, weights in word_weights.items(): for word, counts in word_counts.items():
total = sum(weights.values()) if sum(counts.values()) >= min_count:
new_weights: List[int] = [] weights = {
for category in names: category: value / category_lengths[category]
new_weights.append( for category, value in counts.items()
round((weights.get(category, 0) / total) * 65535) }
) total = sum(weights.values())
model[word] = new_weights new_weights: List[int] = []
for category in names:
new_weights.append(
round((weights.get(category, 0) / total) * 65535)
)
model[word] = new_weights
print("weighted")
return gptc.model.Model(model, names, max_ngram_length) return gptc.model.Model(model, names, max_ngram_length)

View File

@ -40,7 +40,9 @@ class Model:
model = self.weights model = self.weights
tokens = gptc.tokenizer.hash( tokens = gptc.tokenizer.hash(
gptc.tokenizer.tokenize(text, min(max_ngram_length, self.max_ngram_length)) gptc.tokenizer.tokenize(
text, min(max_ngram_length, self.max_ngram_length)
)
) )
numbered_probs: Dict[int, float] = {} numbered_probs: Dict[int, float] = {}
for word in tokens: for word in tokens:
@ -74,9 +76,9 @@ class Model:
for index, category in enumerate(self.names) for index, category in enumerate(self.names)
} }
def serialize(self) -> bytes: def serialize(self, file):
out = b"GPTC model v4\n" file.write(b"GPTC model v4\n")
out += ( file.write(
json.dumps( json.dumps(
{ {
"names": self.names, "names": self.names,
@ -96,10 +98,10 @@ class Model:
+ b"\n" + b"\n"
) )
for word, weights in self.weights.items(): for word, weights in self.weights.items():
out += word.to_bytes(6, "big") + b"".join( file.write(
[weight.to_bytes(2, "big") for weight in weights] word.to_bytes(6, "big")
+ b"".join([weight.to_bytes(2, "big") for weight in weights])
) )
return out
def deserialize(encoded_model: bytes) -> Model: def deserialize(encoded_model: bytes) -> Model:
@ -122,9 +124,9 @@ def deserialize(encoded_model: bytes) -> Model:
except KeyError: except KeyError:
raise InvalidModelError() raise InvalidModelError()
if not (isinstance(names, list) and isinstance(max_ngram_length, int)) or not all( if not (
[isinstance(name, str) for name in names] isinstance(names, list) and isinstance(max_ngram_length, int)
): ) or not all([isinstance(name, str) for name in names]):
raise InvalidModelError() raise InvalidModelError()
weight_code_length = 6 + 2 * len(names) weight_code_length = 6 + 2 * len(names)