From 7e7b5f3e9ca3c361dcb47e8e1f83201036c6f78c Mon Sep 17 00:00:00 2001 From: Samuel Sloniker Date: Thu, 22 Dec 2022 18:01:37 -0800 Subject: [PATCH] Performance improvements --- README.md | 4 ++-- gptc/__main__.py | 4 +--- gptc/compiler.py | 51 ++++++++++++++++++++++-------------------------- gptc/model.py | 22 +++++++++++---------- 4 files changed, 38 insertions(+), 43 deletions(-) diff --git a/README.md b/README.md index dbc13fd..51ab811 100644 --- a/README.md +++ b/README.md @@ -43,9 +43,9 @@ example of the format. Any exceptions will be printed to stderr. ## Library -### `Model.serialize()` +### `Model.serialize(file)` -Returns a `bytes` representing the model. +Write binary data representing the model to `file`. ### `gptc.deserialize(encoded_model)` diff --git a/gptc/__main__.py b/gptc/__main__.py index 32ec754..eaa3143 100644 --- a/gptc/__main__.py +++ b/gptc/__main__.py @@ -66,9 +66,7 @@ def main() -> None: with open(args.model, "r") as f: model = json.load(f) - sys.stdout.buffer.write( - gptc.compile(model, args.max_ngram_length, args.min_count).serialize() - ) + gptc.compile(model, args.max_ngram_length, args.min_count).serialize(sys.stdout.buffer) elif args.subparser_name == "classify": with open(args.model, "rb") as f: model = gptc.deserialize(f.read()) diff --git a/gptc/compiler.py b/gptc/compiler.py index c67aeee..e2ab349 100755 --- a/gptc/compiler.py +++ b/gptc/compiler.py @@ -27,23 +27,23 @@ def compile( """ - categories: Dict[str, List[int]] = {} + word_counts: Dict[int, Dict[str, int]] = {} + category_lengths: Dict[str, int] = {} + names: List[str] = [] for portion in raw_model: text = gptc.tokenizer.hash( gptc.tokenizer.tokenize(portion["text"], max_ngram_length) ) category = portion["category"] - try: - categories[category] += text - except KeyError: - categories[category] = text - word_counts: Dict[int, Dict[str, int]] = {} + if not category in names: + names.append(category) - names = list(categories.keys()) + category_lengths[category] = category_lengths.get(category, 0) + len( + text + ) - for category, text in categories.items(): for word in text: if word in word_counts: try: @@ -53,27 +53,22 @@ def compile( else: word_counts[word] = {category: 1} - category_lengths = { - category: len(text) for category, text in categories.items() - } - - word_weights: Dict[int, Dict[str, float]] = { - word: { - category: value / category_lengths[category] - for category, value in values.items() - } - for word, values in word_counts.items() - if sum(values.values()) >= min_count - } + print("counted") model: Dict[int, List[int]] = {} - for word, weights in word_weights.items(): - total = sum(weights.values()) - new_weights: List[int] = [] - for category in names: - new_weights.append( - round((weights.get(category, 0) / total) * 65535) - ) - model[word] = new_weights + for word, counts in word_counts.items(): + if sum(counts.values()) >= min_count: + weights = { + category: value / category_lengths[category] + for category, value in counts.items() + } + total = sum(weights.values()) + new_weights: List[int] = [] + for category in names: + new_weights.append( + round((weights.get(category, 0) / total) * 65535) + ) + model[word] = new_weights + print("weighted") return gptc.model.Model(model, names, max_ngram_length) diff --git a/gptc/model.py b/gptc/model.py index 53e7bbc..bd64761 100644 --- a/gptc/model.py +++ b/gptc/model.py @@ -40,7 +40,9 @@ class Model: model = self.weights tokens = gptc.tokenizer.hash( - gptc.tokenizer.tokenize(text, min(max_ngram_length, self.max_ngram_length)) + gptc.tokenizer.tokenize( + text, min(max_ngram_length, self.max_ngram_length) + ) ) numbered_probs: Dict[int, float] = {} for word in tokens: @@ -74,9 +76,9 @@ class Model: for index, category in enumerate(self.names) } - def serialize(self) -> bytes: - out = b"GPTC model v4\n" - out += ( + def serialize(self, file): + file.write(b"GPTC model v4\n") + file.write( json.dumps( { "names": self.names, @@ -96,10 +98,10 @@ class Model: + b"\n" ) for word, weights in self.weights.items(): - out += word.to_bytes(6, "big") + b"".join( - [weight.to_bytes(2, "big") for weight in weights] + file.write( + word.to_bytes(6, "big") + + b"".join([weight.to_bytes(2, "big") for weight in weights]) ) - return out def deserialize(encoded_model: bytes) -> Model: @@ -122,9 +124,9 @@ def deserialize(encoded_model: bytes) -> Model: except KeyError: raise InvalidModelError() - if not (isinstance(names, list) and isinstance(max_ngram_length, int)) or not all( - [isinstance(name, str) for name in names] - ): + if not ( + isinstance(names, list) and isinstance(max_ngram_length, int) + ) or not all([isinstance(name, str) for name in names]): raise InvalidModelError() weight_code_length = 6 + 2 * len(names)