Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions experimental/air/cmd/runsubmit.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,15 @@ type jobEnvironment struct {
}

// submitTask is the single task air submits: a native ai_runtime_task.
//
// max_retries / retry_on_timeout are intentionally omitted: on the ai_runtime_task
// path retries are driven by the AI Runtime service, not the Jobs task field, so
// setting it here has no effect (matches the Python CLI's native path).
type submitTask struct {
TaskKey string `json:"task_key"`
RunIf string `json:"run_if"`
AiRuntimeTask aiRuntimeTask `json:"ai_runtime_task"`
EnvironmentKey string `json:"environment_key"`
MaxRetries int `json:"max_retries"`
RetryOnTimeout bool `json:"retry_on_timeout,omitempty"`
}

// jobsSubmitRun is the Jobs runs/submit payload.
Expand Down Expand Up @@ -123,11 +125,7 @@ func buildSubmitPayload(cfg *runConfig, commandPath, dlImage string) jobsSubmitR
RunIf: "ALL_SUCCESS",
AiRuntimeTask: task,
EnvironmentKey: aiRuntimeEnvironmentKey,
MaxRetries: cfg.maxRetries(),
}
// max_retries 0 (no retries) is sent explicitly; retry_on_timeout only
// applies when retries are allowed.
st.RetryOnTimeout = st.MaxRetries > 0

return jobsSubmitRun{
RunName: cfg.ExperimentName,
Expand Down
22 changes: 4 additions & 18 deletions experimental/air/cmd/runsubmit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ func TestBuildSubmitPayload(t *testing.T) {
ExperimentName: "exp",
Command: new("python train.py"),
Compute: &computeConfig{AcceleratorType: "GPU_8xH100", NumAccelerators: 16},
MaxRetries: new(2),
TimeoutMinutes: new(30),
MLflowRunName: new("run-v2"),
MLflowExperimentDirectory: new("/Workspace/Users/me/exp"),
Expand All @@ -49,8 +48,6 @@ func TestBuildSubmitPayload(t *testing.T) {
assert.Equal(t, "exp", task.TaskKey)
assert.Equal(t, "ALL_SUCCESS", task.RunIf)
assert.Equal(t, aiRuntimeEnvironmentKey, task.EnvironmentKey)
assert.Equal(t, 2, task.MaxRetries)
assert.True(t, task.RetryOnTimeout)

at := task.AiRuntimeTask
assert.Equal(t, "exp", at.Experiment)
Expand All @@ -59,24 +56,13 @@ func TestBuildSubmitPayload(t *testing.T) {
require.Len(t, at.Deployments, 1)
assert.Equal(t, "/d/command.sh", at.Deployments[0].CommandPath)
assert.Equal(t, aiRuntimeCompute{AcceleratorType: "GPU_8xH100", AcceleratorCount: 16}, at.Deployments[0].Compute)
}

func TestBuildSubmitPayload_NoRetries(t *testing.T) {
cfg := &runConfig{
ExperimentName: "exp",
Command: new("x"),
Compute: &computeConfig{AcceleratorType: "GPU_1xH100", NumAccelerators: 1},
MaxRetries: new(0),
}

task := buildSubmitPayload(cfg, "/d/command.sh", "4").Tasks[0]
assert.Equal(t, 0, task.MaxRetries)
assert.False(t, task.RetryOnTimeout)

// max_retries: 0 must be sent, not omitted, so the server honors "no retries".
// max_retries / retry_on_timeout are not sent: the ai_runtime_task path does
// not honor them (retries are driven by the AI Runtime service).
b, err := json.Marshal(task)
require.NoError(t, err)
assert.Contains(t, string(b), `"max_retries":0`)
assert.NotContains(t, string(b), "max_retries")
assert.NotContains(t, string(b), "retry_on_timeout")
}

func TestSubmitToken(t *testing.T) {
Expand Down
Loading