feat(torch): expose optional codegen parameters#619
Draft
voltjia wants to merge 1 commit into
Draft
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Tensor?, scalar optionals, optional dtype, optional strings, and optional integer/float lists.src/base/<op>.hoverloads when a base already exists; omitted ATen optional/default parameters are forwarded as typed defaults, while incompatible overloads are reported and skipped.std::optional<T>support to operator cache hashing and update generated torch-op tests to construct optional arguments correctly.Motivation
The PyTorch code generator previously hid optional ATen schema parameters and always forwarded typed
nulloptvalues. That made generated APIs unable to exercise non-default optional behavior and caused drift against operator base headers that intentionally expose optional parameters. This PR makes optional schema handling explicit while keeping existing hand-written bases as the public API source of truth when they are present.Closes # N/A — this is follow-up work from the PyTorch codegen/base drift discussion.
Type of Change
feat— new feature / new operator / new platformfix— bug fix.perf— performance improvement (no behavioral change).refactor— code restructuring without behavior change.test— adding or fixing tests only.docs— documentation only.build/ci— build system or CI configuration.chore— tooling, formatting, or other non-code changes.Platforms Affected
WITH_CPU)WITH_NVIDIA)WITH_ILUVATAR)WITH_METAX)WITH_CAMBRICON)WITH_MOORE)WITH_ASCEND)WITH_TORCH)Test Results on Supported Platforms
pytestResulttests/test_generate_torch_ops.py:5 passed; CPU + PyTorch build succeededWITH_TORCH; full GPU pytest pending.Full `pytest` output (optional)
Benchmark / Performance Impact
N/A — this PR changes generated API/backend plumbing and tests. No runtime performance benchmark was run.
Notes for Reviewers
src/base/<op>.hoverloads are treated as the public API when present. The generator binds compatible overloads to ATen schema parameters and fills omitted optional/default schema parameters at the ATen call site.std::optional<...>. PyTorch-internal optional types without stable InfiniOps representations remain hidden and are forwarded as typed empty optionals.625overloads across507ops. The generated metadata exposes208optional parameters across113ops.Checklist
Title, Branch, and Commits
feat(nvidia): …,fix(cuda/gemm): …).<type>/xxx-yyyy-zzzzwhere<type>matches the PR title's Conventional Commits type and words are joined with hyphens (seeCONTRIBUTING.md§Branches).CONTRIBUTING.md§Pull Requests).master— the branch is rebased cleanly on top of the currentmaster.fixup!/squash!/wipcommits remain.Scope and Design
CONTRIBUTING.md§Code/General).printf/std::cout/print(...)left behind, orTODOwithout an owner and issue link.General Code Hygiene (applies to all languages)
CONTRIBUTING.md§Code/General).CONTRIBUTING.md§Code/General).the `seqlens_k` tensor) (CONTRIBUTING.md§Code/General).CONTRIBUTING.md§Code/General).CONTRIBUTING.md§Code/General; §Python).C++ Specific (if C++ files changed)
clang-formatversion 21 was not available in the current validation container; the touched C++ header was formatted with the available formatter and CI should enforce the project version.clang-tidywas not run locally for this draft PR; no kernel or algorithm implementation path is added.CONTRIBUTING.md§C++).CONTRIBUTING.md§C++).CONTRIBUTING.md§C++).CONTRIBUTING.md§C++).CONTRIBUTING.md§C++).src/base/<op>.hor platform implementation directories.new/delete; RAII / smart pointers / existing allocators are used.Python Specific (if Python files changed)
ruff checkpasses cleanly on CI (see.github/workflows/ruff.yml).ruff format --checkpasses cleanly — if not, runruff formatand commit the result.CONTRIBUTING.md§Python).pytest.skipmessages without terminal punctuation) are honored where applicable (CONTRIBUTING.md§Python).CONTRIBUTING.md§Python).if,for, and similar control-flow statements (CONTRIBUTING.md§Python).return, except when it directly follows a control-flow statement likeiforfor(CONTRIBUTING.md§Python).Testing
pytestis pending; this PR is draft and the current table records partial validation only.tests/followingtests/test_add.py/tests/test_gemm.pypatterns (CONTRIBUTING.md§Adding an Operator).pytest.mark.parametrizecorrectly: dependent parameters share one decorator (e.g.@pytest.mark.parametrize("dtype, rtol, atol", …)); independent parameters use separate decorators ordered by parameter declaration.pytest.mark.auto_act_and_assertis not used by the generator unit tests or generated torch-op harness touched here.dtype/deviceparameterization is relied on, or overridden with an explicitpytest.mark.parametrizewhen necessary.Build, CI, and Tooling
pip install .[dev]on at least one affected platform.compile_commands.jsonstill regenerates through the existing CMake/scikit-build configuration path.CMakeLists.txtis not changed.ruffis green locally/in container;clang-format.ymlis expected to enforce the project formatter version on CI.pyproject.toml's[project.optional-dependencies].Documentation
Security and Safety