Spaces:
Runtime error
Runtime error
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import signal | |
| from unittest.mock import MagicMock, PropertyMock, patch | |
| import pytest | |
| import torch | |
| from lightning.pytorch import Trainer | |
| from nemo.lightning.pytorch.callbacks.preemption import PreemptionCallback | |
| class TestPreemptionCallback: | |
| def callback(self): | |
| return PreemptionCallback() | |
| def mock_trainer(self): | |
| trainer = MagicMock(spec=Trainer) | |
| trainer.should_stop = False | |
| return trainer | |
| def test_init(self, callback): | |
| assert callback.sig == signal.SIGTERM | |
| assert not callback._interrupted | |
| assert callback._handler_context is None | |
| def test_custom_signal(self): | |
| custom_callback = PreemptionCallback(sig=signal.SIGUSR1) | |
| assert custom_callback.sig == signal.SIGUSR1 | |
| def test_on_train_batch_start_distributed_init( | |
| self, callback, mock_trainer, initially_supported, becomes_supported | |
| ): | |
| with ( | |
| patch.object(PreemptionCallback, '_check_preemption_support') as mock_check, | |
| patch.object(callback, '_preemption_handler') as mock_handler, | |
| ): | |
| mock_check.side_effect = [initially_supported, becomes_supported] | |
| callback.on_train_start(mock_trainer, None) | |
| callback.on_train_batch_start(mock_trainer, None, None, 0) | |
| expected_call_count = 1 if initially_supported else (1 if becomes_supported else 0) | |
| assert mock_handler.call_count == expected_call_count | |
| if initially_supported: | |
| mock_handler.assert_called_once_with() | |
| elif becomes_supported: | |
| mock_handler.assert_called_once_with() | |
| else: | |
| mock_handler.assert_not_called() | |
| def test_interrupted_property(self, callback, is_supported, interrupted, expected): | |
| with ( | |
| patch.object(PreemptionCallback, '_check_preemption_support', return_value=is_supported), | |
| patch('torch.distributed.broadcast'), | |
| patch('torch.tensor', return_value=torch.tensor(interrupted)), | |
| patch('torch.cuda.is_available', return_value=True), | |
| patch('torch.cuda.current_device', return_value=0), | |
| ): | |
| callback._interrupted = interrupted | |
| assert callback.interrupted == expected | |
| def test_on_train_start(self, callback, mock_trainer): | |
| with ( | |
| patch.object(PreemptionCallback, 'preemption_supported', new_callable=PropertyMock) as mock_supported, | |
| patch.object(callback, '_preemption_handler') as mock_handler, | |
| ): | |
| # Test when preemption is supported | |
| mock_supported.return_value = True | |
| callback.on_train_start(mock_trainer, None) | |
| mock_handler.assert_called_once() | |
| mock_handler.reset_mock() | |
| # Test when preemption is not supported | |
| mock_supported.return_value = False | |
| callback.on_train_start(mock_trainer, None) | |
| mock_handler.assert_not_called() | |
| def test_on_train_end(self, callback, mock_trainer): | |
| mock_context = MagicMock() | |
| callback._handler_context = mock_context | |
| callback.on_train_end(mock_trainer, None) | |
| mock_context.__exit__.assert_called_once_with(None, None, None) | |
| def test_on_train_batch_end(self, callback, mock_trainer, interrupted): | |
| with patch.object(PreemptionCallback, 'interrupted', new_callable=lambda: property(lambda self: interrupted)): | |
| if interrupted: | |
| with pytest.raises(SystemExit): | |
| callback.on_train_batch_end(mock_trainer, None, None, None, 0) | |
| else: | |
| callback.on_train_batch_end(mock_trainer, None, None, None, 0) | |
| assert mock_trainer.should_stop == interrupted | |