# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import lightning.pytorch as pl import nemo_run as run import torch from lightning.pytorch.loggers import WandbLogger from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.optimizer import OptimizerConfig from nemo import lightning as nl from nemo.collections import llm from nemo.collections.diffusion.data.diffusion_energon_datamodule import DiffusionDataModule from nemo.collections.diffusion.data.diffusion_mock_datamodule import MockDataModule from nemo.collections.diffusion.data.diffusion_taskencoder import RawImageDiffusionTaskEncoder from nemo.collections.diffusion.models.flux.model import ClipConfig, FluxConfig, FluxModelParams, T5Config from nemo.collections.diffusion.models.flux_controlnet.model import FluxControlNetConfig, MegatronFluxControlNetModel from nemo.collections.diffusion.vae.autoencoder import AutoEncoderConfig from nemo.lightning.pytorch.callbacks.nsys import NsysCallback from nemo.lightning.pytorch.optim import WarmupHoldPolicyScheduler from nemo.utils.exp_manager import TimingCallback @run.cli.factory @run.autoconvert def flux_datamodule(dataset_dir) -> pl.LightningDataModule: """Flux Datamodule Initialization""" data_module = DiffusionDataModule( dataset_dir, seq_length=4096, task_encoder=run.Config( RawImageDiffusionTaskEncoder, ), micro_batch_size=1, global_batch_size=8, num_workers=23, use_train_split_for_val=True, ) return data_module @run.cli.factory @run.autoconvert def flux_mock_datamodule() -> pl.LightningDataModule: """Mock Datamodule Initialization""" data_module = MockDataModule( image_h=1024, image_w=1024, micro_batch_size=1, global_batch_size=1, image_precached=True, text_precached=True, ) return data_module @run.cli.factory(target=llm.train) def flux_controlnet_training() -> run.Partial: """Flux Controlnet Training Config""" return run.Partial( llm.train, model=run.Config( MegatronFluxControlNetModel, flux_params=run.Config(FluxModelParams), flux_controlnet_config=run.Config(FluxControlNetConfig), seed=42, ), data=flux_mock_datamodule(), trainer=run.Config( nl.Trainer, devices=1, num_nodes=int(os.environ.get('SLURM_NNODES', 1)), accelerator="gpu", strategy=run.Config( nl.MegatronStrategy, tensor_model_parallel_size=1, pipeline_model_parallel_size=1, context_parallel_size=1, sequence_parallel=False, pipeline_dtype=torch.bfloat16, ddp=run.Config( DistributedDataParallelConfig, data_parallel_sharding_strategy='optim_grads_params', check_for_nan_in_grad=True, grad_reduce_in_fp32=True, overlap_param_gather=True, overlap_grad_reduce=True, ), fsdp='megatron', ), plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), num_sanity_val_steps=0, limit_val_batches=1, val_check_interval=1000, max_steps=50000, log_every_n_steps=1, callbacks=[ run.Config( nl.ModelCheckpoint, monitor='global_step', filename='{global_step}', every_n_train_steps=1000, save_last=False, save_top_k=3, mode='max', save_on_train_epoch_end=True, ), run.Config(TimingCallback), ], ), log=nl.NeMoLogger(wandb=(WandbLogger() if "WANDB_API_KEY" in os.environ else None)), optim=run.Config( nl.MegatronOptimizerModule, config=run.Config( OptimizerConfig, lr=1e-4, adam_beta1=0.9, adam_beta2=0.999, use_distributed_optimizer=True, bf16=True, ), lr_scheduler=run.Config( WarmupHoldPolicyScheduler, warmup_steps=500, hold_steps=1000000000000, ), ), tokenizer=None, resume=run.Config( nl.AutoResume, resume_if_exists=True, resume_ignore_no_checkpoint=True, resume_past_end=True, ), model_transform=None, ) @run.cli.factory(target=llm.train) def convergence_test(megatron_fsdp=True) -> run.Partial: ''' A convergence recipe with real data loader. Image and text embedding calculated on the fly. ''' recipe = flux_controlnet_training() recipe.model.flux_params.t5_params = run.Config(T5Config, version='/ckpts/text_encoder_2') recipe.model.flux_params.clip_params = run.Config(ClipConfig, version='/ckpts/text_encoder') recipe.model.flux_params.vae_config = run.Config( AutoEncoderConfig, ckpt='/ckpts/ae.safetensors', ch_mult=[1, 2, 4, 4], attn_resolutions=[] ) recipe.model.flux_params.device = 'cuda' recipe.model.flux_params.flux_config = run.Config( FluxConfig, ckpt_path='/ckpts/transformer', calculate_per_token_loss=False, gradient_accumulation_fusion=False ) recipe.model.flux_params.flux_config.do_convert_from_hf = True recipe.trainer.devices = 2 recipe.data = flux_datamodule('/dataset/fill50k/fill50k_tarfiles/') recipe.model.flux_controlnet_config.num_single_layers = 0 recipe.model.flux_controlnet_config.num_joint_layers = 4 if megatron_fsdp: configure_megatron_fsdp(recipe) else: configure_ddp(recipe) recipe.optim.config.lr = 5e-5 recipe.data.global_batch_size = 2 return recipe @run.cli.factory(target=llm.train) def fp8_test(megatron_fsdp=True) -> run.Partial: ''' A convergence recipe with real data loader. Image and text embedding calculated on the fly. ''' recipe = flux_controlnet_training() recipe.model.flux_params.t5_params = run.Config(T5Config, version='/ckpts/text_encoder_2') recipe.model.flux_params.clip_params = run.Config(ClipConfig, version='/ckpts/text_encoder') recipe.model.flux_params.vae_config = run.Config( AutoEncoderConfig, ckpt='/ckpts/ae.safetensors', ch_mult=[1, 2, 4, 4], attn_resolutions=[] ) recipe.model.flux_params.device = 'cuda' recipe.model.flux_params.flux_config = run.Config( FluxConfig, ckpt_path='/ckpts/nemo_flux_transformer.safetensors', guidance_embed=False, calculate_per_token_loss=False, gradient_accumulation_fusion=False, ) recipe.trainer.devices = 2 recipe.data = flux_datamodule('/mingyuanm/dataset/fill50k/fill50k_tarfiles/') recipe.model.flux_controlnet_config.num_single_layers = 0 recipe.model.flux_controlnet_config.num_joint_layers = 4 recipe.model.flux_controlnet_config.guidance_embed = False if megatron_fsdp: configure_megatron_fsdp(recipe) else: configure_ddp(recipe) recipe.optim.config.lr = 5e-5 recipe.trainer.plugins = run.Config( nl.MegatronMixedPrecision, precision="bf16-mixed", fp8='hybrid', fp8_margin=0, fp8_amax_history_len=1024, fp8_amax_compute_algo="max", fp8_params=False, ) return recipe @run.cli.factory(target=llm.train) def convergence_tp2() -> run.Partial: ''' A convergence recipe with real data loader. Image and text embedding calculated on the fly. ''' recipe = flux_controlnet_training() recipe.model.flux_params.t5_params = run.Config(T5Config, version='/ckpts/text_encoder_2') recipe.model.flux_params.clip_params = run.Config(ClipConfig, version='/ckpts/text_encoder') recipe.model.flux_params.vae_config = run.Config( AutoEncoderConfig, ckpt='/ckpts/ae.safetensors', ch_mult=[1, 2, 4, 4], attn_resolutions=[] ) recipe.model.flux_params.device = 'cuda' recipe.model.flux_params.flux_config = run.Config( FluxConfig, ckpt_path='/ckpts/nemo_dist_ckpt/weights/', load_dist_ckpt=True, calculate_per_token_loss=False, gradient_accumulation_fusion=False, ) recipe.trainer.devices = 2 recipe.trainer.max_steps = 30000 recipe.trainer.strategy.tensor_model_parallel_size = 2 recipe.data = flux_datamodule('/dataset/fill50k/fill50k_tarfiles/') recipe.data.global_batch_size = 2 recipe.model.flux_controlnet_config.num_single_layers = 0 recipe.model.flux_controlnet_config.num_joint_layers = 4 return recipe @run.cli.factory(target=llm.train) def full_model_tp2_dp4_mock() -> run.Partial: ''' An example recipe uses tp 2 dp 4 with mock dataset. ''' recipe = flux_controlnet_training() recipe.model.flux_params.t5_params = None recipe.model.flux_params.clip_params = None recipe.model.flux_params.vae_config = None recipe.model.flux_params.device = 'cuda' recipe.trainer.strategy.tensor_model_parallel_size = 2 recipe.trainer.devices = 8 recipe.data.global_batch_size = 8 recipe.trainer.callbacks.append(run.Config(NsysCallback, start_step=10, end_step=11, gen_shape=True)) recipe.model.flux_controlnet_config.num_single_layers = 10 recipe.model.flux_controlnet_config.num_joint_layers = 4 return recipe @run.cli.factory(target=llm.train) def unit_test(megatron_fsdp=True) -> run.Partial: ''' Basic functional test, with mock dataset, text/vae encoders not initialized, ddp strategy, frozen and trainable layers both set to 1 ''' recipe = flux_controlnet_training() recipe.model.flux_params.t5_params = None recipe.model.flux_params.clip_params = None recipe.model.flux_params.vae_config = None recipe.model.flux_params.device = 'cuda' recipe.model.flux_params.flux_config = run.Config( FluxConfig, num_joint_layers=1, num_single_layers=1, ) recipe.model.flux_controlnet_config.num_single_layers = 1 recipe.model.flux_controlnet_config.num_joint_layers = 1 recipe.data.global_batch_size = 1 if megatron_fsdp: configure_megatron_fsdp(recipe) else: configure_ddp(recipe) recipe.trainer.max_steps = 10 return recipe def configure_megatron_fsdp(recipe) -> run.Partial: recipe.trainer.strategy.ddp = run.Config( DistributedDataParallelConfig, data_parallel_sharding_strategy='optim_grads_params', # Custom FSDP check_for_nan_in_grad=True, grad_reduce_in_fp32=True, overlap_param_gather=True, # Megatron FSDP requires this overlap_grad_reduce=True, # Megatron FSDP requires this use_megatron_fsdp=True, ) recipe.trainer.strategy.fsdp = 'megatron' return recipe def configure_ddp(recipe) -> run.Partial: recipe.trainer.strategy.ddp = run.Config( DistributedDataParallelConfig, check_for_nan_in_grad=True, grad_reduce_in_fp32=True, ) recipe.trainer.strategy.fsdp = None return recipe if __name__ == "__main__": run.cli.main(llm.train, default_factory=unit_test)