Add support for specialization constants#2304
Conversation
|
View rendered docs @ https://intelpython.github.io/dpctl/pulls/2304/index.html |
b7f8d82 to
8c5651d
Compare
8dbb320 to
043aa83
Compare
82736b8 to
1a0a910
Compare
55d4458 to
97871d9
Compare
also removes "v" as a permitted specialization constant intermediate data type, as composite specialization constants are broken into multiple specialization constants, so structs end up passed as a single constant while the program expects multiple, and therefore, doesn't work as intended
also adds spec_id, itemsize, and default_value fields
97871d9 to
e2e4826
Compare
79cd9fd to
53be0e1
Compare
copy inputs to SpecializationConstant to prevent dangling pointers
53be0e1 to
d9f386b
Compare
also adds a test
| "Failed to allocate memory for specialization constants." | ||
| ) | ||
| for i, spconst in enumerate(specializations): | ||
| if not isinstance(spconst, SpecializationConstant): |
There was a problem hiding this comment.
The loop variable is declared cdef SpecializationConstant spconst, so for i, spconst in enumerate(specializations) performs the type coercion at the loop assignment.
A non-SpecializationConstant element raises Cython's own TypeError before reaching the isinstance guard at line 555 — so the custom error message and the free(spconsts) cleanup never run, leaking the spconsts allocation.
|
|
||
| # add submodules | ||
| __all__ += [ | ||
| "utils", |
There was a problem hiding this comment.
"utils" in __all__ is never imported
There was a problem hiding this comment.
Tests only pass because they do from dpctl.program.utils import ... explicitly. Needs to add from . import utils
|
|
||
| if word_count == 0: | ||
| raise ValueError(f"Invalid SPIR-V instruction at word index {i}") | ||
| if i + word_count > len(words): |
There was a problem hiding this comment.
SPIR-V parser raises uncaught IndexError / UnicodeDecodeError on malformed input.
The only bounds check is that the declared word_count fits the buffer; per-opcode operand reads (OpTypeInt needs i+3, OpDecorate needs i+3, OpName decodes UTF-8) are unchecked. A truncated instruction throws raw IndexError, and invalid UTF-8 in OpName throws UnicodeDecodeError — both contradicting the documented ValueError("Invalid SPIR-V binary") contract for a function whose explicit job is parsing untrusted binaries. Neither path is tested.
There was a problem hiding this comment.
Should that be covered in the tests?
| } | ||
| ze_module_constants_t ZeSpecConstants = {}; | ||
| ZeSpecConstants.numConstants = 0; | ||
| ZeSpecConstants.numConstants = static_cast<std::uint32_t>(NumSpecConsts); |
There was a problem hiding this comment.
numConstants set even when arrays are empty.
ZeSpecConstants.numConstants is set to NumSpecConsts unconditionally, while pConstantIds/pConstantValues become nullptr when the vectors are empty. If a caller passes NumSpecConsts > 0 with SpecConsts == nullptr, L0 gets a nonzero count with null arrays → UB. The OCL path silently ignores the same input, so the two backends diverge. (While the Cython layer never produces this, so it's a hardening/API-consistency issue.)
| f"{len(args) + 1}." | ||
| ) | ||
|
|
||
| self._spec_const.id = <uint32_t>spec_id |
There was a problem hiding this comment.
It wraps silently for out-of-range/negative IDs. SPIR-V SpecId is a 32-bit decoration; an out-of-range value should raise, not wrap.
| def __repr__(self): | ||
| return f"SpecializationConstant({self._spec_const.id})" | ||
|
|
||
| def __eq__(self, other): |
There was a problem hiding this comment.
Do we need to add definition for __hash__, since defining __eq__ might drop hashability?
|
|
||
| def __eq__(self, other): | ||
| if not isinstance(other, SpecializationConstant): | ||
| return False |
There was a problem hiding this comment.
To unblock reflected comparison:
| return False | |
| return NotImplemented |
| isinstance(args[0], numbers.Integral) and | ||
| isinstance(args[1], numbers.Integral) | ||
| ): | ||
| target_obj = PyBytes_FromStringAndSize( |
There was a problem hiding this comment.
No nbytes >= 0 check.
A negative Py_ssize_t into PyBytes_FromStringAndSize is undefined; the unsafe-pointer path is inherently risky but should at least range-check.
| case backend::opencl: | ||
| KBRef = _CreateKernelBundleWithIL_ocl_impl(*SyclCtx, *SyclDev, IL, | ||
| length, CompileOpts); | ||
| length, CompileOpts, |
There was a problem hiding this comment.
No try/catch across the C ABI boundary in the OpenCL/top-level path.
The spec-constant build failures enlarge the throwing input space.
| const char *CompileOpts); | ||
| const char *CompileOpts, | ||
| size_t NumSpecConsts, | ||
| const DPCTLSpecConst *SpecConsts); |
There was a problem hiding this comment.
Do we need to add __dpctl_keep?
This PR introduces support for specialization constants in dpctl, including both a Cython class
SpecializationConstantfor construction and passing of the constructed class tocreate_kernel_bundle_from_spirvvia a newspecializationskeyword argument.The
SpecializationConstantclass supports multiple constructors, including from Python buffers, a dtype string and a Python buffer (casting to the dtype via NumPy), and a number of bytes and a pointer as integers.Also introduces
dpctl.program.utilswithparse_spirv_specializationsutility function, allowing the user to query a SPIR-V directly from Python.