Add min_count

This commit is contained in:
Samuel Sloniker 2022-11-23 11:42:58 -08:00
parent e17c79c231
commit 1d1ccbb7cc
Signed by: kj7rrv
GPG Key ID: 1BB4029E66285A62
3 changed files with 25 additions and 5 deletions

View File

@ -29,10 +29,13 @@ stdout (or "None" if it cannot determine anything).
### Compiling models
gptc compile [-n <max_ngram_length>] <raw model file>
gptc compile [-n <max_ngram_length>] [-c <min_count>] <raw model file>
This will print the compiled model in JSON to stdout.
If `-c` is specified, words and ngrams used less than `min_count` times will be
excluded from the compiled model.
### Packing models
gptc pack <dir>
@ -68,13 +71,16 @@ The classifier's model.
Check whether emojis are supported by the `Classifier`. (See section "Emoji.")
Equivalent to `gptc.has_emoji and gptc.model_has_emoji(model)`.
### `gptc.compile(raw_model, max_ngram_length=1)`
### `gptc.compile(raw_model, max_ngram_length=1, min_count=1)`
Compile a raw model (as a list, not JSON) and return the compiled model (as a
dict).
For information about `max_ngram_length`, see section "Ngrams."
Words or ngrams used less than `min_count` times throughout the input text are
excluded from the model.
### `gptc.pack(directory, print_exceptions=False)
Pack the model in `directory` and return a tuple of the format:

View File

@ -24,6 +24,13 @@ def main() -> None:
type=int,
default=1,
)
compile_parser.add_argument(
"--min-count",
"-c",
help="minimum use count for word/ngram to be included in model",
type=int,
default=1,
)
classify_parser = subparsers.add_parser("classify", help="classify text")
classify_parser.add_argument("model", help="compiled model to use")
@ -59,7 +66,7 @@ def main() -> None:
with open(args.model, "r") as f:
model = json.load(f)
print(json.dumps(gptc.compile(model, args.max_ngram_length)))
print(json.dumps(gptc.compile(model, args.max_ngram_length, args.min_count)))
elif args.subparser_name == "classify":
with open(args.model, "r") as f:
model = json.load(f)

View File

@ -9,7 +9,9 @@ MODEL = Dict[str, Union[WEIGHTS_T, CONFIG_T]]
def compile(
raw_model: Iterable[Mapping[str, str]], max_ngram_length: int = 1
raw_model: Iterable[Mapping[str, str]],
max_ngram_length: int = 1,
min_count: int = 1,
) -> MODEL:
"""Compile a raw model.
@ -58,6 +60,12 @@ def compile(
except KeyError:
word_counts[word][category] = 1
word_counts = {
word: counts
for word, counts in word_counts.items()
if sum(counts.values()) >= min_count
}
word_weights: Dict[str, Dict[str, float]] = {}
for word, values in word_counts.items():
for category, value in values.items():
@ -84,4 +92,3 @@ def compile(
model["__emoji__"] = int(gptc.tokenizer.has_emoji)
return model