Add L2 score mod distributed attention shape#3147
Conversation
| DISTRIBUTED_SCORE_MOD_DATA_SHAPES = { | ||
| "L0": [], | ||
| "L1": [(4, 16, 4, 64)], | ||
| "L2": [(4, 16, 4, 64)], |
There was a problem hiding this comment.
I think it should be (assuming you want this to run as L1 test):
DISTRIBUTED_SCORE_MOD_DATA_SHAPES = {
"L0": [],
"L1": [(4, 16, 4, 64)],
"L2": [],
}
What you have will run the same tests for L1 and L2 there by duplicating effort
Please urgently launch a pipeline with a JAX build manually for L0, L1 and L2 levels and confirm that it runs successfully before merging
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
286f9be to
6e9fb06
Compare
| DISTRIBUTED_SCORE_MOD_DATA_SHAPES = { | ||
| "L0": [], | ||
| "L1": [(4, 16, 4, 64)], | ||
| "L2": [], |
There was a problem hiding this comment.
L2 level resolves to zero test cases
"L2": [] is passed to pytest_parametrize_wrapper, which calls get_parameters_for_test_level and returns the empty list. That list is forwarded directly to pytest.mark.parametrize("data_shape", []). With an empty parametrize set, pytest either skips the test entirely or raises a collection error depending on the --empty-parameter-set-mark config, so when NVTE_JAX_UNITTEST_LEVEL=L2 is used in CI no TestDistributedScoreModSelfAttn cases will execute. The PR description says this change "fixes L2 tests", but the fix needs at least one concrete shape tuple — the same pattern used by DISTRIBUTED_SELF_ATTN_DATA_SHAPES where L2 carries [(32, 512, 12, 64)].
Description
Add L2 score mod distributed attention shape
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: