Spaces:
Runtime error
Runtime error
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| def set_env(): | |
| os.environ['NVTE_APPLY_QK_LAYER_SCALING'] = '1' | |
| from pathlib import Path | |
| import lightning.pytorch as pl | |
| import pytest | |
| import torch | |
| import nemo.lightning as nl | |
| from nemo.collections import llm | |
| from nemo.lightning.io.pl import MegatronCheckpointIO | |
| from nemo.utils.callbacks.dist_ckpt_io import AsyncFinalizableCheckpointIO, AsyncFinalizerCallback | |
| def _get_strategy(): | |
| strategy = nl.MegatronStrategy( | |
| enable_nemo_ckpt_io=False, | |
| ckpt_async_save=False, | |
| ) | |
| return strategy | |
| def _get_last_checkpoint_dir(model: pl.LightningModule, suffix: str = '') -> Path: | |
| return f'epoch={model.trainer.current_epoch - 1}-step={model.trainer.max_steps - 1}{suffix}' | |
| def get_model_and_data(mbs=2, gbs=2): | |
| seq_length = 128 | |
| data = llm.MockDataModule(seq_length=seq_length, micro_batch_size=mbs, global_batch_size=gbs) | |
| config = llm.GPTConfig( | |
| num_layers=2, | |
| hidden_size=64, | |
| ffn_hidden_size=256, | |
| num_attention_heads=4, | |
| seq_length=seq_length, | |
| apply_query_key_layer_scaling=1, | |
| ) | |
| return llm.GPTModel(config, tokenizer=data.tokenizer), data | |
| class TestDistCkptIO: | |
| def test_dist_ckpt_io_called_for_mcore_models(self, tmp_path): | |
| set_env() | |
| assert os.environ['NVTE_APPLY_QK_LAYER_SCALING'] == '1' | |
| gbs, mbs = 2, 2 | |
| model, data = get_model_and_data(mbs, gbs) | |
| from tests.lightning.mcore_microbatch_utils import reconfigure_num_microbatches_calculator_manager | |
| with reconfigure_num_microbatches_calculator_manager(0, None, gbs, mbs, data_parallel_size=1): | |
| strategy = _get_strategy() | |
| trainer = nl.Trainer( | |
| devices=1, | |
| accelerator="gpu", | |
| strategy=strategy, | |
| enable_checkpointing=True, | |
| max_steps=2, | |
| default_root_dir=str(tmp_path), | |
| logger=False, | |
| ) | |
| trainer.fit(model, data) | |
| assert isinstance(trainer.strategy.checkpoint_io, MegatronCheckpointIO) | |
| # Ckpt path doesn't contain the .ckpt suffix | |
| ckpts = os.listdir(Path(tmp_path / "checkpoints")) | |
| assert len(ckpts) == 1 | |
| ckpt = ckpts[0] | |
| assert str(ckpt) == _get_last_checkpoint_dir(model) | |
| trainer._teardown() | |
| def test_async_save_produces_same_checkpoints_as_sync(self, tmp_path): | |
| set_env() | |
| assert os.environ['NVTE_APPLY_QK_LAYER_SCALING'] == '1' | |
| gbs, mbs = 2, 2 | |
| model, data = get_model_and_data(mbs, gbs) | |
| from tests.lightning.mcore_microbatch_utils import reconfigure_num_microbatches_calculator_manager | |
| with reconfigure_num_microbatches_calculator_manager(0, None, gbs, mbs, data_parallel_size=1): | |
| sync_ckpt_dir = tmp_path / 'sync_checkpoints' | |
| async_ckpt_dir = tmp_path / 'async_checkpoints' | |
| sync_checkpoint_io = MegatronCheckpointIO('torch_dist') | |
| async_checkpoint_io = AsyncFinalizableCheckpointIO(MegatronCheckpointIO('torch_dist', async_save=True)) | |
| # dummy_trainer just to initialize NCCL | |
| dummy_trainer = pl.Trainer( | |
| devices=1, | |
| logger=False, | |
| max_steps=2, | |
| strategy=_get_strategy(), | |
| ) | |
| dummy_trainer.fit(model, data) | |
| strategy = _get_strategy() | |
| ## reset the model and data and train with sync checkpointing | |
| model, data = get_model_and_data(mbs, gbs) | |
| sync_test_trainer = pl.Trainer( | |
| devices=1, | |
| enable_checkpointing=True, | |
| logger=False, | |
| max_steps=2, | |
| strategy=_get_strategy(), | |
| plugins=[sync_checkpoint_io], | |
| default_root_dir=str(sync_ckpt_dir), | |
| ) | |
| sync_test_trainer.fit(model, data) | |
| ## reset the model and data and train with sync checkpointing | |
| model, data = get_model_and_data(mbs, gbs) | |
| async_test_trainer = pl.Trainer( | |
| devices=1, | |
| enable_checkpointing=True, | |
| logger=False, | |
| max_steps=2, | |
| strategy=_get_strategy(), | |
| plugins=[async_checkpoint_io], | |
| callbacks=AsyncFinalizerCallback(), | |
| default_root_dir=str(async_ckpt_dir), | |
| ) | |
| async_test_trainer.fit(model, data) | |
| sync_last_ckpt = f"{sync_ckpt_dir}/checkpoints/{_get_last_checkpoint_dir(model)}" | |
| async_last_ckpt = f"{async_ckpt_dir}/checkpoints/{_get_last_checkpoint_dir(model)}" | |
| sharded_state_dict_metadata = sync_checkpoint_io.load_content_metadata(sync_last_ckpt) | |
| assert sharded_state_dict_metadata == async_checkpoint_io.checkpoint_io.load_content_metadata(async_last_ckpt) | |
| ## NOTE: model does not have `sharded_state_dict` attribute because | |
| ## this is after MegatronStrategy teardown | |
| ## so model class' __getattr__ gets replaced with original __getattr__ | |
| checkpoint = {'sharded_state_dict': model.module.sharded_state_dict(metadata=sharded_state_dict_metadata)} | |
| sync_state_dict = sync_checkpoint_io.load_checkpoint(Path(sync_last_ckpt), sharded_state_dict=checkpoint) | |
| async_state_dict = async_checkpoint_io.load_checkpoint(Path(async_last_ckpt), sharded_state_dict=checkpoint) | |
| ## one of the keys is a _io.BytesIO object | |
| for k in sync_state_dict['sharded_state_dict'].keys(): | |
| if isinstance(sync_state_dict['sharded_state_dict'][k], torch.Tensor): | |
| assert torch.all(sync_state_dict['sharded_state_dict'][k] == async_state_dict['sharded_state_dict'][k]) | |
| dummy_trainer._teardown() | |
| def test_sharded_strategies(self): | |
| set_env() | |
| assert os.environ['NVTE_APPLY_QK_LAYER_SCALING'] == '1' | |
| model_checkpoint = nl.ModelCheckpoint() | |
| strategy = nl.MegatronStrategy( | |
| enable_nemo_ckpt_io=False, | |
| save_ckpt_format='torch_dist', | |
| ckpt_parallel_save=True, | |
| ckpt_load_directly_on_device=False, | |
| ckpt_async_save=True, | |
| ) | |
| trainer = nl.Trainer( | |
| callbacks=[model_checkpoint], | |
| strategy=strategy, | |
| ) | |
| assert isinstance(strategy.checkpoint_io, AsyncFinalizableCheckpointIO) | |
| assert isinstance(strategy.checkpoint_io._checkpoint_io, MegatronCheckpointIO) | |
| base_checkpoint_io = strategy.checkpoint_io._checkpoint_io | |
| assert base_checkpoint_io.save_ckpt_format == 'torch_dist' | |
| assert base_checkpoint_io.parallel_save | |
| assert base_checkpoint_io.load_directly_on_device == False | |
| trainer._teardown() | |