Performance improvements
This commit is contained in:
parent
a76c6d3da8
commit
7e7b5f3e9c
|
@ -43,9 +43,9 @@ example of the format. Any exceptions will be printed to stderr.
|
|||
|
||||
## Library
|
||||
|
||||
### `Model.serialize()`
|
||||
### `Model.serialize(file)`
|
||||
|
||||
Returns a `bytes` representing the model.
|
||||
Write binary data representing the model to `file`.
|
||||
|
||||
### `gptc.deserialize(encoded_model)`
|
||||
|
||||
|
|
|
@ -66,9 +66,7 @@ def main() -> None:
|
|||
with open(args.model, "r") as f:
|
||||
model = json.load(f)
|
||||
|
||||
sys.stdout.buffer.write(
|
||||
gptc.compile(model, args.max_ngram_length, args.min_count).serialize()
|
||||
)
|
||||
gptc.compile(model, args.max_ngram_length, args.min_count).serialize(sys.stdout.buffer)
|
||||
elif args.subparser_name == "classify":
|
||||
with open(args.model, "rb") as f:
|
||||
model = gptc.deserialize(f.read())
|
||||
|
|
|
@ -27,23 +27,23 @@ def compile(
|
|||
|
||||
"""
|
||||
|
||||
categories: Dict[str, List[int]] = {}
|
||||
word_counts: Dict[int, Dict[str, int]] = {}
|
||||
category_lengths: Dict[str, int] = {}
|
||||
names: List[str] = []
|
||||
|
||||
for portion in raw_model:
|
||||
text = gptc.tokenizer.hash(
|
||||
gptc.tokenizer.tokenize(portion["text"], max_ngram_length)
|
||||
)
|
||||
category = portion["category"]
|
||||
try:
|
||||
categories[category] += text
|
||||
except KeyError:
|
||||
categories[category] = text
|
||||
|
||||
word_counts: Dict[int, Dict[str, int]] = {}
|
||||
if not category in names:
|
||||
names.append(category)
|
||||
|
||||
names = list(categories.keys())
|
||||
category_lengths[category] = category_lengths.get(category, 0) + len(
|
||||
text
|
||||
)
|
||||
|
||||
for category, text in categories.items():
|
||||
for word in text:
|
||||
if word in word_counts:
|
||||
try:
|
||||
|
@ -53,27 +53,22 @@ def compile(
|
|||
else:
|
||||
word_counts[word] = {category: 1}
|
||||
|
||||
category_lengths = {
|
||||
category: len(text) for category, text in categories.items()
|
||||
}
|
||||
|
||||
word_weights: Dict[int, Dict[str, float]] = {
|
||||
word: {
|
||||
category: value / category_lengths[category]
|
||||
for category, value in values.items()
|
||||
}
|
||||
for word, values in word_counts.items()
|
||||
if sum(values.values()) >= min_count
|
||||
}
|
||||
print("counted")
|
||||
|
||||
model: Dict[int, List[int]] = {}
|
||||
for word, weights in word_weights.items():
|
||||
total = sum(weights.values())
|
||||
new_weights: List[int] = []
|
||||
for category in names:
|
||||
new_weights.append(
|
||||
round((weights.get(category, 0) / total) * 65535)
|
||||
)
|
||||
model[word] = new_weights
|
||||
for word, counts in word_counts.items():
|
||||
if sum(counts.values()) >= min_count:
|
||||
weights = {
|
||||
category: value / category_lengths[category]
|
||||
for category, value in counts.items()
|
||||
}
|
||||
total = sum(weights.values())
|
||||
new_weights: List[int] = []
|
||||
for category in names:
|
||||
new_weights.append(
|
||||
round((weights.get(category, 0) / total) * 65535)
|
||||
)
|
||||
model[word] = new_weights
|
||||
|
||||
print("weighted")
|
||||
return gptc.model.Model(model, names, max_ngram_length)
|
||||
|
|
|
@ -40,7 +40,9 @@ class Model:
|
|||
model = self.weights
|
||||
|
||||
tokens = gptc.tokenizer.hash(
|
||||
gptc.tokenizer.tokenize(text, min(max_ngram_length, self.max_ngram_length))
|
||||
gptc.tokenizer.tokenize(
|
||||
text, min(max_ngram_length, self.max_ngram_length)
|
||||
)
|
||||
)
|
||||
numbered_probs: Dict[int, float] = {}
|
||||
for word in tokens:
|
||||
|
@ -74,9 +76,9 @@ class Model:
|
|||
for index, category in enumerate(self.names)
|
||||
}
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
out = b"GPTC model v4\n"
|
||||
out += (
|
||||
def serialize(self, file):
|
||||
file.write(b"GPTC model v4\n")
|
||||
file.write(
|
||||
json.dumps(
|
||||
{
|
||||
"names": self.names,
|
||||
|
@ -96,10 +98,10 @@ class Model:
|
|||
+ b"\n"
|
||||
)
|
||||
for word, weights in self.weights.items():
|
||||
out += word.to_bytes(6, "big") + b"".join(
|
||||
[weight.to_bytes(2, "big") for weight in weights]
|
||||
file.write(
|
||||
word.to_bytes(6, "big")
|
||||
+ b"".join([weight.to_bytes(2, "big") for weight in weights])
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def deserialize(encoded_model: bytes) -> Model:
|
||||
|
@ -122,9 +124,9 @@ def deserialize(encoded_model: bytes) -> Model:
|
|||
except KeyError:
|
||||
raise InvalidModelError()
|
||||
|
||||
if not (isinstance(names, list) and isinstance(max_ngram_length, int)) or not all(
|
||||
[isinstance(name, str) for name in names]
|
||||
):
|
||||
if not (
|
||||
isinstance(names, list) and isinstance(max_ngram_length, int)
|
||||
) or not all([isinstance(name, str) for name in names]):
|
||||
raise InvalidModelError()
|
||||
|
||||
weight_code_length = 6 + 2 * len(names)
|
||||
|
|
Loading…
Reference in New Issue
Block a user