380 lines
11 KiB
Python
380 lines
11 KiB
Python
# Based on
|
|
# https://github.com/sizmailov/pybind11-stubgen/blob/master/pybind11_stubgen/__init__.py
|
|
|
|
from __future__ import annotations
|
|
|
|
import importlib
|
|
import logging
|
|
import re
|
|
from argparse import ArgumentParser, Namespace
|
|
from pathlib import Path
|
|
import glob
|
|
|
|
from pybind11_stubgen.parser.interface import IParser
|
|
from pybind11_stubgen.parser.mixins.error_handlers import (
|
|
IgnoreAllErrors,
|
|
IgnoreInvalidExpressionErrors,
|
|
IgnoreInvalidIdentifierErrors,
|
|
IgnoreUnresolvedNameErrors,
|
|
LogErrors,
|
|
LoggerData,
|
|
SuggestCxxSignatureFix,
|
|
TerminateOnFatalErrors,
|
|
)
|
|
from pybind11_stubgen.parser.mixins.filter import (
|
|
FilterClassMembers,
|
|
FilterInvalidIdentifiers,
|
|
FilterPybind11ViewClasses,
|
|
FilterPybindInternals,
|
|
FilterTypingModuleAttributes,
|
|
)
|
|
from pybind11_stubgen.parser.mixins.fix import (
|
|
FixBuiltinTypes,
|
|
FixCurrentModulePrefixInTypeNames,
|
|
FixMissing__all__Attribute,
|
|
FixMissing__future__AnnotationsImport,
|
|
FixMissingEnumMembersAnnotation,
|
|
FixMissingFixedSizeImport,
|
|
FixMissingImports,
|
|
FixMissingNoneHashFieldAnnotation,
|
|
FixNumpyArrayDimAnnotation,
|
|
FixNumpyArrayDimTypeVar,
|
|
FixNumpyArrayFlags,
|
|
FixNumpyArrayRemoveParameters,
|
|
FixNumpyDtype,
|
|
FixPEP585CollectionNames,
|
|
FixPybind11EnumStrDoc,
|
|
FixRedundantBuiltinsAnnotation,
|
|
FixRedundantMethodsFromBuiltinObject,
|
|
FixScipyTypeArguments,
|
|
FixTypingTypeNames,
|
|
FixValueReprRandomAddress,
|
|
OverridePrintSafeValues,
|
|
RemoveSelfAnnotation,
|
|
ReplaceReadWritePropertyWithField,
|
|
RewritePybind11EnumValueRepr,
|
|
)
|
|
from pybind11_stubgen.parser.mixins.parse import (
|
|
BaseParser,
|
|
ExtractSignaturesFromPybind11Docstrings,
|
|
ParserDispatchMixin,
|
|
)
|
|
from pybind11_stubgen.printer import Printer
|
|
from pybind11_stubgen.structs import QualifiedName
|
|
from pybind11_stubgen.writer import Writer
|
|
|
|
|
|
class CLIArgs(Namespace):
|
|
output_dir: str
|
|
root_suffix: str
|
|
ignore_invalid_expressions: re.Pattern | None
|
|
ignore_invalid_identifiers: re.Pattern | None
|
|
ignore_unresolved_names: re.Pattern | None
|
|
ignore_all_errors: bool
|
|
enum_class_locations: list[tuple[re.Pattern, str]]
|
|
numpy_array_wrap_with_annotated: bool
|
|
numpy_array_use_type_var: bool
|
|
numpy_array_remove_parameters: bool
|
|
print_invalid_expressions_as_is: bool
|
|
print_safe_value_reprs: re.Pattern | None
|
|
exit_code: bool
|
|
dry_run: bool
|
|
stub_extension: str
|
|
module_name: str
|
|
|
|
|
|
def arg_parser() -> ArgumentParser:
|
|
def regex(pattern_str: str) -> re.Pattern:
|
|
try:
|
|
return re.compile(pattern_str)
|
|
except re.error as e:
|
|
raise ValueError(f"Invalid REGEX pattern: {e}")
|
|
|
|
def regex_colon_path(regex_path: str) -> tuple[re.Pattern, str]:
|
|
pattern_str, path = regex_path.rsplit(":", maxsplit=1)
|
|
if any(not part.isidentifier() for part in path.split(".")):
|
|
raise ValueError(f"Invalid PATH: {path}")
|
|
return regex(pattern_str), path
|
|
|
|
parser = ArgumentParser(
|
|
prog="pybind11-stubgen", description="Generates stubs for specified modules"
|
|
)
|
|
parser.add_argument(
|
|
"-o",
|
|
"--output-dir",
|
|
help="The root directory for output stubs",
|
|
default=".",
|
|
)
|
|
parser.add_argument(
|
|
"--root-suffix",
|
|
type=str,
|
|
default=None,
|
|
dest="root_suffix",
|
|
help="Top-level module directory suffix",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--ignore-invalid-expressions",
|
|
metavar="REGEX",
|
|
default=None,
|
|
type=regex,
|
|
help="Ignore invalid expressions matching REGEX",
|
|
)
|
|
parser.add_argument(
|
|
"--ignore-invalid-identifiers",
|
|
metavar="REGEX",
|
|
default=None,
|
|
type=regex,
|
|
help="Ignore invalid identifiers matching REGEX",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--ignore-unresolved-names",
|
|
metavar="REGEX",
|
|
default=None,
|
|
type=regex,
|
|
help="Ignore unresolved names matching REGEX",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--ignore-all-errors",
|
|
default=False,
|
|
action="store_true",
|
|
help="Ignore all errors during module parsing",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--enum-class-locations",
|
|
dest="enum_class_locations",
|
|
metavar="REGEX:LOC",
|
|
action="append",
|
|
default=[],
|
|
type=regex_colon_path,
|
|
help="Locations of enum classes in "
|
|
"<enum-class-name-regex>:<path-to-class> format. "
|
|
"Example: `MyEnum:foo.bar.Baz`",
|
|
)
|
|
|
|
numpy_array_fix = parser.add_mutually_exclusive_group()
|
|
numpy_array_fix.add_argument(
|
|
"--numpy-array-wrap-with-annotated",
|
|
default=False,
|
|
action="store_true",
|
|
help="Replace numpy/scipy arrays of "
|
|
"'ARRAY_T[TYPE, [*DIMS], *FLAGS]' format with "
|
|
"'Annotated[ARRAY_T, TYPE, FixedSize|DynamicSize(*DIMS), *FLAGS]'",
|
|
)
|
|
numpy_array_fix.add_argument(
|
|
"--numpy-array-use-type-var",
|
|
default=False,
|
|
action="store_true",
|
|
help="Replace 'numpy.ndarray[numpy.float32[m, 1]]' with "
|
|
"'numpy.ndarray[tuple[M, typing.Literal[1]], numpy.dtype[numpy.float32]]'",
|
|
)
|
|
|
|
numpy_array_fix.add_argument(
|
|
"--numpy-array-remove-parameters",
|
|
default=False,
|
|
action="store_true",
|
|
help="Replace 'numpy.ndarray[...]' with 'numpy.ndarray'",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--print-invalid-expressions-as-is",
|
|
default=False,
|
|
action="store_true",
|
|
help="Suppress the replacement with '...' of invalid expressions"
|
|
"found in annotations",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--print-safe-value-reprs",
|
|
metavar="REGEX",
|
|
default=None,
|
|
type=regex,
|
|
help="Override the print-safe check for values matching REGEX",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--exit-code",
|
|
action="store_true",
|
|
dest="exit_code",
|
|
help="On error exits with 1 and skips stub generation",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--dry-run",
|
|
action="store_true",
|
|
dest="dry_run",
|
|
help="Don't write stubs. Parses module and report errors",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--stub-extension",
|
|
type=str,
|
|
default="pyi",
|
|
metavar="EXT",
|
|
choices=["pyi", "py"],
|
|
help="The file extension of the generated stubs. "
|
|
"Must be 'pyi' (default) or 'py'",
|
|
)
|
|
|
|
return parser
|
|
|
|
|
|
def stub_parser_from_args(args: CLIArgs) -> IParser:
|
|
error_handlers_top: list[type] = [
|
|
LoggerData,
|
|
*([IgnoreAllErrors] if args.ignore_all_errors else []),
|
|
*([IgnoreInvalidIdentifierErrors] if args.ignore_invalid_identifiers else []),
|
|
*([IgnoreInvalidExpressionErrors] if args.ignore_invalid_expressions else []),
|
|
*([IgnoreUnresolvedNameErrors] if args.ignore_unresolved_names else []),
|
|
]
|
|
error_handlers_bottom: list[type] = [
|
|
LogErrors,
|
|
SuggestCxxSignatureFix,
|
|
*([TerminateOnFatalErrors] if args.exit_code else []),
|
|
]
|
|
|
|
numpy_fixes: list[type] = [
|
|
*([FixNumpyArrayDimAnnotation] if args.numpy_array_wrap_with_annotated else []),
|
|
*([FixNumpyArrayDimTypeVar] if args.numpy_array_use_type_var else []),
|
|
*(
|
|
[FixNumpyArrayRemoveParameters]
|
|
if args.numpy_array_remove_parameters
|
|
else []
|
|
),
|
|
]
|
|
|
|
class Parser(
|
|
*error_handlers_top, # type: ignore[misc]
|
|
FixMissing__future__AnnotationsImport,
|
|
FixMissing__all__Attribute,
|
|
FixMissingNoneHashFieldAnnotation,
|
|
FixMissingImports,
|
|
FilterTypingModuleAttributes,
|
|
FixPEP585CollectionNames,
|
|
FixTypingTypeNames,
|
|
FixScipyTypeArguments,
|
|
FixMissingFixedSizeImport,
|
|
FixMissingEnumMembersAnnotation,
|
|
OverridePrintSafeValues,
|
|
*numpy_fixes, # type: ignore[misc]
|
|
FixNumpyDtype,
|
|
FixNumpyArrayFlags,
|
|
FixCurrentModulePrefixInTypeNames,
|
|
FixBuiltinTypes,
|
|
RewritePybind11EnumValueRepr,
|
|
FilterClassMembers,
|
|
ReplaceReadWritePropertyWithField,
|
|
FilterInvalidIdentifiers,
|
|
FixValueReprRandomAddress,
|
|
FixRedundantBuiltinsAnnotation,
|
|
FilterPybindInternals,
|
|
FilterPybind11ViewClasses,
|
|
FixRedundantMethodsFromBuiltinObject,
|
|
RemoveSelfAnnotation,
|
|
FixPybind11EnumStrDoc,
|
|
ExtractSignaturesFromPybind11Docstrings,
|
|
ParserDispatchMixin,
|
|
BaseParser,
|
|
*error_handlers_bottom, # type: ignore[misc]
|
|
):
|
|
pass
|
|
|
|
parser = Parser()
|
|
|
|
if args.enum_class_locations:
|
|
parser.set_pybind11_enum_locations(dict(args.enum_class_locations))
|
|
if args.ignore_invalid_identifiers is not None:
|
|
parser.set_ignored_invalid_identifiers(args.ignore_invalid_identifiers)
|
|
if args.ignore_invalid_expressions is not None:
|
|
parser.set_ignored_invalid_expressions(args.ignore_invalid_expressions)
|
|
if args.ignore_unresolved_names is not None:
|
|
parser.set_ignored_unresolved_names(args.ignore_unresolved_names)
|
|
if args.print_safe_value_reprs is not None:
|
|
parser.set_print_safe_value_pattern(args.print_safe_value_reprs)
|
|
return parser
|
|
|
|
|
|
def main() -> None:
|
|
|
|
files = glob.glob("*.so")
|
|
|
|
for fid in files:
|
|
idx: int = fid.find(".")
|
|
module_name: str = fid[:idx]
|
|
print("Processing: " + module_name)
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(name)s - [%(levelname)7s] %(message)s",
|
|
)
|
|
args = arg_parser().parse_args(namespace=CLIArgs())
|
|
|
|
parser = stub_parser_from_args(args)
|
|
printer = Printer(
|
|
invalid_expr_as_ellipses=not args.print_invalid_expressions_as_is
|
|
)
|
|
|
|
out_dir, sub_dir = to_output_and_subdir(
|
|
output_dir=args.output_dir,
|
|
module_name=module_name,
|
|
root_suffix=args.root_suffix,
|
|
)
|
|
|
|
run(
|
|
parser,
|
|
printer,
|
|
module_name,
|
|
out_dir,
|
|
sub_dir=sub_dir,
|
|
dry_run=args.dry_run,
|
|
writer=Writer(stub_ext=args.stub_extension),
|
|
)
|
|
|
|
|
|
def to_output_and_subdir(
|
|
output_dir: str, module_name: str, root_suffix: str | None
|
|
) -> tuple[Path, Path | None]:
|
|
out_dir = Path(output_dir)
|
|
|
|
module_path = module_name.split(".")
|
|
|
|
if root_suffix is None:
|
|
return out_dir.joinpath(*module_path[:-1]), None
|
|
else:
|
|
module_path = [f"{module_path[0]}{root_suffix}", *module_path[1:]]
|
|
if len(module_path) == 1:
|
|
sub_dir = Path(module_path[-1])
|
|
else:
|
|
sub_dir = None
|
|
return out_dir.joinpath(*module_path[:-1]), sub_dir
|
|
|
|
|
|
def run(
|
|
parser: IParser,
|
|
printer: Printer,
|
|
module_name: str,
|
|
out_dir: Path,
|
|
sub_dir: Path | None,
|
|
dry_run: bool,
|
|
writer: Writer,
|
|
):
|
|
module = parser.handle_module(
|
|
QualifiedName.from_str(module_name), importlib.import_module(module_name)
|
|
)
|
|
parser.finalize()
|
|
|
|
if module is None:
|
|
raise RuntimeError(f"Can't parse {module_name}")
|
|
|
|
if dry_run:
|
|
return
|
|
|
|
out_dir.mkdir(exist_ok=True, parents=True)
|
|
writer.write_module(module, printer, to=out_dir, sub_dir=sub_dir)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|