Skip to content

feat(torch): expose optional codegen parameters#619

Draft
voltjia wants to merge 1 commit into
masterfrom
feat/torch-codegen-optional-overloads
Draft

feat(torch): expose optional codegen parameters#619
voltjia wants to merge 1 commit into
masterfrom
feat/torch-codegen-optional-overloads

Conversation

@voltjia
Copy link
Copy Markdown
Collaborator

@voltjia voltjia commented May 20, 2026

Summary

  • Expose ATen optional parameters with stable InfiniOps representations in generated PyTorch operator bases, including Tensor?, scalar optionals, optional dtype, optional strings, and optional integer/float lists.
  • Bind generated PyTorch backends to existing src/base/<op>.h overloads when a base already exists; omitted ATen optional/default parameters are forwarded as typed defaults, while incompatible overloads are reported and skipped.
  • Add std::optional<T> support to operator cache hashing and update generated torch-op tests to construct optional arguments correctly.
  • Add generator tests covering optional parameter exposure and existing-base overload binding.

Motivation

The PyTorch code generator previously hid optional ATen schema parameters and always forwarded typed nullopt values. That made generated APIs unable to exercise non-default optional behavior and caused drift against operator base headers that intentionally expose optional parameters. This PR makes optional schema handling explicit while keeping existing hand-written bases as the public API source of truth when they are present.

Closes # N/A — this is follow-up work from the PyTorch codegen/base drift discussion.

Type of Change

  • feat — new feature / new operator / new platform
  • N/A — fix — bug fix.
  • N/A — perf — performance improvement (no behavioral change).
  • N/A — refactor — code restructuring without behavior change.
  • N/A — test — adding or fixing tests only.
  • N/A — docs — documentation only.
  • N/A — build / ci — build system or CI configuration.
  • N/A — chore — tooling, formatting, or other non-code changes.
  • N/A — Breaking change; this PR is draft pending review of the generated public API impact.

Platforms Affected

  • CPU (WITH_CPU)
  • NVIDIA (WITH_NVIDIA)
  • Iluvatar (WITH_ILUVATAR)
  • MetaX (WITH_METAX)
  • Cambricon (WITH_CAMBRICON)
  • Moore (WITH_MOORE)
  • Ascend (WITH_ASCEND)
  • PyTorch C++ bindings (WITH_TORCH)
  • N/A — Build system / CMake / CI; no CMake or CI files are changed.
  • Python bindings / user-facing API

Test Results on Supported Platforms

Platform Built pytest Result Notes / Hardware
NVIDIA Partial tests/test_generate_torch_ops.py: 5 passed; CPU + PyTorch build succeeded Container-only validation for generator, generated wrappers, generated dispatch, and WITH_TORCH; full GPU pytest pending.
Iluvatar N/A N/A Full-platform validation pending; draft PR opened for code review first.
MetaX N/A N/A Full-platform validation pending; draft PR opened for code review first.
Cambricon N/A N/A Full-platform validation pending; draft PR opened for code review first.
Moore N/A N/A Full-platform validation pending; draft PR opened for code review first.
Ascend N/A N/A Full-platform validation pending; draft PR opened for code review first.
Full `pytest` output (optional)
python -m pytest tests/test_generate_torch_ops.py -q
Running 5 items in this shard
.....                                                                    [100%]
5 passed in 0.04s
python scripts/generate_torch_ops.py
generated 625 overloads across 507 ops
python -m pip install --no-build-isolation .[dev] \
  -C cmake.define.WITH_CPU=ON \
  -C cmake.define.WITH_TORCH=ON
Successfully built InfiniOps
Successfully installed InfiniOps-0.1.0

Benchmark / Performance Impact

N/A — this PR changes generated API/backend plumbing and tests. No runtime performance benchmark was run.

Notes for Reviewers

  • Existing src/base/<op>.h overloads are treated as the public API when present. The generator binds compatible overloads to ATen schema parameters and fills omitted optional/default schema parameters at the ATen call site.
  • Generated fresh bases now expose supported optional types as std::optional<...>. PyTorch-internal optional types without stable InfiniOps representations remain hidden and are forwarded as typed empty optionals.
  • A full codegen pass currently reports 625 overloads across 507 ops. The generated metadata exposes 208 optional parameters across 113 ops.
  • This is intentionally a draft until full-platform, full-test validation is completed.

Checklist

Title, Branch, and Commits

  • PR title follows Conventional Commits (e.g. feat(nvidia): …, fix(cuda/gemm): …).
  • Branch name follows <type>/xxx-yyyy-zzzz where <type> matches the PR title's Conventional Commits type and words are joined with hyphens (see CONTRIBUTING.md §Branches).
  • Each commit message follows Conventional Commits.
  • Small PR is a single squashable commit; or, for a large PR, every commit is meaningful, well-formed, and independently reviewable (see CONTRIBUTING.md §Pull Requests).
  • No stray merge commits from master — the branch is rebased cleanly on top of the current master.
  • No fixup! / squash! / wip commits remain.

Scope and Design

  • Changes are minimal — nothing unrelated to the stated motivation was added (CONTRIBUTING.md §Code/General).
  • No dead code, commented-out blocks, debug prints, printf/std::cout/print(...) left behind, or TODO without an owner and issue link.
  • No unrelated formatting churn that would obscure the diff.
  • Public API changes (if any) are intentional, documented, and reflected in affected callers/tests.

General Code Hygiene (applies to all languages)

  • The code is self-explanatory; comments were added only where the why is non-obvious (CONTRIBUTING.md §Code/General).
  • Every modified or added file ends with a single trailing newline (CONTRIBUTING.md §Code/General).
  • No trailing whitespace, tab/space mixing, or stray BOMs.
  • Identifiers in comments and error messages are wrapped in backticks (e.g. the `seqlens_k` tensor) (CONTRIBUTING.md §Code/General).
  • All comments and error messages are in English (CONTRIBUTING.md §Code/General).
  • Comments and error messages are complete sentences — capitalized first letter, terminal punctuation — unless the language/framework convention says otherwise (CONTRIBUTING.md §Code/General; §Python).

C++ Specific (if C++ files changed)

  • Code follows the Google C++ Style Guide strictly.
  • N/A — clang-format version 21 was not available in the current validation container; the touched C++ header was formatted with the available formatter and CI should enforce the project version.
  • N/A — clang-tidy was not run locally for this draft PR; no kernel or algorithm implementation path is added.
  • Operator parameter order is inputs first, outputs last; attributes are between inputs and outputs; naming follows PyTorch → ONNX → CUDA API precedence (CONTRIBUTING.md §C++).
  • No exceptions are thrown. No new C++ error path was added.
  • N/A — No new C++ error or warning message was added.
  • N/A — No kernel files are added or renamed.
  • N/A — No kernel launcher files are added or changed.
  • Constructor initializer list order matches member declaration order (CONTRIBUTING.md §C++).
  • Exactly one blank line between classes, between classes and functions, and between functions (CONTRIBUTING.md §C++).
  • Exactly one blank line between members (functions and variables) within a class (CONTRIBUTING.md §C++).
  • Exactly one blank line before and after the contents of a namespace (CONTRIBUTING.md §C++).
  • N/A — No new hand-written operator implementation is added under src/base/<op>.h or platform implementation directories.
  • No raw new/delete; RAII / smart pointers / existing allocators are used.

Python Specific (if Python files changed)

  • Code is PEP 8 compliant; ruff check passes cleanly on CI (see .github/workflows/ruff.yml).
  • ruff format --check passes cleanly — if not, run ruff format and commit the result.
  • Comments are complete English sentences, starting with a capital letter and ending with punctuation; Markdown backticks are used for code references (CONTRIBUTING.md §Python).
  • Framework-specific conventions (e.g. lowercase pytest.skip messages without terminal punctuation) are honored where applicable (CONTRIBUTING.md §Python).
  • No blank line between the function signature and the body when there is no docstring or comment (CONTRIBUTING.md §Python).
  • A blank line is present before and after if, for, and similar control-flow statements (CONTRIBUTING.md §Python).
  • A blank line appears before each return, except when it directly follows a control-flow statement like if or for (CONTRIBUTING.md §Python).
  • Docstrings (if any) follow PEP 257 conventions.
  • Type hints are added / kept consistent with the surrounding code.

Testing

  • N/A — Full-platform pytest is pending; this PR is draft and the current table records partial validation only.
  • N/A — No platform was unreachable; full-platform validation has not been started for this draft yet.
  • New functionality has matching tests under tests/ following tests/test_add.py / tests/test_gemm.py patterns (CONTRIBUTING.md §Adding an Operator).
  • Tests use pytest.mark.parametrize correctly: dependent parameters share one decorator (e.g. @pytest.mark.parametrize("dtype, rtol, atol", …)); independent parameters use separate decorators ordered by parameter declaration.
  • N/A — pytest.mark.auto_act_and_assert is not used by the generator unit tests or generated torch-op harness touched here.
  • Default dtype / device parameterization is relied on, or overridden with an explicit pytest.mark.parametrize when necessary.
  • N/A — No new test is known to be flaky under parallelism.
  • N/A — This is a feature PR rather than a bug-fix regression test PR.

Build, CI, and Tooling

  • The project builds cleanly from a fresh directory with pip install .[dev] on at least one affected platform.
  • compile_commands.json still regenerates through the existing CMake/scikit-build configuration path.
  • N/A — No new backend or device auto-detection is added.
  • Only one CUDA-like GPU backend is selectable at a time — the existing mutual-exclusion check in CMakeLists.txt is not changed.
  • ruff is green locally/in container; clang-format.yml is expected to enforce the project formatter version on CI.
  • No new runtime dependency was added without updating pyproject.toml's [project.optional-dependencies].

Documentation

  • N/A — No README, CONTRIBUTING, build flag, or developer workflow change is introduced.
  • N/A — No new operator, dispatch helper, or public utility is added outside generated code behavior.
  • N/A — No user-visible breaking change is intentionally introduced.

Security and Safety

  • No secrets, access tokens, internal URLs, customer data, or personal hardware identifiers have been committed.
  • N/A — No third-party code is added.
  • No unsafe pointer arithmetic, uninitialized reads, or missing bounds checks were introduced.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant