MagpieTTS_Internal_Demo / tests /lightning /test_megatron_parallel.py
subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from unittest.mock import MagicMock
import pytest
from megatron.core import parallel_state
from torch import nn
from nemo import lightning as nl
from nemo.lightning import megatron_parallel as mp
class TestMegatronParallel:
"""Unit tests for the MegatronParallel class."""
@pytest.fixture
def mock_pipeline(self, mocker):
"""Fixture to create a mock pipeline."""
class DummyModule(nn.Module):
def __init__(self, dummy_arg=None):
self.dummy_arg = dummy_arg
super().__init__()
def forward(self, x):
return x
return DummyModule()
@pytest.fixture
def mock_precision_plugin(self, mocker):
"""Fixture to create a mock precision plugin."""
return nl.MegatronMixedPrecision(precision="bf16-mixed")
@pytest.fixture
def mock_callbacks(self, mocker):
"""Fixture to create a mock callback connector."""
return mocker.MagicMock(spec=mp.CallbackConnector)
@pytest.fixture
def mock_data_step(self, mocker):
"""Fixture to create a mock data step function."""
return mocker.MagicMock()
@pytest.fixture
def mock_forward_step(self, mocker):
"""Fixture to create a mock forward step function."""
return mocker.MagicMock()
@pytest.fixture
def mock_loss_reduction(self, mocker):
"""Fixture to create a mock loss reduction function."""
return mocker.MagicMock()
def test_init_with_defaults(self, mocker, mock_pipeline):
"""Test __init__ with default parameters."""
mocker.patch('megatron.core.parallel_state.get_pipeline_model_parallel_world_size', return_value=1)
mocker.patch('megatron.core.parallel_state.model_parallel_is_initialized', return_value=False)
megatron_parallel = mp.MegatronParallel(pipeline=mock_pipeline, cpu=True)
assert megatron_parallel.pipeline == mock_pipeline
assert megatron_parallel.precision_plugin is None
assert isinstance(megatron_parallel.callbacks, mp.CallbackConnector)
assert megatron_parallel.data_step == mp.default_data_step
assert megatron_parallel.forward_step == mp.default_forward_step
assert megatron_parallel.loss_reduction is None
def test_init_with_custom_parameters(
self,
mocker,
mock_pipeline,
mock_precision_plugin,
mock_callbacks,
mock_data_step,
mock_forward_step,
mock_loss_reduction,
):
"""Test __init__ with custom parameters."""
mocker.patch('megatron.core.parallel_state.get_pipeline_model_parallel_world_size', return_value=1)
mocker.patch('megatron.core.parallel_state.model_parallel_is_initialized', return_value=False)
megatron_parallel = mp.MegatronParallel(
pipeline=mock_pipeline,
precision_plugin=mock_precision_plugin,
callbacks=mock_callbacks,
data_step=mock_data_step,
forward_step=mock_forward_step,
loss_reduction=mock_loss_reduction,
cpu=True,
)
assert megatron_parallel.pipeline == mock_pipeline
assert megatron_parallel.precision_plugin == mock_precision_plugin
assert megatron_parallel.callbacks == mock_callbacks
assert megatron_parallel.data_step == mock_data_step
assert megatron_parallel.forward_step == mock_forward_step
assert megatron_parallel.loss_reduction == mock_loss_reduction
class TestCallbackConnector:
def test_add_callbacks(self) -> None:
callback_connector = mp.CallbackConnector()
callback = TestCallback()
callback_connector.add(callback)
assert callback in callback_connector.callbacks["on_megatron_step_start"]
assert callback in callback_connector.callbacks["on_megatron_microbatch_start"]
def test_event(self) -> None:
callback_connector = mp.CallbackConnector()
callback = TestCallback()
callback_connector.add(callback)
# Replace mocker.spy with manual mocking
callback.on_megatron_step_start = MagicMock()
callback.on_megatron_microbatch_start = MagicMock()
callback_connector.event("on_megatron_step_start")
callback_connector.event("on_megatron_microbatch_start")
assert callback.on_megatron_step_start.call_count == 1
assert callback.on_megatron_microbatch_start.call_count == 1
def test_add_connector(self) -> None:
callback_connector1 = mp.CallbackConnector()
callback_connector2 = mp.CallbackConnector()
callback1 = TestCallback()
callback2 = TestCallback()
callback_connector1.add(callback1)
callback_connector2.add(callback2)
callback_connector1 += callback_connector2
assert callback1 in callback_connector1.callbacks["on_megatron_step_start"]
assert callback2 in callback_connector1.callbacks["on_megatron_step_start"]
def test_contains(self):
callback_connector = mp.CallbackConnector()
callback = TestCallback()
callback_connector.add(callback)
assert callback in callback_connector
def test_add_count_callback(self):
"""Test adding a CountCallback to the CallbackConnector."""
connector = mp.CallbackConnector()
count_callback = CountCallback()
connector.add(count_callback)
# Check if the CountCallback has been added correctly
assert count_callback in connector, "CountCallback should be in the CallbackConnector"
def test_event_trigger_with_count_callback(self):
"""Test if the event triggers the method in CountCallback."""
connector = mp.CallbackConnector()
count_callback = CountCallback()
connector.add(count_callback)
# Simulate an event that CountCallback listens to
connector.event('on_megatron_step_start')
# Check if the CountCallback's method was called
assert (
count_callback.counts["on_megatron_step_start"] == 1
), "CountCallback's method should have been triggered once"
class TestCallback:
def on_megatron_step_start(self):
pass
def on_megatron_microbatch_start(self):
pass
class CountCallback:
def __init__(self) -> None:
self.counts = defaultdict(int)
def on_megatron_step_start(self, *args, **kwargs) -> None:
# assert len(kwargs) == 12
self.counts["on_megatron_step_start"] += 1
def on_megatron_microbatch_start(self, *args, **kwargs) -> None:
# assert len(kwargs) == 14
self.counts["on_megatron_microbatch_start"] += 1
def on_megatron_microbatch_callback(self, *args, **kwargs) -> None:
self.counts["on_megatron_microbatches_callback"] += 1
def on_megatron_microbatch_end(self, *args, **kwargs) -> None:
self.counts["on_megatron_microbatches_end"] += 1
def on_megatron_reduce_microbatches_start(self, *args, **kwargs) -> None:
self.counts["on_megatron_reduce_microbatches_start"] += 1
def on_megatron_reduce_microbatches_end(self, *args, **kwargs) -> None:
self.counts["on_megatron_reduce_microbatches_end"] += 1
def on_megatron_log_step_end(self, *args, **kwargs) -> None:
self.counts["on_megatron_log_step_end"] += 1
def on_megatron_step_end(self, *args, **kwargs) -> None:
self.counts["on_megatron_step_end"] += 1