diff --git a/README.md b/README.md index 391c648..c69d2a2 100644 --- a/README.md +++ b/README.md @@ -29,10 +29,13 @@ stdout (or "None" if it cannot determine anything). ### Compiling models - gptc compile [-n ] + gptc compile [-n ] [-c ] This will print the compiled model in JSON to stdout. +If `-c` is specified, words and ngrams used less than `min_count` times will be +excluded from the compiled model. + ### Packing models gptc pack @@ -68,13 +71,16 @@ The classifier's model. Check whether emojis are supported by the `Classifier`. (See section "Emoji.") Equivalent to `gptc.has_emoji and gptc.model_has_emoji(model)`. -### `gptc.compile(raw_model, max_ngram_length=1)` +### `gptc.compile(raw_model, max_ngram_length=1, min_count=1)` Compile a raw model (as a list, not JSON) and return the compiled model (as a dict). For information about `max_ngram_length`, see section "Ngrams." +Words or ngrams used less than `min_count` times throughout the input text are +excluded from the model. + ### `gptc.pack(directory, print_exceptions=False) Pack the model in `directory` and return a tuple of the format: diff --git a/gptc/__main__.py b/gptc/__main__.py index bebac66..d6dab3c 100644 --- a/gptc/__main__.py +++ b/gptc/__main__.py @@ -24,6 +24,13 @@ def main() -> None: type=int, default=1, ) + compile_parser.add_argument( + "--min-count", + "-c", + help="minimum use count for word/ngram to be included in model", + type=int, + default=1, + ) classify_parser = subparsers.add_parser("classify", help="classify text") classify_parser.add_argument("model", help="compiled model to use") @@ -59,7 +66,7 @@ def main() -> None: with open(args.model, "r") as f: model = json.load(f) - print(json.dumps(gptc.compile(model, args.max_ngram_length))) + print(json.dumps(gptc.compile(model, args.max_ngram_length, args.min_count))) elif args.subparser_name == "classify": with open(args.model, "r") as f: model = json.load(f) diff --git a/gptc/compiler.py b/gptc/compiler.py index 667f4e7..7b3fb73 100755 --- a/gptc/compiler.py +++ b/gptc/compiler.py @@ -9,7 +9,9 @@ MODEL = Dict[str, Union[WEIGHTS_T, CONFIG_T]] def compile( - raw_model: Iterable[Mapping[str, str]], max_ngram_length: int = 1 + raw_model: Iterable[Mapping[str, str]], + max_ngram_length: int = 1, + min_count: int = 1, ) -> MODEL: """Compile a raw model. @@ -58,6 +60,12 @@ def compile( except KeyError: word_counts[word][category] = 1 + word_counts = { + word: counts + for word, counts in word_counts.items() + if sum(counts.values()) >= min_count + } + word_weights: Dict[str, Dict[str, float]] = {} for word, values in word_counts.items(): for category, value in values.items(): @@ -84,4 +92,3 @@ def compile( model["__emoji__"] = int(gptc.tokenizer.has_emoji) return model -