Spaces:
Runtime error
Runtime error
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from collections import defaultdict | |
| from unittest.mock import MagicMock | |
| import pytest | |
| from megatron.core import parallel_state | |
| from torch import nn | |
| from nemo import lightning as nl | |
| from nemo.lightning import megatron_parallel as mp | |
| class TestMegatronParallel: | |
| """Unit tests for the MegatronParallel class.""" | |
| def mock_pipeline(self, mocker): | |
| """Fixture to create a mock pipeline.""" | |
| class DummyModule(nn.Module): | |
| def __init__(self, dummy_arg=None): | |
| self.dummy_arg = dummy_arg | |
| super().__init__() | |
| def forward(self, x): | |
| return x | |
| return DummyModule() | |
| def mock_precision_plugin(self, mocker): | |
| """Fixture to create a mock precision plugin.""" | |
| return nl.MegatronMixedPrecision(precision="bf16-mixed") | |
| def mock_callbacks(self, mocker): | |
| """Fixture to create a mock callback connector.""" | |
| return mocker.MagicMock(spec=mp.CallbackConnector) | |
| def mock_data_step(self, mocker): | |
| """Fixture to create a mock data step function.""" | |
| return mocker.MagicMock() | |
| def mock_forward_step(self, mocker): | |
| """Fixture to create a mock forward step function.""" | |
| return mocker.MagicMock() | |
| def mock_loss_reduction(self, mocker): | |
| """Fixture to create a mock loss reduction function.""" | |
| return mocker.MagicMock() | |
| def test_init_with_defaults(self, mocker, mock_pipeline): | |
| """Test __init__ with default parameters.""" | |
| mocker.patch('megatron.core.parallel_state.get_pipeline_model_parallel_world_size', return_value=1) | |
| mocker.patch('megatron.core.parallel_state.model_parallel_is_initialized', return_value=False) | |
| megatron_parallel = mp.MegatronParallel(pipeline=mock_pipeline, cpu=True) | |
| assert megatron_parallel.pipeline == mock_pipeline | |
| assert megatron_parallel.precision_plugin is None | |
| assert isinstance(megatron_parallel.callbacks, mp.CallbackConnector) | |
| assert megatron_parallel.data_step == mp.default_data_step | |
| assert megatron_parallel.forward_step == mp.default_forward_step | |
| assert megatron_parallel.loss_reduction is None | |
| def test_init_with_custom_parameters( | |
| self, | |
| mocker, | |
| mock_pipeline, | |
| mock_precision_plugin, | |
| mock_callbacks, | |
| mock_data_step, | |
| mock_forward_step, | |
| mock_loss_reduction, | |
| ): | |
| """Test __init__ with custom parameters.""" | |
| mocker.patch('megatron.core.parallel_state.get_pipeline_model_parallel_world_size', return_value=1) | |
| mocker.patch('megatron.core.parallel_state.model_parallel_is_initialized', return_value=False) | |
| megatron_parallel = mp.MegatronParallel( | |
| pipeline=mock_pipeline, | |
| precision_plugin=mock_precision_plugin, | |
| callbacks=mock_callbacks, | |
| data_step=mock_data_step, | |
| forward_step=mock_forward_step, | |
| loss_reduction=mock_loss_reduction, | |
| cpu=True, | |
| ) | |
| assert megatron_parallel.pipeline == mock_pipeline | |
| assert megatron_parallel.precision_plugin == mock_precision_plugin | |
| assert megatron_parallel.callbacks == mock_callbacks | |
| assert megatron_parallel.data_step == mock_data_step | |
| assert megatron_parallel.forward_step == mock_forward_step | |
| assert megatron_parallel.loss_reduction == mock_loss_reduction | |
| class TestCallbackConnector: | |
| def test_add_callbacks(self) -> None: | |
| callback_connector = mp.CallbackConnector() | |
| callback = TestCallback() | |
| callback_connector.add(callback) | |
| assert callback in callback_connector.callbacks["on_megatron_step_start"] | |
| assert callback in callback_connector.callbacks["on_megatron_microbatch_start"] | |
| def test_event(self) -> None: | |
| callback_connector = mp.CallbackConnector() | |
| callback = TestCallback() | |
| callback_connector.add(callback) | |
| # Replace mocker.spy with manual mocking | |
| callback.on_megatron_step_start = MagicMock() | |
| callback.on_megatron_microbatch_start = MagicMock() | |
| callback_connector.event("on_megatron_step_start") | |
| callback_connector.event("on_megatron_microbatch_start") | |
| assert callback.on_megatron_step_start.call_count == 1 | |
| assert callback.on_megatron_microbatch_start.call_count == 1 | |
| def test_add_connector(self) -> None: | |
| callback_connector1 = mp.CallbackConnector() | |
| callback_connector2 = mp.CallbackConnector() | |
| callback1 = TestCallback() | |
| callback2 = TestCallback() | |
| callback_connector1.add(callback1) | |
| callback_connector2.add(callback2) | |
| callback_connector1 += callback_connector2 | |
| assert callback1 in callback_connector1.callbacks["on_megatron_step_start"] | |
| assert callback2 in callback_connector1.callbacks["on_megatron_step_start"] | |
| def test_contains(self): | |
| callback_connector = mp.CallbackConnector() | |
| callback = TestCallback() | |
| callback_connector.add(callback) | |
| assert callback in callback_connector | |
| def test_add_count_callback(self): | |
| """Test adding a CountCallback to the CallbackConnector.""" | |
| connector = mp.CallbackConnector() | |
| count_callback = CountCallback() | |
| connector.add(count_callback) | |
| # Check if the CountCallback has been added correctly | |
| assert count_callback in connector, "CountCallback should be in the CallbackConnector" | |
| def test_event_trigger_with_count_callback(self): | |
| """Test if the event triggers the method in CountCallback.""" | |
| connector = mp.CallbackConnector() | |
| count_callback = CountCallback() | |
| connector.add(count_callback) | |
| # Simulate an event that CountCallback listens to | |
| connector.event('on_megatron_step_start') | |
| # Check if the CountCallback's method was called | |
| assert ( | |
| count_callback.counts["on_megatron_step_start"] == 1 | |
| ), "CountCallback's method should have been triggered once" | |
| class TestCallback: | |
| def on_megatron_step_start(self): | |
| pass | |
| def on_megatron_microbatch_start(self): | |
| pass | |
| class CountCallback: | |
| def __init__(self) -> None: | |
| self.counts = defaultdict(int) | |
| def on_megatron_step_start(self, *args, **kwargs) -> None: | |
| # assert len(kwargs) == 12 | |
| self.counts["on_megatron_step_start"] += 1 | |
| def on_megatron_microbatch_start(self, *args, **kwargs) -> None: | |
| # assert len(kwargs) == 14 | |
| self.counts["on_megatron_microbatch_start"] += 1 | |
| def on_megatron_microbatch_callback(self, *args, **kwargs) -> None: | |
| self.counts["on_megatron_microbatches_callback"] += 1 | |
| def on_megatron_microbatch_end(self, *args, **kwargs) -> None: | |
| self.counts["on_megatron_microbatches_end"] += 1 | |
| def on_megatron_reduce_microbatches_start(self, *args, **kwargs) -> None: | |
| self.counts["on_megatron_reduce_microbatches_start"] += 1 | |
| def on_megatron_reduce_microbatches_end(self, *args, **kwargs) -> None: | |
| self.counts["on_megatron_reduce_microbatches_end"] += 1 | |
| def on_megatron_log_step_end(self, *args, **kwargs) -> None: | |
| self.counts["on_megatron_log_step_end"] += 1 | |
| def on_megatron_step_end(self, *args, **kwargs) -> None: | |
| self.counts["on_megatron_step_end"] += 1 | |