MagpieTTS_Internal_Demo / tests /lightning /test_strategy_lib.py
subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import ANY, MagicMock, patch
import pytest
import torch
from torch import nn
from nemo.core.optim import MainParamsOptimizerWrapper
from nemo.lightning import MegatronStrategy, _strategy_lib # , DataConfig
class Identity(nn.Identity):
def __init__(self):
super().__init__()
class WithCopy(nn.Identity):
def copy(self):
return WithCopy()
class Optimizer:
def state_dict(self):
return {
"param_groups": [{"params": torch.nn.Parameter(torch.randn(3, 3, device='cuda', dtype=torch.float32))}],
"state": {0: {}, 1: {}},
}
def load_state_dict(self, state_dict):
return self.state_dict()
@property
def param_groups(self):
params = torch.nn.Parameter(torch.randn(3, 3, device='cuda', dtype=torch.float32))
params.requires_grad = True
return [{'params': [params], 'is_expert': True}]
class OptimizerWrapper(MainParamsOptimizerWrapper):
def __init_(self, optimizer):
super().__init__(optimizer)
class DummyOptimizer:
def __init__(self):
self._custom_amp_unscale_grads = True
self.step_called = False
def unscale_grads(self, *args):
print("Dummy unscale_grads called with:", args)
def step(self, *args, **kwargs):
print("Dummy optimizer step called.")
self.step_called = True
return "step_result"
class Model:
def __init__(self, prefix="", metadata=None):
self.prefix = prefix
self.metadta = metadata
def sharded_state_dict(self, prefix="", metadata=None):
return dict(test="test")
def make_optimizer_state():
found_inf_values = {"cuda:0": 0.0} # Default: no infs found
return {
"found_inf_per_device": {
device: torch.tensor(val, dtype=torch.float32, device="cuda") for device, val in found_inf_values.items()
}
}
def test_set_model_parallel_attributes() -> None:
strategy = MegatronStrategy(
pipeline_model_parallel_size=2,
expert_model_parallel_size=2,
sequence_parallel=False,
pipeline_dtype=torch.float32,
)
from megatron.core.transformer.transformer_config import TransformerConfig
class DummyModel:
def __init__(self):
self.config = TransformerConfig(
hidden_size=128, num_attention_heads=2, num_layers=2, num_moe_experts=2, add_bias_linear=False
)
def configure_model(self):
pass
model = DummyModel()
assert model.config.pipeline_model_parallel_size != 2
assert model.config.expert_model_parallel_size != 2
assert model.config.pipeline_dtype != torch.float32
_strategy_lib.set_model_parallel_attributes(model, strategy.parallelism)
assert model.config.pipeline_model_parallel_size == 2
assert model.config.expert_model_parallel_size == 2
assert model.config.sequence_parallel == False
assert model.config.pipeline_dtype == torch.float32
def test_init_parallel_ranks() -> None:
from megatron.core.num_microbatches_calculator import destroy_num_microbatches_calculator
from megatron.core.parallel_state import destroy_model_parallel
from nemo.utils import AppState
app_state = AppState()
app_state.tensor_model_parallel_size = 2
app_state.pipeline_model_parallel_size = 3
app_state.context_parallel_size = 2
app_state.expert_model_parallel_size = 2
app_state.global_rank = 1
app_state.local_rank = 0
mock_parallel_config = MagicMock()
mock_parallel_config.tensor_model_parallel_size = 2
mock_parallel_config.pipeline_model_parallel_size = 3
mock_parallel_config.virtual_pipeline_model_parallel_size = 4
mock_parallel_config.context_parallel_size = 2
mock_parallel_config.expert_model_parallel_size = 2
mock_parallel_config.expert_tensor_parallel_size = None
mock_parallel_config.tp_comm_overlap = False
mock_parallel_config.use_te_rng_tracker = False
_strategy_lib.init_parallel_ranks(
world_size=24,
global_rank=1,
local_rank=0,
parallel_config=mock_parallel_config,
seed=1234,
fp8=False,
)
expected_app_state = {
"world_size": 24,
"global_rank": 1,
"local_rank": 0,
"tensor_model_parallel_size": 2,
"pipeline_model_parallel_size": 3,
"virtual_pipeline_model_parallel_size": 4,
"context_parallel_size": 2,
"expert_model_parallel_size": 2,
"use_fp8": False,
"init_mpi_proc_group": False,
}
for k, v in expected_app_state.items():
assert hasattr(app_state, k), f"Expected to find {k} in AppState"
app_attr = getattr(app_state, k)
assert app_attr == v, f"{k} in AppState is incorrect, Expected: {v} Actual: {app_attr}"
destroy_model_parallel()
destroy_num_microbatches_calculator()
@patch('torch.distributed.is_initialized', return_value=True)
@patch('megatron.core.parallel_state')
def test_init_model_parallel(mock_mpu, *args):
from nemo.utils import AppState
app_state = AppState()
app_state.model_parallel_size = 1
app_state.tensor_model_parallel_size = 2
app_state.pipeline_model_parallel_size = 1
app_state.pipeline_model_parallel_comm_backend = None
app_state.context_parallel_size = 2
app_state.expert_model_parallel_size = 2
app_state.expert_tensor_parallel_size = 1
app_state.expert_tensor_parallel_rank = 0
app_state.init_mpi_proc_group = False
app_state.tensor_model_parallel_rank = 2
app_state.pipeline_model_parallel_rank = 0
_mpu_tp_2(mock_mpu)
_strategy_lib.init_model_parallel(nn.Identity())
mock_mpu.initialize_model_parallel.assert_called_once_with(
tensor_model_parallel_size=2,
pipeline_model_parallel_size=1,
virtual_pipeline_model_parallel_size=None,
pipeline_model_parallel_comm_backend=None,
context_parallel_size=2,
expert_model_parallel_size=2,
expert_tensor_parallel_size=1,
use_sharp=False,
order="tp-cp-ep-dp-pp",
num_distributed_optimizer_instances=1,
nccl_communicator_config_path=None,
create_gloo_process_groups=True,
)
@patch('torch.distributed.is_initialized', return_value=True)
@patch('megatron.core.parallel_state')
def test_init_model_parallel_with_tp_pp_dp(mock_mpu, *args):
from nemo.utils import AppState
app_state = AppState()
app_state.model_parallel_size = 1
app_state.tensor_model_parallel_size = 2
app_state.pipeline_model_parallel_size = 1
app_state.pipeline_model_parallel_comm_backend = None
app_state.context_parallel_size = 2
app_state.expert_model_parallel_size = 2
app_state.expert_tensor_parallel_size = 1
app_state.expert_tensor_parallel_rank = 0
app_state.init_mpi_proc_group = False
app_state.tensor_model_parallel_rank = 2
app_state.pipeline_model_parallel_rank = 0
app_state.use_tp_pp_dp_mapping = True
_mpu_tp_2(mock_mpu)
_strategy_lib.init_model_parallel(nn.Identity())
mock_mpu.initialize_model_parallel.assert_called_once_with(
tensor_model_parallel_size=2,
pipeline_model_parallel_size=1,
virtual_pipeline_model_parallel_size=None,
pipeline_model_parallel_comm_backend=None,
context_parallel_size=2,
expert_model_parallel_size=2,
expert_tensor_parallel_size=1,
use_sharp=False,
order="tp-cp-ep-pp-dp",
num_distributed_optimizer_instances=1,
nccl_communicator_config_path=None,
create_gloo_process_groups=True,
)
@pytest.mark.run_only_on('GPU')
def test_optimizer_sharded_state_dict():
model = Model()
optimizer = Optimizer()
optimizer = OptimizerWrapper(optimizer)
optimizer_state_dict = _strategy_lib.optimizer_sharded_state_dict(model, optimizer, sharding_type="test")
assert optimizer_state_dict['fp32_from_fp16_params'] == [[]]
@pytest.mark.run_only_on('GPU')
@patch('torch.distributed.is_initialized', return_value=True)
@patch('megatron.core.parallel_state')
def test_grad_scaler(mock_mpu, *args):
scaler = _strategy_lib.GradScaler()
optimizer = DummyOptimizer()
scaler._unscale_grads_(optimizer)
optimizer_state = make_optimizer_state()
scaler._maybe_opt_step(optimizer, optimizer_state)
state_dict = scaler.state_dict()
assert type(state_dict) is dict
scaler.load_state_dict(state_dict)
try:
scaler.update()
except AssertionError:
pass
# TODO @chcui uncomment after fabric API is merged
# @patch('nemo.lightning._strategy_lib.DataLoader', return_value=MagicMock())
# @patch('megatron.core.parallel_state')
# def test_process_dataloader(mock_mpu, mock_dataloader) -> None:
# mock_dataloader_instance = MagicMock()
# mock_dataloader_instance.dataset = [1, 2, 3]
# mock_dataloader_instance.num_workers = 4
# mock_dataloader_instance.pin_memory = True
# mock_dataloader_instance.persistent_workers = False
#
# data_config = DataConfig(256)
# data_config.micro_batch_size = 2
# data_config.global_batch_size = 6
# data_config.rampup_batch_size = 3
#
# mock_mpu.get_data_parallel_rank.return_value = 0
# mock_mpu.get_data_parallel_world_size.return_value = 1
#
# out = _strategy_lib.process_dataloader(mock_dataloader_instance, data_config)
# assert isinstance(out.batch_sampler, MagicMock)
# mock_dataloader.assert_called_once_with(
# mock_dataloader_instance.dataset,
# batch_sampler=ANY,
# num_workers=4,
# pin_memory=True,
# persistent_workers=False,
# collate_fn=ANY
# )
# @patch('nemo.lightning._strategy_lib.init_parallel_ranks')
# @patch('megatron.core.parallel_state')
# def test_setup_megatron_parallel_with_trainer(mock_mpu, mock_init_parallel_ranks) -> None:
# _mpu_tp_2(mock_mpu)
# mock_trainer = MagicMock(spec=pl.Trainer)
# mock_trainer.strategy = MegatronStrategy(
# ModelParallelConfig(tensor_model_parallel_size=2),
# DataConfig(256),
# )
# mock_trainer.world_size = 2
# mock_trainer.local_rank = 0
# mock_trainer.global_rank = 1
# result = _strategy_lib.setup_megatron_parallel(mock_trainer, nn.Identity())
# mock_init_parallel_ranks.assert_called_once()
# assert isinstance(result, LightningMegatronParallel)
# assert len(result) == 1
# # Test with function
# assert len(_strategy_lib.setup_megatron_parallel(mock_trainer, lambda: nn.Identity())) == 1
# @patch('nemo.lightning._strategy_lib.init_parallel_ranks')
# @patch('megatron.core.parallel_state')
# def test_setup_megatron_parallel_virtual_pipelining(mock_mpu, mock_init_parallel_ranks) -> None:
# vp_size = 4
# _mpu_tp_2(mock_mpu)
# mock_mpu.get_pipeline_model_parallel_world_size.return_value = 4
# mock_trainer = MagicMock(spec=pl.Trainer)
# mock_trainer.strategy = MegatronStrategy(
# ModelParallelConfig(
# virtual_pipeline_model_parallel_size=vp_size,
# tensor_model_parallel_size=2,
# ),
# DataConfig(256),
# )
# mock_trainer.world_size = 8
# mock_trainer.local_rank = 0
# mock_trainer.global_rank = 1
# result = _strategy_lib.setup_megatron_parallel(mock_trainer, Identity())
# mock_init_parallel_ranks.assert_called_once()
# assert len(result) == vp_size
# # Test with function
# assert len(_strategy_lib.setup_megatron_parallel(mock_trainer, lambda: nn.Identity())) == vp_size
# # Test with a module with a copy method
# assert len(_strategy_lib.setup_megatron_parallel(mock_trainer, WithCopy())) == vp_size
# with pytest.raises(
# ValueError,
# match="Model does not have a copy method. Please implement this or " +
# "pass in a function that returns the model"
# ):
# _strategy_lib.setup_megatron_parallel(mock_trainer, nn.Identity())
# @patch('nemo.lightning._strategy_lib.init_parallel_ranks')
# @patch('megatron.core.parallel_state')
# def test_setup_megatron_parallel_with_fabric(mock_mpu, mock_init_parallel_ranks) -> None:
# _mpu_tp_2(mock_mpu)
# mock_trainer = MagicMock(spec=fl.Fabric)
# mock_trainer.strategy = FabricMegatronStrategy(
# ModelParallelConfig(tensor_model_parallel_size=2),
# DataConfig(256),
# )
# mock_trainer.world_size = 2
# mock_trainer.local_rank = 0
# mock_trainer.global_rank = 1
# result = _strategy_lib.setup_megatron_parallel(mock_trainer, nn.Identity())
# mock_init_parallel_ranks.assert_called_once()
# assert isinstance(result, MegatronParallel)
# assert len(result) == 1
# @patch('nemo.lightning._strategy_lib.init_parallel_ranks')
# @patch('megatron.core.parallel_state')
# def test_setup_megatron_parallel_with_strategy(mock_mpu, mock_init_parallel_ranks) -> None:
# _mpu_tp_2(mock_mpu)
# mock_trainer = MagicMock(spec=FabricMegatronStrategy)
# mock_trainer.configure_mock(
# parallelism=ModelParallelConfig(tensor_model_parallel_size=2),
# data_config=DataConfig(256),
# world_size=2,
# local_rank=0,
# global_rank=1
# )
# result = _strategy_lib.setup_megatron_parallel(mock_trainer, nn.Identity())
# mock_init_parallel_ranks.assert_called_once()
# assert isinstance(result, MegatronParallel)
# assert len(result) == 1
def _mpu_tp_2(mock_mpu) -> None:
mock_mpu.get_tensor_model_parallel_rank.return_value = 2
mock_mpu.get_pipeline_model_parallel_rank.return_value = 0
mock_mpu.get_pipeline_model_parallel_world_size.return_value = 1
mock_mpu.get_pipeline_model_parallel_group.return_value = 0
mock_mpu.get_tensor_model_parallel_group.return_value = 1
mock_mpu.get_expert_tensor_parallel_rank.return_value = 0