From d9714e73be5cfd4d6d4394c583166df5dc3566eb Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Wed, 20 May 2026 16:34:36 +0800 Subject: [PATCH] feat(torch): expose optional codegen parameters --- scripts/generate_torch_ops.py | 447 ++++++++++++++++++++++++++++--- scripts/generate_wrappers.py | 48 +++- src/hash.h | 9 + tests/test_generate_torch_ops.py | 109 ++++++++ tests/test_torch_ops.py | 66 ++++- 5 files changed, 633 insertions(+), 46 deletions(-) diff --git a/scripts/generate_torch_ops.py b/scripts/generate_torch_ops.py index bcbe64f0f..59cf8a680 100644 --- a/scripts/generate_torch_ops.py +++ b/scripts/generate_torch_ops.py @@ -101,12 +101,11 @@ # The int-dim overload is always emitted alongside, so we lose nothing # user-visible. -# Optional ATen types we hide from the user-facing API and pass as a -# typed empty optional at the call site. Covers the common "full -# default" case for most reductions and activations. We use a typed -# `c10::optional{}` rather than bare `at::nullopt` so the compiler -# can disambiguate ops with multiple `_out` overloads (e.g. `clamp_out` -# accepts both `optional` and `optional` for `min`/`max`). +# Optional ATen types and their typed empty optional values at the call site. +# We use typed `c10::optional{}` values rather than bare `at::nullopt` so +# the compiler can disambiguate ops with multiple `_out` overloads (e.g. +# `clamp_out` accepts both `optional` and `optional` for +# `min`/`max`). _NULLOPT_BY_TYPE = { "Scalar?": "c10::optional{}", "int?": "c10::optional{}", @@ -131,7 +130,42 @@ "SymInt[3]?": "c10::optional{}", "float[]?": "c10::optional>{}", } -_HARDCODE_NULLOPT_TYPES = frozenset(_NULLOPT_BY_TYPE) + +# Optional ATen types that have a stable InfiniOps representation. These are +# exposed in generated base signatures and converted back to ATen optionals in +# the PyTorch backend. PyTorch-internal concepts without a public InfiniOps +# analogue stay hidden and are still passed as typed empty optionals. +_EXPOSED_OPTIONAL_CPP_TYPES = { + "Scalar?": "std::optional", + "int?": "std::optional", + "bool?": "std::optional", + "float?": "std::optional", + "str?": "std::optional", + "ScalarType?": "std::optional", + "Tensor?": "std::optional", + "int[]?": "std::optional>", + "int[1]?": "std::optional>", + "int[2]?": "std::optional>", + "int[3]?": "std::optional>", + "SymInt?": "std::optional", + "SymInt[]?": "std::optional>", + "SymInt[1]?": "std::optional>", + "SymInt[2]?": "std::optional>", + "SymInt[3]?": "std::optional>", + "float[]?": "std::optional>", +} +_HARDCODE_NULLOPT_TYPES = frozenset( + set(_NULLOPT_BY_TYPE) - set(_EXPOSED_OPTIONAL_CPP_TYPES) +) + + +def _normalize_cpp_type(cpp_type: str) -> str: + text = cpp_type.strip() + + if text.startswith("const "): + text = text[len("const ") :] + + return text.rstrip("&").strip() @dataclasses.dataclass @@ -140,14 +174,26 @@ class Param: aten_type: str default: str | None keyword_only: bool + cpp_type_override: str | None = None @property def is_tensor(self) -> bool: - # Real tensors only. `Tensor?` is optional and falls through to - # the hidden-param path (substituted with `at::nullopt`). + if self.cpp_type_override is not None: + return _normalize_cpp_type(self.cpp_type_override) == "Tensor" + + # Real tensors only. `Tensor?` is optional and has separate handling. return self.aten_type == "Tensor" or self.aten_type.startswith("Tensor(") + @property + def is_optional_tensor(self) -> bool: + if self.cpp_type_override is not None: + return ( + _normalize_cpp_type(self.cpp_type_override) == "std::optional" + ) + + return self.aten_type == "Tensor?" + @property def is_mutable_tensor(self) -> bool: # Mutable tensors carry `!` in their alias annotation, e.g. @@ -183,11 +229,11 @@ def is_hidden(self) -> bool: \\`int n\\` on the special chebyshev family, etc. — as missing semantic controls. They are now exposed and forwarded to ATen. - Optional ATen types (\\`Tensor?\\`, \\`Scalar?\\`, \\`int?\\`, …) remain - hidden for now — exposing them would require teaching the torch - source to thread \\`std::optional\\` through to ATen, which is a - separate refactor. The same goes for ATen-internal types like - \\`Generator?\\`/\\`Layout?\\` that have no InfiniOps analogue. + Optional ATen types with a stable InfiniOps representation + (\\`Tensor?\\`, \\`Scalar?\\`, \\`int?\\`, …) are exposed as + \\`std::optional\\`. ATen-internal types like + \\`Generator?\\`/\\`Layout?\\` still have no InfiniOps analogue, so + they remain hidden. """ return self.is_hardcoded_nullopt @@ -233,6 +279,9 @@ def hidden_value(self) -> str: @property def cpp_type(self) -> str: + if self.cpp_type_override is not None: + return self.cpp_type_override + if self.is_tensor: # `Tensor[]` / `Tensor(a!)[]` would need `std::vector` and a # different ATen call shape — not yet supported, so reject so the @@ -246,6 +295,9 @@ def cpp_type(self) -> str: return "Tensor" + if self.aten_type in _EXPOSED_OPTIONAL_CPP_TYPES: + return _EXPOSED_OPTIONAL_CPP_TYPES[self.aten_type] + if self.is_hidden: # Not exposed — the ATen call substitutes a hardcoded value # so the `cpp_type` is irrelevant. @@ -272,6 +324,8 @@ class Op: aten_name: str overload: str params: list[Param] + signature_params: list[Param] | None = None + param_bindings: list[Param | None] | None = None @property def pascal_name(self) -> str: @@ -330,8 +384,24 @@ def visible_params(self) -> list[Param]: """Params the wrapper exposes to the user; hidden ones (hardcoded optional nullopt, default-`False`/`True` bools) are filtered.""" + if self.signature_params is not None: + return self.signature_params + return [p for p in self.params if not p.is_hidden] + def api_param_for(self, schema_index: int) -> Param | None: + """Return the public signature param bound to a schema param.""" + + if self.param_bindings is not None: + return self.param_bindings[schema_index] + + param = self.params[schema_index] + + if param.is_hidden: + return None + + return param + @property def is_testable(self) -> bool: """Cheap structural check: at least one out tensor, and the first @@ -453,6 +523,161 @@ def _parse_one_arg(token: str, keyword_only: bool) -> Param: ) +_BASE_OPERATOR_RE = re.compile( + r"\bvirtual\s+void\s+operator\(\)\s*\((?P.*?)\)\s*const" + r"(?:\s*=\s*0|\s*\{)", + re.S, +) + + +def _split_cpp_params(args_str: str) -> list[str]: + parts: list[str] = [] + depth = 0 + current: list[str] = [] + + for ch in args_str: + if ch in "([<": + depth += 1 + current.append(ch) + elif ch in ")]>": + depth -= 1 + current.append(ch) + elif ch == "," and depth == 0: + piece = "".join(current).strip() + + if piece: + parts.append(piece) + + current = [] + else: + current.append(ch) + + tail = "".join(current).strip() + + if tail: + parts.append(tail) + + return parts + + +def _parse_base_signature_param(token: str) -> tuple[str, str]: + token = token.split("=", 1)[0].strip() + token = re.sub(r"\s+", " ", token) + match = re.match(r"(?P.+?)\s+(?P\w+)$", token) + + if not match: + raise ValueError(f"could not parse base operator param: {token!r}") + + return _normalize_cpp_type(match.group("type")), match.group("name") + + +def _parse_base_operator_signatures( + base_path: pathlib.Path, +) -> list[list[tuple[str, str]]]: + source = base_path.read_text() + signatures = [] + + for match in _BASE_OPERATOR_RE.finditer(source): + args = match.group("args").strip() + + if not args: + signatures.append([]) + continue + + signatures.append( + [_parse_base_signature_param(token) for token in _split_cpp_params(args)] + ) + + return signatures + + +def _cpp_types_compatible(schema_cpp_type: str, base_cpp_type: str) -> bool: + schema = _normalize_cpp_type(schema_cpp_type) + base = _normalize_cpp_type(base_cpp_type) + + if schema == base: + return True + + compatible_pairs = { + ("double", "float"), + ("int64_t", "int"), + ("std::optional", "std::optional"), + ("std::optional", "std::optional"), + } + + return (schema, base) in compatible_pairs + + +def _bind_base_signature(op: Op, signature: list[tuple[str, str]]) -> Op | None: + bindings: list[Param | None] = [None] * len(op.params) + signature_params = [] + schema_index = 0 + + for base_cpp_type, base_name in signature: + matched_index = None + + for index in range(schema_index, len(op.params)): + schema_param = op.params[index] + + if _cpp_types_compatible(schema_param.cpp_type, base_cpp_type): + matched_index = index + break + + if not _is_omittable_param(schema_param): + return None + + if matched_index is None: + return None + + schema_param = op.params[matched_index] + api_param = dataclasses.replace( + schema_param, name=base_name, cpp_type_override=base_cpp_type + ) + bindings[matched_index] = api_param + signature_params.append(api_param) + schema_index = matched_index + 1 + + if any(not _is_omittable_param(param) for param in op.params[schema_index:]): + return None + + return dataclasses.replace( + op, signature_params=signature_params, param_bindings=bindings + ) + + +def _bind_existing_base_overloads( + name: str, ops: list[Op] +) -> tuple[list[Op], list[str]]: + base_path = _base_path(name) + signatures = _parse_base_operator_signatures(base_path) + bound_ops = [] + warnings = [] + + for signature in signatures: + matches = [ + bound + for op in ops + if (bound := _bind_base_signature(op, signature)) is not None + ] + + if not matches: + rendered = ", ".join(f"{cpp_type} {param}" for cpp_type, param in signature) + warnings.append( + f"base overload `operator()({rendered})` does not match any " + "usable ATen schema" + ) + continue + + matches.sort( + key=lambda match: sum( + binding is None for binding in match.param_bindings or () + ) + ) + bound_ops.append(matches[0]) + + return bound_ops, warnings + + def _snake_to_pascal(s: str) -> str: return "".join(p.capitalize() for p in s.split("_")) @@ -483,6 +708,17 @@ def _base_path(op_name: str) -> pathlib.Path: return _BASE_DIR / f"{op_name}.h" +def _is_omittable_param(param: Param) -> bool: + return param.aten_type in _NULLOPT_BY_TYPE or param.default is not None + + +def _default_call_value(param: Param) -> str: + if param.aten_type in _NULLOPT_BY_TYPE: + return _NULLOPT_BY_TYPE[param.aten_type] + + return param.hidden_value() + + def _load_aten_yaml(version: str) -> str: """Return the `native_functions.yaml` bundled with installed `torchgen`. @@ -660,12 +896,32 @@ def _translate_default(param: Param) -> str: return raw # numeric literals (`0`, `1`, `1.0`) pass through +def _generate_base_includes(ops: list[Op]) -> str: + cpp_types = [param.cpp_type for op in ops for param in op.visible_params] + includes = [] + + if any("std::optional" in cpp_type for cpp_type in cpp_types): + includes.append("#include ") + + if any("std::string" in cpp_type for cpp_type in cpp_types): + includes.append("#include ") + + if any("std::vector" in cpp_type for cpp_type in cpp_types): + includes.append("#include ") + + includes.append('#include "operator.h"') + + return "\n".join(includes) + + def _generate_base_header(name: str, ops: list[Op]) -> str: pascal = _snake_to_pascal(name) member_decls = [] tensor_member_order = [] seen_tensor_members = set() + optional_tensor_member_order = [] + seen_optional_tensor_members = set() scalar_member_order = [] scalar_member_types = {} @@ -680,6 +936,20 @@ def _generate_base_header(name: str, ops: list[Op]) -> str: member_decls.append(f" Tensor::Strides {param.name}_strides_;") member_decls.append(f" DataType {param.name}_type_;") + for param in op.visible_params: + if ( + not param.is_optional_tensor + or param.name in seen_optional_tensor_members + ): + continue + + seen_optional_tensor_members.add(param.name) + optional_tensor_member_order.append(param.name) + member_decls.append(f" bool has_{param.name}_{{false}};") + member_decls.append(f" Tensor::Shape {param.name}_shape_;") + member_decls.append(f" Tensor::Strides {param.name}_strides_;") + member_decls.append(f" DataType {param.name}_type_{{DataType::kFloat32}};") + # Visible non-tensor params (scalars, strings, vectors) are also # stored on the base so backends can dispatch on them later — not # only at the moment `operator()` is invoked. Reviewers flagged @@ -689,7 +959,11 @@ def _generate_base_header(name: str, ops: list[Op]) -> str: # overload's member is dropped (later constructors leave it # default-initialised). for param in op.visible_params: - if param.is_tensor or param.name in scalar_member_types: + if ( + param.is_tensor + or param.is_optional_tensor + or param.name in scalar_member_types + ): continue scalar_member_order.append(param.name) @@ -704,10 +978,14 @@ def _generate_base_header(name: str, ops: list[Op]) -> str: for op in ops: init_pieces = [] tensor_params = {param.name: param for param in op.tensor_params} + optional_tensor_params = { + param.name: param for param in op.visible_params if param.is_optional_tensor + } scalar_params = { param.name: param for param in op.visible_params if not param.is_tensor + and not param.is_optional_tensor and scalar_member_types.get(param.name) == param.cpp_type } @@ -723,6 +1001,26 @@ def _generate_base_header(name: str, ops: list[Op]) -> str: ) init_pieces.append(f" {param.name}_type_{{{param.name}.dtype()}}") + for param_name in optional_tensor_member_order: + param = optional_tensor_params.get(param_name) + + if param is None: + continue + + init_pieces.append(f" has_{param.name}_{{{param.name}.has_value()}}") + init_pieces.append( + f" {param.name}_shape_{{{param.name} ? " + f"{param.name}->shape() : Tensor::Shape{{}}}}" + ) + init_pieces.append( + f" {param.name}_strides_{{{param.name} ? " + f"{param.name}->strides() : Tensor::Strides{{}}}}" + ) + init_pieces.append( + f" {param.name}_type_{{{param.name} ? " + f"{param.name}->dtype() : DataType::kFloat32}}" + ) + for param_name in scalar_member_order: param = scalar_params.get(param_name) @@ -746,6 +1044,7 @@ def _generate_base_header(name: str, ops: list[Op]) -> str: return _BASE_TEMPLATE.format( name_uc=name.upper(), pascal=pascal, + includes=_generate_base_includes(ops), constructors="\n\n".join(constructors), op_calls="\n\n".join(calls), member_decls="\n\n".join(member_decls), @@ -771,21 +1070,87 @@ def _generate_torch_method_source(name: str, op: Op) -> str: pascal = _snake_to_pascal(name) conversion_lines = [] - for param in op.tensor_params: + def _optional_aten_type(param: Param) -> str: + return _NULLOPT_BY_TYPE[param.aten_type].removesuffix("{}") + + def _optional_aten_value(schema_param: Param, api_param: Param) -> str: + if schema_param.aten_type == "Tensor?": + data_expr = f"const_cast({api_param.name}->data())" + + return ( + f"ToAtenTensor({data_expr}, {api_param.name}_shape_, " + f"{api_param.name}_strides_, {api_param.name}_type_, device_index_)" + ) + + if schema_param.aten_type == "Scalar?": + return f"at::Scalar(*{api_param.name})" + + if schema_param.aten_type == "ScalarType?": + return f"ToAtenDataType(*{api_param.name})" + + if schema_param.aten_type == "str?": + return f"c10::string_view(*{api_param.name})" + + if schema_param.aten_type.startswith( + ("int[", "SymInt[") + ) or schema_param.aten_type in { + "int[]?", + "SymInt[]?", + }: + return f"at::IntArrayRef(*{api_param.name})" + + if schema_param.aten_type == "float[]?": + return f"at::ArrayRef(*{api_param.name})" + + return f"*{api_param.name}" + + def _append_optional_conversion(schema_param: Param, api_param: Param) -> None: + optional_type = _optional_aten_type(schema_param) + conversion_lines.append(f" {optional_type} at_{schema_param.name};") + conversion_lines.append(f" if ({api_param.name}.has_value()) {{") + conversion_lines.append( + f" at_{schema_param.name} = " + f"{optional_type}{{{_optional_aten_value(schema_param, api_param)}}};" + ) + conversion_lines.append(" }") + + for schema_index, param in enumerate(op.params): + if not param.is_tensor: + continue + + api_param = op.api_param_for(schema_index) + + if api_param is None: + continue + data_expr = ( - f"{param.name}.data()" + f"{api_param.name}.data()" if param.is_mutable_tensor - else f"const_cast({param.name}.data())" + else f"const_cast({api_param.name}.data())" ) conversion_lines.append( f" auto at_{param.name} = ToAtenTensor(\n" - f" {data_expr}, {param.name}_shape_, {param.name}_strides_,\n" - f" {param.name}_type_, device_index_);" + f" {data_expr}, {api_param.name}_shape_, {api_param.name}_strides_,\n" + f" {api_param.name}_type_, device_index_);" ) - def _render_arg(p): + for schema_index, param in enumerate(op.params): + api_param = op.api_param_for(schema_index) + + if api_param is not None and param.aten_type in _EXPOSED_OPTIONAL_CPP_TYPES: + _append_optional_conversion(param, api_param) + + def _render_arg(schema_index, p): + api_param = op.api_param_for(schema_index) + + if api_param is None: + return _default_call_value(p) + if p.is_hidden: - return p.hidden_value() + return _default_call_value(p) + + if p.aten_type in _EXPOSED_OPTIONAL_CPP_TYPES: + return f"at_{p.name}" if p.is_tensor: return f"at_{p.name}" @@ -797,17 +1162,18 @@ def _render_arg(p): # unlike `_out` calls which place output tensors first. input_param = op.params[0] arg_order = op.params[1:] - aten_call = ( - f"at_{input_param.name}.{op.aten_name}" - f"({', '.join(_render_arg(p) for p in arg_order)})" + rendered_args = ", ".join( + _render_arg(index + 1, p) for index, p in enumerate(arg_order) ) + aten_call = f"at_{input_param.name}.{op.aten_name}({rendered_args})" else: # ATen `_out` form puts all out tensors first, then non-out params # in YAML order. Hardcoded-nullopt params become `at::nullopt`. - arg_order = op.out_params + [p for p in op.params if not p.is_out] - aten_call = ( - f"at::{op.aten_name}_out({', '.join(_render_arg(p) for p in arg_order)})" - ) + arg_order = [(index, p) for index, p in enumerate(op.params) if p.is_out] + [ + (index, p) for index, p in enumerate(op.params) if not p.is_out + ] + rendered_args = ", ".join(_render_arg(index, p) for index, p in arg_order) + aten_call = f"at::{op.aten_name}_out({rendered_args})" return _TORCH_METHOD_TEMPLATE.format( pascal=pascal, @@ -846,7 +1212,7 @@ def _generate_torch_source(name: str, ops: list[Op]) -> str: #ifndef INFINI_OPS_BASE_{name_uc}_H_ #define INFINI_OPS_BASE_{name_uc}_H_ -#include "operator.h" +{includes} namespace infini::ops {{ @@ -1052,6 +1418,24 @@ def main() -> int: ) ) + public_name = usable[0].infini_name + base_exists = _base_path(public_name).exists() + + if base_exists: + usable, base_warnings = _bind_existing_base_overloads(public_name, usable) + + for warning in base_warnings: + skipped.append((public_name, warning)) + + if not usable: + skipped.append( + ( + public_name, + "existing base has no overload compatible with ATen schema", + ) + ) + continue + # Emit one InfiniOps wrapper per ATen op. Distinct visible overloads # become overloaded constructors / `operator()` methods on the same # class (`Pow` exposes both tensor and scalar exponents). Overloads @@ -1062,8 +1446,7 @@ def main() -> int: # resolves through `src/` first). Signature mismatches surface as # compile errors with a clear message — drop the op from the YAML # to suppress. - public_name = usable[0].infini_name - _emit(public_name, usable, emit_base=not _base_path(public_name).exists()) + _emit(public_name, usable, emit_base=not base_exists) for op in usable: metadata.append( diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index f2ff37065..797c1a6a7 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -135,6 +135,24 @@ def _find_optional_tensor_params(op_name): return set(re.findall(r"std::optional\s+(\w+)", source)) +def _find_optional_non_tensor_params(op_name): + """Return parameter names declared as non-Tensor `std::optional`. + + Some generated ATen bases have overloads that reuse a parameter name across + different optional kinds, e.g. `clamp(..., std::optional min, ...)` + and `clamp(..., std::optional min, ...)`. The optional-Tensor + regex fallback is name-based, so record non-Tensor optionals too to avoid + treating the scalar overload as a Tensor overload. + """ + source = _find_base_header(op_name).read_text() + + return { + name + for cpp_type, name in re.findall(r"std::optional<([^>]+)>\s+(\w+)", source) + if "Tensor" not in cpp_type + } + + def _find_vector_tensor_params(op_name): """Return a set of parameter names declared as `std::vector` in the base header. @@ -162,14 +180,20 @@ def _find_vector_int64_params(op_name): def _generate_pybind11(operator): optional_tensor_params = _find_optional_tensor_params(operator.name) + optional_non_tensor_params = _find_optional_non_tensor_params(operator.name) vector_tensor_params = _find_vector_tensor_params(operator.name) vector_int64_params = _find_vector_int64_params(operator.name) def _is_optional_tensor(arg): - if arg.spelling in optional_tensor_params: - return True + spelling = arg.type.spelling - return "std::optional" in arg.type.spelling and "Tensor" in arg.type.spelling + if "std::optional" in spelling: + return "Tensor" in spelling + + if arg.spelling in optional_non_tensor_params: + return False + + return arg.spelling in optional_tensor_params def _is_optional(arg): return "std::optional" in arg.type.spelling @@ -503,6 +527,22 @@ def _generate_params(node, call=False): arguments = (arguments[-1], *arguments[:-1]) + def _unwrap_std_optional(spelling): + prefix = "std::optional<" + + if not spelling.startswith(prefix): + return spelling + + inner = spelling[len(prefix) :] + + if inner.endswith(" >"): + return inner[:-2] + ">" + + if inner.endswith(">"): + return inner[:-1] + + return inner + def _handle_tensor(spelling): if call: return spelling.replace("Tensor", "void *") @@ -510,7 +550,7 @@ def _handle_tensor(spelling): return spelling.replace("Tensor", "infiniopTensorDescriptor_t") def _handle_std_optional(spelling): - return spelling.replace("std::optional<", "").replace(">", "") + return _unwrap_std_optional(spelling) return ", ".join( f"{_handle_std_optional(_handle_tensor(arg.type.spelling))} {arg.spelling}" diff --git a/src/hash.h b/src/hash.h index 4721f33f3..9c8f8b1ce 100644 --- a/src/hash.h +++ b/src/hash.h @@ -2,6 +2,7 @@ #define INFINI_OPS_HASH_H_ #include +#include #include template @@ -18,4 +19,12 @@ inline void HashCombine(std::size_t& seed, const std::vector& v) { } } +template +inline void HashCombine(std::size_t& seed, const std::optional& v) { + HashCombine(seed, v.has_value()); + if (v.has_value()) { + HashCombine(seed, *v); + } +} + #endif diff --git a/tests/test_generate_torch_ops.py b/tests/test_generate_torch_ops.py index 456535567..10301255d 100644 --- a/tests/test_generate_torch_ops.py +++ b/tests/test_generate_torch_ops.py @@ -31,3 +31,112 @@ def test_public_op_name_normalizes_aten_internal_and_inplace_names(): assert module._public_op_name("_softmax") == "aten_softmax" assert module._public_op_name("add_") == "add_inplace" assert module._public_op_name("_add_relu_") == "aten_add_relu_inplace" + + +def test_optional_tensor_params_are_exposed_and_forwarded_to_aten(): + module = _load_generator_module() + op = module._parse_func( + "batch_norm_elemt(Tensor input, Tensor? weight=None, " + "Tensor? bias=None, Tensor mean, Tensor invstd, float eps, " + "*, Tensor(a!) out) -> Tensor(a!)" + ) + + assert [param.cpp_type for param in op.visible_params] == [ + "Tensor", + "std::optional", + "std::optional", + "Tensor", + "Tensor", + "double", + "Tensor", + ] + + base = module._generate_base_header("batch_norm_elemt", [op]) + source = module._generate_torch_method_source("batch_norm_elemt", op) + + assert "#include " in base + assert "std::optional weight" in base + assert "std::optional bias" in base + assert "bool has_weight_" in base + assert "bool has_bias_" in base + assert "c10::optional at_weight" in source + assert "c10::optional at_bias" in source + assert "at::batch_norm_elemt_out" in source + assert "at_weight" in source + assert "at_bias" in source + + +def test_optional_scalar_and_array_params_are_exposed_and_forwarded_to_aten(): + module = _load_generator_module() + quantile = module._parse_func( + "quantile(Tensor input, Tensor q, int? dim=None, bool keepdim=False, " + "str interpolation='linear', *, Tensor(a!) out) -> Tensor(a!)" + ) + upsample = module._parse_func( + "upsample_bicubic2d(Tensor input, SymInt[2] output_size, " + "bool align_corners, float[]? scale_factors=None, " + "*, Tensor(a!) out) -> Tensor(a!)" + ) + + assert [param.cpp_type for param in quantile.visible_params] == [ + "Tensor", + "Tensor", + "std::optional", + "bool", + "std::string", + "Tensor", + ] + assert [param.cpp_type for param in upsample.visible_params] == [ + "Tensor", + "std::vector", + "bool", + "std::optional>", + "Tensor", + ] + + quantile_source = module._generate_torch_method_source("quantile", quantile) + upsample_source = module._generate_torch_method_source( + "upsample_bicubic2d", upsample + ) + + assert "c10::optional at_dim" in quantile_source + assert "at::quantile_out" in quantile_source + assert "at_dim" in quantile_source + assert "c10::optional> at_scale_factors" in upsample_source + assert "at::upsample_bicubic2d_out" in upsample_source + assert "at_scale_factors" in upsample_source + + +def test_existing_base_overload_can_omit_optional_schema_params(): + module = _load_generator_module() + op = module._parse_func( + "slow_conv3d(Tensor input, Tensor weight, int[3] kernel_size, " + "Tensor? bias=None, int[3] stride=1, int[3] padding=0, " + "*, Tensor(a!) out) -> Tensor(a!)" + ) + signature = [ + ("Tensor", "input"), + ("Tensor", "weight"), + ("std::vector", "kernel_size"), + ("std::vector", "stride"), + ("std::vector", "padding"), + ("Tensor", "out"), + ] + + bound = module._bind_base_signature(op, signature) + + assert bound is not None + assert [param.name for param in bound.visible_params] == [ + "input", + "weight", + "kernel_size", + "stride", + "padding", + "out", + ] + + source = module._generate_torch_method_source("slow_conv3d", bound) + + assert "std::optional bias" not in source + assert "c10::optional{}" in source + assert "at::slow_conv3d_out" in source diff --git a/tests/test_torch_ops.py b/tests/test_torch_ops.py index bc2a3337d..7fa19755b 100644 --- a/tests/test_torch_ops.py +++ b/tests/test_torch_ops.py @@ -140,6 +140,13 @@ _LIST_SIZE_RE = re.compile(r"\[(\d+)\]") +def _optional_inner_type(aten_type): + if aten_type.endswith("?"): + return aten_type[:-1] + + return None + + def _is_inplace_aten_name(name): """Return whether `name` is an ATen in-place operator name.""" @@ -203,28 +210,39 @@ def _list_default(aten_type): } ) +# Ops whose vendor `_out` path does not match the vendor functional reference +# for the generic random inputs generated by this harness. +_VENDOR_DIVERGENT_OPS = frozenset( + { + ("musa", "native_batch_norm"), + } +) + # Ops whose vendor kernel crashes the Python process, so they must be skipped # before calling into the InfiniOps/PyTorch slot. _VENDOR_CRASH_OPS = frozenset( { ("npu", "mish"), + ("npu", "mse_loss"), + ("npu", "nonzero"), + ("npu", "norm"), ("npu", "nuclear_norm"), + ("npu", "smooth_l1_loss"), + ("npu", "soft_margin_loss"), ("npu", "_linalg_svd"), ("npu", "svd"), } ) # Ops where the ATen `_out` schema and the Python reference (`torch.`, -# `torch.nn.functional.`) diverge in positional-argument ordering, so -# the harness's purely-positional reference call lands an InfiniOps -# argument on the wrong reference parameter. E.g. ATen +# `torch.nn.functional.`) diverge in positional-argument ordering or +# parameter representation, so the harness's purely-positional reference call +# lands an InfiniOps argument on the wrong reference parameter. E.g. ATen # `binary_cross_entropy_out(self, target, weight=None, reduction=Mean, out)` -# has `weight` between `target` and `reduction`; with `weight` hidden as -# `Tensor?`, our visible signature is `(self, target, reduction, out)`, -# but `torch.nn.functional.binary_cross_entropy(input, target, weight, -# reduction)` reads our `reduction:int` as `weight:Tensor` and crashes -# inside `weight.size()`. The InfiniOps wrapper itself is fine; only -# the harness's reference call is wrong. +# uses integer `reduction`, while `torch.nn.functional.binary_cross_entropy` +# has legacy `size_average` / `reduce` parameters before string `reduction`. +# The InfiniOps wrapper itself is fine; only the harness's reference call is +# wrong. _REFERENCE_SIGNATURE_MISMATCH_OPS = frozenset( { "binary_cross_entropy", @@ -358,6 +376,31 @@ def _build_input_value(op_name, param, shape, dtype, device, tensor_idx): return _SCALAR_VALUES[key] t = param["type"] + optional_inner = _optional_inner_type(t) + + if optional_inner == "Tensor": + per_op = _TENSOR_SHAPES.get(op_name) + tshape = per_op[tensor_idx] if per_op is not None else shape + + return randn_strided(tshape, None, dtype=dtype, device=device) + + if optional_inner == "ScalarType": + return None + + if optional_inner == "str": + return None + + if optional_inner in {"Scalar", "float"}: + return 0.5 + + if optional_inner in {"int", "SymInt", "bool"}: + return _TYPE_DEFAULTS[optional_inner] + + if t == "float[]?": + return None + + if t.startswith(("int[", "SymInt[")) or t in {"int[]?", "SymInt[]?"}: + return _list_default(t) if t.startswith(("int[", "SymInt[")) or t in {"int[]", "SymInt[]"}: return _list_default(t) @@ -444,6 +487,9 @@ def test_op(op_meta, shape, dtype, device, rtol, atol): if aten_name in _VENDOR_HANG_OPS: pytest.skip(f"`{aten_name}` hangs on at least one vendor kernel") + if (device, aten_name) in _VENDOR_DIVERGENT_OPS: + pytest.skip(f"`{aten_name}` diverges on `{device}` vendor kernel") + if (device, aten_name) in _VENDOR_CRASH_OPS: pytest.skip(f"`{aten_name}` crashes on `{device}` vendor kernel") @@ -468,7 +514,7 @@ def test_op(op_meta, shape, dtype, device, rtol, atol): _build_input_value(aten_name, p, shape, dtype, device, tensor_idx) ) - if p["is_tensor"]: + if p["is_tensor"] or p["type"] == "Tensor?": tensor_idx += 1 # Run the reference to discover output shape(s)/dtype(s).