diff --git a/gptc/model.py b/gptc/model.py index bd64761..53bcf01 100644 --- a/gptc/model.py +++ b/gptc/model.py @@ -3,7 +3,7 @@ import gptc.tokenizer from gptc.exceptions import InvalidModelError import gptc.weighting -from typing import Iterable, Mapping, List, Dict, Union, cast +from typing import Iterable, Mapping, List, Dict, Union, cast, BinaryIO import json @@ -76,7 +76,7 @@ class Model: for index, category in enumerate(self.names) } - def serialize(self, file): + def serialize(self, file: BinaryIO): file.write(b"GPTC model v4\n") file.write( json.dumps(