diff --git a/kubernetes/base/leaderelection/leaderelection.py b/kubernetes/base/leaderelection/leaderelection.py index 4fff8d0790..ea20f0a570 100644 --- a/kubernetes/base/leaderelection/leaderelection.py +++ b/kubernetes/base/leaderelection/leaderelection.py @@ -55,8 +55,7 @@ def run(self): logger.info("{} successfully acquired lease".format(self.election_config.lock.identity)) # Start leading and call OnStartedLeading() - threading.daemon = True - threading.Thread(target=self.election_config.onstarted_leading).start() + threading.Thread(target=self.election_config.onstarted_leading, daemon=True).start() self.renew_loop() diff --git a/kubernetes/base/leaderelection/leaderelection_test.py b/kubernetes/base/leaderelection/leaderelection_test.py index 9fb6d9bcf4..0cbcf00c88 100644 --- a/kubernetes/base/leaderelection/leaderelection_test.py +++ b/kubernetes/base/leaderelection/leaderelection_test.py @@ -22,6 +22,7 @@ import json import time import pytest +from unittest.mock import patch thread_lock = threading.RLock() @@ -194,6 +195,33 @@ def on_stopped_leading(): self.assert_history(leadership_history, ["get leadership", "start leading", "stop leading"]) + def test_onstarted_leading_runs_in_daemon_thread(self): + captured = {} + real_thread = threading.Thread + + def record_thread(*args, **kwargs): + thread = real_thread(*args, **kwargs) + captured["daemon"] = thread.daemon + return thread + + started = threading.Event() + + mock_lock = MockResourceLock("mock", "mock_namespace", "mock", thread_lock, + lambda: None, lambda: None, lambda: None, None) + mock_lock.renew_count_max = 1 + + config = electionconfig.Config(lock=mock_lock, lease_duration=2, + renew_deadline=1.5, retry_period=1, + onstarted_leading=started.set, + onstopped_leading=lambda: None) + + with patch.object(leaderelection.threading, "Thread", new=record_thread): + leaderelection.LeaderElection(config).run() + + self.assertTrue(started.wait(1), "onstarted_leading callback did not run") + self.assertIn("daemon", captured) + self.assertTrue(captured["daemon"]) + def assert_history(self, history, expected): self.assertIsNotNone(expected) self.assertIsNotNone(history)