# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import importlib from unittest.mock import MagicMock from lightning.pytorch.callbacks import Callback as PTLCallback from nemo.lightning.base_callback import BaseCallback def _fresh_group_module(): """Reset the CallbackGroup singleton and stub OneLoggerNeMoCallback safely. This avoids deleting modules from sys.modules. We import the module, replace the OneLoggerNeMoCallback symbol with a lightweight stub, and reset the internal singleton so a new instance is built. """ mod = importlib.import_module('nemo.lightning.callback_group') class _StubOneLoggerCallback(BaseCallback): def __init__(self, *args, **kwargs): pass def update_config(self, *args, **kwargs): pass setattr(mod, 'OneLoggerNeMoCallback', _StubOneLoggerCallback) # Reset the singleton so the next get_instance() uses the stubbed class mod.CallbackGroup._instance = None return mod def test_base_callback_noops_do_not_raise(): """Test BaseCallback hooks are no-ops and do not raise exceptions.""" cb = BaseCallback() cb.on_app_start() cb.on_app_end() cb.on_model_init_start() cb.on_model_init_end() cb.on_dataloader_init_start() cb.on_dataloader_init_end() cb.on_optimizer_init_start() cb.on_optimizer_init_end() cb.on_load_checkpoint_start() cb.on_load_checkpoint_end() cb.on_save_checkpoint_start() cb.on_save_checkpoint_end() cb.on_save_checkpoint_success() cb.update_config() def test_base_callback_is_ptl_callback(): """Test BaseCallback derives from Lightning PTL Callback.""" assert isinstance(BaseCallback(), PTLCallback) def test_callback_group_singleton_identity(): """Test CallbackGroup returns the same singleton instance.""" mod = _fresh_group_module() a = mod.CallbackGroup.get_instance() b = mod.CallbackGroup.get_instance() assert a is b def test_callback_group_update_config_fanout_and_attach(monkeypatch): """Test update_config fans out to callbacks and attaches them to trainer.""" mod = _fresh_group_module() group = mod.CallbackGroup.get_instance() class _StubCallback(BaseCallback): def __init__(self): self.called = False self.kwargs = None def update_config(self, *args, **kwargs): self.called = True self.kwargs = kwargs stub_cb = _StubCallback() group._callbacks = [stub_cb] class Trainer: def __init__(self): self.callbacks = [] trainer = Trainer() marker = object() group.update_config('v2', trainer, data=marker) assert stub_cb.called kwargs = stub_cb.kwargs assert kwargs['nemo_version'] == 'v2' assert kwargs['trainer'] is trainer assert kwargs['data'] is marker assert trainer.callbacks[0] is stub_cb def test_callback_group_dynamic_dispatch_calls_when_present(): """Test dynamic dispatch calls methods when present on callbacks.""" mod = _fresh_group_module() group = mod.CallbackGroup.get_instance() mock_cb = MagicMock() group._callbacks = [mock_cb] group.on_app_start() assert mock_cb.on_app_start.called def test_callback_group_dynamic_dispatch_ignores_missing_methods(): """Test dynamic dispatch ignores missing methods without raising.""" mod = _fresh_group_module() group = mod.CallbackGroup.get_instance() class Dummy: pass group._callbacks = [Dummy()] # Should not raise even if method not present group.on_nonexistent_method() def test_hook_class_init_with_callbacks_wraps_and_emits(monkeypatch): """Test inheritance-based hook via __init_subclass__ emits start/end once (e2e-style).""" mod = _fresh_group_module() group = mod.CallbackGroup.get_instance() start = MagicMock() end = MagicMock() monkeypatch.setattr(group, 'on_model_init_start', start) monkeypatch.setattr(group, 'on_model_init_end', end) class Base: def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) # Mirror IOMixin: hook subclasses at definition time mod.hook_class_init_with_callbacks(cls, 'on_model_init_start', 'on_model_init_end') class Child(Base): def __init__(self): self.x = 1 class GrandChild(Child): def __init__(self): self.y = 2 super().__init__() c = Child() assert c.x == 1 # Flag indicating wrapping applied on the subclass assert getattr(Child.__init__, '_init_wrapped_for_callbacks', False) is True d = GrandChild() assert d.x == 1 assert d.y == 2 assert start.call_count == 2 assert end.call_count == 2 # Flag indicating wrapping applied on the subclass assert getattr(GrandChild.__init__, '_init_wrapped_for_callbacks', False) is True def test_hook_class_init_with_callbacks_idempotent(): """Test inheritance-based hook is idempotent and does not re-wrap on repeated calls.""" mod = _fresh_group_module() class Base: def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) mod.hook_class_init_with_callbacks(cls, 'on_model_init_start', 'on_model_init_end') class Child(Base): def __init__(self): pass # Hook was applied via __init_subclass__ at class creation time first = Child.__init__ # Attempt to apply again explicitly; should be a no-op mod.hook_class_init_with_callbacks(Child, 'on_model_init_start', 'on_model_init_end') second = Child.__init__ assert first is second def test_on_app_end_is_idempotent(monkeypatch): """Test on_app_end fans out only once even if called multiple times.""" mod = _fresh_group_module() group = mod.CallbackGroup.get_instance() mock_cb = MagicMock() group._callbacks = [mock_cb] group.on_app_end() group.on_app_end() assert mock_cb.on_app_end.call_count == 1