# Based on # https://github.com/sizmailov/pybind11-stubgen/blob/master/pybind11_stubgen/__init__.py from __future__ import annotations import importlib import logging import re from argparse import ArgumentParser, Namespace from pathlib import Path import glob from pybind11_stubgen.parser.interface import IParser from pybind11_stubgen.parser.mixins.error_handlers import ( IgnoreAllErrors, IgnoreInvalidExpressionErrors, IgnoreInvalidIdentifierErrors, IgnoreUnresolvedNameErrors, LogErrors, LoggerData, SuggestCxxSignatureFix, TerminateOnFatalErrors, ) from pybind11_stubgen.parser.mixins.filter import ( FilterClassMembers, FilterInvalidIdentifiers, FilterPybind11ViewClasses, FilterPybindInternals, FilterTypingModuleAttributes, ) from pybind11_stubgen.parser.mixins.fix import ( FixBuiltinTypes, FixCurrentModulePrefixInTypeNames, FixMissing__all__Attribute, FixMissing__future__AnnotationsImport, FixMissingEnumMembersAnnotation, FixMissingFixedSizeImport, FixMissingImports, FixMissingNoneHashFieldAnnotation, FixNumpyArrayDimAnnotation, FixNumpyArrayDimTypeVar, FixNumpyArrayFlags, FixNumpyArrayRemoveParameters, FixNumpyDtype, FixPEP585CollectionNames, FixPybind11EnumStrDoc, FixRedundantBuiltinsAnnotation, FixRedundantMethodsFromBuiltinObject, FixScipyTypeArguments, FixTypingTypeNames, FixValueReprRandomAddress, OverridePrintSafeValues, RemoveSelfAnnotation, ReplaceReadWritePropertyWithField, RewritePybind11EnumValueRepr, ) from pybind11_stubgen.parser.mixins.parse import ( BaseParser, ExtractSignaturesFromPybind11Docstrings, ParserDispatchMixin, ) from pybind11_stubgen.printer import Printer from pybind11_stubgen.structs import QualifiedName from pybind11_stubgen.writer import Writer class CLIArgs(Namespace): output_dir: str root_suffix: str ignore_invalid_expressions: re.Pattern | None ignore_invalid_identifiers: re.Pattern | None ignore_unresolved_names: re.Pattern | None ignore_all_errors: bool enum_class_locations: list[tuple[re.Pattern, str]] numpy_array_wrap_with_annotated: bool numpy_array_use_type_var: bool numpy_array_remove_parameters: bool print_invalid_expressions_as_is: bool print_safe_value_reprs: re.Pattern | None exit_code: bool dry_run: bool stub_extension: str module_name: str def arg_parser() -> ArgumentParser: def regex(pattern_str: str) -> re.Pattern: try: return re.compile(pattern_str) except re.error as e: raise ValueError(f"Invalid REGEX pattern: {e}") def regex_colon_path(regex_path: str) -> tuple[re.Pattern, str]: pattern_str, path = regex_path.rsplit(":", maxsplit=1) if any(not part.isidentifier() for part in path.split(".")): raise ValueError(f"Invalid PATH: {path}") return regex(pattern_str), path parser = ArgumentParser( prog="pybind11-stubgen", description="Generates stubs for specified modules" ) parser.add_argument( "-o", "--output-dir", help="The root directory for output stubs", default=".", ) parser.add_argument( "--root-suffix", type=str, default=None, dest="root_suffix", help="Top-level module directory suffix", ) parser.add_argument( "--ignore-invalid-expressions", metavar="REGEX", default=None, type=regex, help="Ignore invalid expressions matching REGEX", ) parser.add_argument( "--ignore-invalid-identifiers", metavar="REGEX", default=None, type=regex, help="Ignore invalid identifiers matching REGEX", ) parser.add_argument( "--ignore-unresolved-names", metavar="REGEX", default=None, type=regex, help="Ignore unresolved names matching REGEX", ) parser.add_argument( "--ignore-all-errors", default=False, action="store_true", help="Ignore all errors during module parsing", ) parser.add_argument( "--enum-class-locations", dest="enum_class_locations", metavar="REGEX:LOC", action="append", default=[], type=regex_colon_path, help="Locations of enum classes in " ": format. " "Example: `MyEnum:foo.bar.Baz`", ) numpy_array_fix = parser.add_mutually_exclusive_group() numpy_array_fix.add_argument( "--numpy-array-wrap-with-annotated", default=False, action="store_true", help="Replace numpy/scipy arrays of " "'ARRAY_T[TYPE, [*DIMS], *FLAGS]' format with " "'Annotated[ARRAY_T, TYPE, FixedSize|DynamicSize(*DIMS), *FLAGS]'", ) numpy_array_fix.add_argument( "--numpy-array-use-type-var", default=False, action="store_true", help="Replace 'numpy.ndarray[numpy.float32[m, 1]]' with " "'numpy.ndarray[tuple[M, typing.Literal[1]], numpy.dtype[numpy.float32]]'", ) numpy_array_fix.add_argument( "--numpy-array-remove-parameters", default=False, action="store_true", help="Replace 'numpy.ndarray[...]' with 'numpy.ndarray'", ) parser.add_argument( "--print-invalid-expressions-as-is", default=False, action="store_true", help="Suppress the replacement with '...' of invalid expressions" "found in annotations", ) parser.add_argument( "--print-safe-value-reprs", metavar="REGEX", default=None, type=regex, help="Override the print-safe check for values matching REGEX", ) parser.add_argument( "--exit-code", action="store_true", dest="exit_code", help="On error exits with 1 and skips stub generation", ) parser.add_argument( "--dry-run", action="store_true", dest="dry_run", help="Don't write stubs. Parses module and report errors", ) parser.add_argument( "--stub-extension", type=str, default="pyi", metavar="EXT", choices=["pyi", "py"], help="The file extension of the generated stubs. " "Must be 'pyi' (default) or 'py'", ) return parser def stub_parser_from_args(args: CLIArgs) -> IParser: error_handlers_top: list[type] = [ LoggerData, *([IgnoreAllErrors] if args.ignore_all_errors else []), *([IgnoreInvalidIdentifierErrors] if args.ignore_invalid_identifiers else []), *([IgnoreInvalidExpressionErrors] if args.ignore_invalid_expressions else []), *([IgnoreUnresolvedNameErrors] if args.ignore_unresolved_names else []), ] error_handlers_bottom: list[type] = [ LogErrors, SuggestCxxSignatureFix, *([TerminateOnFatalErrors] if args.exit_code else []), ] numpy_fixes: list[type] = [ *([FixNumpyArrayDimAnnotation] if args.numpy_array_wrap_with_annotated else []), *([FixNumpyArrayDimTypeVar] if args.numpy_array_use_type_var else []), *( [FixNumpyArrayRemoveParameters] if args.numpy_array_remove_parameters else [] ), ] class Parser( *error_handlers_top, # type: ignore[misc] FixMissing__future__AnnotationsImport, FixMissing__all__Attribute, FixMissingNoneHashFieldAnnotation, FixMissingImports, FilterTypingModuleAttributes, FixPEP585CollectionNames, FixTypingTypeNames, FixScipyTypeArguments, FixMissingFixedSizeImport, FixMissingEnumMembersAnnotation, OverridePrintSafeValues, *numpy_fixes, # type: ignore[misc] FixNumpyDtype, FixNumpyArrayFlags, FixCurrentModulePrefixInTypeNames, FixBuiltinTypes, RewritePybind11EnumValueRepr, FilterClassMembers, ReplaceReadWritePropertyWithField, FilterInvalidIdentifiers, FixValueReprRandomAddress, FixRedundantBuiltinsAnnotation, FilterPybindInternals, FilterPybind11ViewClasses, FixRedundantMethodsFromBuiltinObject, RemoveSelfAnnotation, FixPybind11EnumStrDoc, ExtractSignaturesFromPybind11Docstrings, ParserDispatchMixin, BaseParser, *error_handlers_bottom, # type: ignore[misc] ): pass parser = Parser() if args.enum_class_locations: parser.set_pybind11_enum_locations(dict(args.enum_class_locations)) if args.ignore_invalid_identifiers is not None: parser.set_ignored_invalid_identifiers(args.ignore_invalid_identifiers) if args.ignore_invalid_expressions is not None: parser.set_ignored_invalid_expressions(args.ignore_invalid_expressions) if args.ignore_unresolved_names is not None: parser.set_ignored_unresolved_names(args.ignore_unresolved_names) if args.print_safe_value_reprs is not None: parser.set_print_safe_value_pattern(args.print_safe_value_reprs) return parser def main() -> None: files = glob.glob("*.so") for fid in files: idx: int = fid.find(".") module_name: str = fid[:idx] print("Processing: " + module_name) logging.basicConfig( level=logging.INFO, format="%(name)s - [%(levelname)7s] %(message)s", ) args = arg_parser().parse_args(namespace=CLIArgs()) parser = stub_parser_from_args(args) printer = Printer( invalid_expr_as_ellipses=not args.print_invalid_expressions_as_is ) out_dir, sub_dir = to_output_and_subdir( output_dir=args.output_dir, module_name=module_name, root_suffix=args.root_suffix, ) run( parser, printer, module_name, out_dir, sub_dir=sub_dir, dry_run=args.dry_run, writer=Writer(stub_ext=args.stub_extension), ) def to_output_and_subdir( output_dir: str, module_name: str, root_suffix: str | None ) -> tuple[Path, Path | None]: out_dir = Path(output_dir) module_path = module_name.split(".") if root_suffix is None: return out_dir.joinpath(*module_path[:-1]), None else: module_path = [f"{module_path[0]}{root_suffix}", *module_path[1:]] if len(module_path) == 1: sub_dir = Path(module_path[-1]) else: sub_dir = None return out_dir.joinpath(*module_path[:-1]), sub_dir def run( parser: IParser, printer: Printer, module_name: str, out_dir: Path, sub_dir: Path | None, dry_run: bool, writer: Writer, ): module = parser.handle_module( QualifiedName.from_str(module_name), importlib.import_module(module_name) ) parser.finalize() if module is None: raise RuntimeError(f"Can't parse {module_name}") if dry_run: return out_dir.mkdir(exist_ok=True, parents=True) writer.write_module(module, printer, to=out_dir, sub_dir=sub_dir) if __name__ == "__main__": main()