subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
def set_env():
os.environ['NVTE_APPLY_QK_LAYER_SCALING'] = '1'
from pathlib import Path
import lightning.pytorch as pl
import pytest
import torch
import nemo.lightning as nl
from nemo.collections import llm
from nemo.lightning.io.pl import MegatronCheckpointIO
from nemo.utils.callbacks.dist_ckpt_io import AsyncFinalizableCheckpointIO, AsyncFinalizerCallback
def _get_strategy():
strategy = nl.MegatronStrategy(
enable_nemo_ckpt_io=False,
ckpt_async_save=False,
)
return strategy
def _get_last_checkpoint_dir(model: pl.LightningModule, suffix: str = '') -> Path:
return f'epoch={model.trainer.current_epoch - 1}-step={model.trainer.max_steps - 1}{suffix}'
def get_model_and_data(mbs=2, gbs=2):
seq_length = 128
data = llm.MockDataModule(seq_length=seq_length, micro_batch_size=mbs, global_batch_size=gbs)
config = llm.GPTConfig(
num_layers=2,
hidden_size=64,
ffn_hidden_size=256,
num_attention_heads=4,
seq_length=seq_length,
apply_query_key_layer_scaling=1,
)
return llm.GPTModel(config, tokenizer=data.tokenizer), data
class TestDistCkptIO:
@pytest.mark.run_only_on('GPU')
def test_dist_ckpt_io_called_for_mcore_models(self, tmp_path):
set_env()
assert os.environ['NVTE_APPLY_QK_LAYER_SCALING'] == '1'
gbs, mbs = 2, 2
model, data = get_model_and_data(mbs, gbs)
from tests.lightning.mcore_microbatch_utils import reconfigure_num_microbatches_calculator_manager
with reconfigure_num_microbatches_calculator_manager(0, None, gbs, mbs, data_parallel_size=1):
strategy = _get_strategy()
trainer = nl.Trainer(
devices=1,
accelerator="gpu",
strategy=strategy,
enable_checkpointing=True,
max_steps=2,
default_root_dir=str(tmp_path),
logger=False,
)
trainer.fit(model, data)
assert isinstance(trainer.strategy.checkpoint_io, MegatronCheckpointIO)
# Ckpt path doesn't contain the .ckpt suffix
ckpts = os.listdir(Path(tmp_path / "checkpoints"))
assert len(ckpts) == 1
ckpt = ckpts[0]
assert str(ckpt) == _get_last_checkpoint_dir(model)
trainer._teardown()
@pytest.mark.run_only_on('GPU')
def test_async_save_produces_same_checkpoints_as_sync(self, tmp_path):
set_env()
assert os.environ['NVTE_APPLY_QK_LAYER_SCALING'] == '1'
gbs, mbs = 2, 2
model, data = get_model_and_data(mbs, gbs)
from tests.lightning.mcore_microbatch_utils import reconfigure_num_microbatches_calculator_manager
with reconfigure_num_microbatches_calculator_manager(0, None, gbs, mbs, data_parallel_size=1):
sync_ckpt_dir = tmp_path / 'sync_checkpoints'
async_ckpt_dir = tmp_path / 'async_checkpoints'
sync_checkpoint_io = MegatronCheckpointIO('torch_dist')
async_checkpoint_io = AsyncFinalizableCheckpointIO(MegatronCheckpointIO('torch_dist', async_save=True))
# dummy_trainer just to initialize NCCL
dummy_trainer = pl.Trainer(
devices=1,
logger=False,
max_steps=2,
strategy=_get_strategy(),
)
dummy_trainer.fit(model, data)
strategy = _get_strategy()
## reset the model and data and train with sync checkpointing
model, data = get_model_and_data(mbs, gbs)
sync_test_trainer = pl.Trainer(
devices=1,
enable_checkpointing=True,
logger=False,
max_steps=2,
strategy=_get_strategy(),
plugins=[sync_checkpoint_io],
default_root_dir=str(sync_ckpt_dir),
)
sync_test_trainer.fit(model, data)
## reset the model and data and train with sync checkpointing
model, data = get_model_and_data(mbs, gbs)
async_test_trainer = pl.Trainer(
devices=1,
enable_checkpointing=True,
logger=False,
max_steps=2,
strategy=_get_strategy(),
plugins=[async_checkpoint_io],
callbacks=AsyncFinalizerCallback(),
default_root_dir=str(async_ckpt_dir),
)
async_test_trainer.fit(model, data)
sync_last_ckpt = f"{sync_ckpt_dir}/checkpoints/{_get_last_checkpoint_dir(model)}"
async_last_ckpt = f"{async_ckpt_dir}/checkpoints/{_get_last_checkpoint_dir(model)}"
sharded_state_dict_metadata = sync_checkpoint_io.load_content_metadata(sync_last_ckpt)
assert sharded_state_dict_metadata == async_checkpoint_io.checkpoint_io.load_content_metadata(async_last_ckpt)
## NOTE: model does not have `sharded_state_dict` attribute because
## this is after MegatronStrategy teardown
## so model class' __getattr__ gets replaced with original __getattr__
checkpoint = {'sharded_state_dict': model.module.sharded_state_dict(metadata=sharded_state_dict_metadata)}
sync_state_dict = sync_checkpoint_io.load_checkpoint(Path(sync_last_ckpt), sharded_state_dict=checkpoint)
async_state_dict = async_checkpoint_io.load_checkpoint(Path(async_last_ckpt), sharded_state_dict=checkpoint)
## one of the keys is a _io.BytesIO object
for k in sync_state_dict['sharded_state_dict'].keys():
if isinstance(sync_state_dict['sharded_state_dict'][k], torch.Tensor):
assert torch.all(sync_state_dict['sharded_state_dict'][k] == async_state_dict['sharded_state_dict'][k])
dummy_trainer._teardown()
def test_sharded_strategies(self):
set_env()
assert os.environ['NVTE_APPLY_QK_LAYER_SCALING'] == '1'
model_checkpoint = nl.ModelCheckpoint()
strategy = nl.MegatronStrategy(
enable_nemo_ckpt_io=False,
save_ckpt_format='torch_dist',
ckpt_parallel_save=True,
ckpt_load_directly_on_device=False,
ckpt_async_save=True,
)
trainer = nl.Trainer(
callbacks=[model_checkpoint],
strategy=strategy,
)
assert isinstance(strategy.checkpoint_io, AsyncFinalizableCheckpointIO)
assert isinstance(strategy.checkpoint_io._checkpoint_io, MegatronCheckpointIO)
base_checkpoint_io = strategy.checkpoint_io._checkpoint_io
assert base_checkpoint_io.save_ckpt_format == 'torch_dist'
assert base_checkpoint_io.parallel_save
assert base_checkpoint_io.load_directly_on_device == False
trainer._teardown()