Spaces:
Runtime error
Runtime error
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import pytest | |
| from torch import nn | |
| from nemo.lightning.io.state import StateDictTransform, TransformCTX, state_transform | |
| class TestStateDictTransform: | |
| """ | |
| Tests for the StateDictTransform functionality. | |
| """ | |
| def mock_ctx(self): | |
| """ | |
| Provides a mock transformation context with predefined source and target states. | |
| Returns | |
| ------- | |
| TransformCTX: A context object with source and target states. | |
| """ | |
| source_state = { | |
| "model.layers.0.self_attn.q_proj.weight": 1, | |
| "model.layers.0.self_attn.k_proj.weight": 2, | |
| "model.layers.0.self_attn.v_proj.weight": 3, | |
| "model.layers.0.mlp.experts.0.gate_proj.weight": 4, | |
| "model.layers.0.mlp.experts.0.up_proj.weight": 5, | |
| "model.layers.0.mlp.experts.0.down_proj.weight": 8, | |
| "model.layers.0.mlp.experts.1.gate_proj.weight": 6, | |
| "model.layers.0.mlp.experts.1.up_proj.weight": 7, | |
| "model.layers.0.mlp.experts.1.down_proj.weight": 8, | |
| "model.layers.1.self_attn.q_proj.weight": 2, | |
| "model.layers.1.self_attn.k_proj.weight": 3, | |
| "model.layers.1.self_attn.v_proj.weight": 4, | |
| "model.layers.1.mlp.experts.0.gate_proj.weight": 5, | |
| "model.layers.1.mlp.experts.0.up_proj.weight": 6, | |
| "model.layers.1.mlp.experts.0.down_proj.weight": 9, | |
| "model.layers.1.mlp.experts.1.gate_proj.weight": 7, | |
| "model.layers.1.mlp.experts.1.up_proj.weight": 8, | |
| "model.layers.1.mlp.experts.1.down_proj.weight": 9, | |
| } | |
| target_state = { | |
| "decoder.layers.0.self_attention.linear_qkv.weight": -1, | |
| "decoder.layers.0.self_attention.linear_proj.weight": -1, | |
| "decoder.layers.0.mlp.experts.linear_fc1.weight0": -1, | |
| "decoder.layers.0.mlp.experts.linear_fc1.weight1": -1, | |
| "decoder.layers.0.mlp.experts.linear_fc2.weight": -1, | |
| "decoder.layers.1.self_attention.linear_qkv.weight": -1, | |
| "decoder.layers.1.self_attention.linear_proj.weight": -1, | |
| "decoder.layers.1.mlp.experts.linear_fc1.weight0": -1, | |
| "decoder.layers.1.mlp.experts.linear_fc1.weight1": -1, | |
| "decoder.layers.1.mlp.experts.linear_fc2.weight": -1, | |
| } | |
| ctx = TransformCTX( | |
| source=nn.Module(), source_state=source_state, target=nn.Module(), target_state=target_state | |
| ) | |
| return ctx | |
| def mock_multi_target_ctx(self): | |
| """ | |
| Provides a mock transformation context with a source state that matches the expected source_key | |
| and a target state prepared with initial values for the expected target_keys. | |
| """ | |
| source_state = { | |
| "decoder.layers.0.mlp.linear_fc1.weight": 1, | |
| "decoder.layers.1.mlp.linear_fc1.weight": 2, | |
| "decoder.layers.2.mlp.experts.linear_fc1.weight0": 23, | |
| "decoder.layers.2.mlp.experts.linear_fc1.weight1": 45, | |
| "decoder.layers.3.mlp.experts.linear_fc1.weight0": 34, | |
| "decoder.layers.3.mlp.experts.linear_fc1.weight1": 56, | |
| } | |
| # Populate target_state with initial placeholder values for keys expected to be matched and updated | |
| target_state = { | |
| "model.layers.0.mlp.gate_proj.weight": -1, | |
| "model.layers.0.mlp.up_proj.weight": -1, | |
| "model.layers.1.mlp.gate_proj.weight": -1, | |
| "model.layers.1.mlp.up_proj.weight": -1, | |
| "model.layers.2.mlp.experts.0.gate_proj.weight": -1, | |
| "model.layers.2.mlp.experts.0.up_proj.weight": -1, | |
| "model.layers.2.mlp.experts.1.gate_proj.weight": -1, | |
| "model.layers.2.mlp.experts.1.up_proj.weight": -1, | |
| "model.layers.3.mlp.experts.0.gate_proj.weight": -1, | |
| "model.layers.3.mlp.experts.0.up_proj.weight": -1, | |
| "model.layers.3.mlp.experts.1.gate_proj.weight": -1, | |
| "model.layers.3.mlp.experts.1.up_proj.weight": -1, | |
| } | |
| ctx = TransformCTX( | |
| source=nn.Module(), source_state=source_state, target=nn.Module(), target_state=target_state | |
| ) | |
| return ctx | |
| def test_transform_with_single_source_single_target(self, mock_ctx): | |
| """ | |
| Test transformation when a single source and target key is specified. | |
| """ | |
| transform = StateDictTransform( | |
| source_key="model.layers.*.mlp.experts.0.down_proj.weight", | |
| target_key="decoder.layers.*.mlp.experts.linear_fc2.weight", | |
| transform=lambda ctx, x: x * 100, | |
| ) | |
| transform(mock_ctx) | |
| assert mock_ctx.target_state["decoder.layers.0.mlp.experts.linear_fc2.weight"] == 800 | |
| assert mock_ctx.target_state["decoder.layers.1.mlp.experts.linear_fc2.weight"] == 900 | |
| def test_transform_with_multiple_sources(self, mock_ctx): | |
| """ | |
| Test transformation when multiple source keys are specified. | |
| """ | |
| transform = StateDictTransform( | |
| source_key=( | |
| "model.layers.*.self_attn.q_proj.weight", | |
| "model.layers.*.self_attn.k_proj.weight", | |
| "model.layers.*.self_attn.v_proj.weight", | |
| ), | |
| target_key="decoder.layers.*.self_attention.linear_qkv.weight", | |
| transform=lambda ctx, q, k, v: q * 100 + k * 10 + v, | |
| ) | |
| transform(mock_ctx) | |
| assert mock_ctx.target_state["decoder.layers.0.self_attention.linear_qkv.weight"] == 123 | |
| assert mock_ctx.target_state["decoder.layers.1.self_attention.linear_qkv.weight"] == 234 | |
| def test_transform_with_multiple_mapped_sources(self, mock_ctx): | |
| """ | |
| Test transformation with a dictionary mapping for source keys. | |
| """ | |
| transform = StateDictTransform( | |
| source_key={ | |
| "q": "model.layers.*.self_attn.q_proj.weight", | |
| "k": "model.layers.*.self_attn.k_proj.weight", | |
| "v": "model.layers.*.self_attn.v_proj.weight", | |
| }, | |
| target_key="decoder.layers.*.self_attention.linear_qkv.weight", | |
| transform=lambda ctx, q, k, v: q * 100 + k * 10 + v, | |
| ) | |
| transform(mock_ctx) | |
| assert mock_ctx.target_state["decoder.layers.0.self_attention.linear_qkv.weight"] == 123 | |
| assert mock_ctx.target_state["decoder.layers.1.self_attention.linear_qkv.weight"] == 234 | |
| def test_transform_with_variable_arguments(self, mock_ctx): | |
| """ | |
| Test transformation with a wildcard pattern and variable arguments. | |
| """ | |
| transform = StateDictTransform( | |
| source_key="model.layers.*.self_attn.*_proj.weight", | |
| target_key="decoder.layers.*.self_attention.linear_qkv.weight", | |
| transform=lambda ctx, *args: sum(args), | |
| ) | |
| transform(mock_ctx) | |
| assert mock_ctx.target_state["decoder.layers.0.self_attention.linear_qkv.weight"] == 6 | |
| assert mock_ctx.target_state["decoder.layers.1.self_attention.linear_qkv.weight"] == 9 | |
| def test_transform_with_no_matching_source(self, mock_ctx): | |
| """ | |
| Test transformation when no source keys match the pattern. | |
| """ | |
| transform = StateDictTransform( | |
| source_key="non.existent.pattern", | |
| target_key="decoder.layers.*.self_attention.linear_qkv.weight", | |
| transform=lambda ctx, *args: sum(args), | |
| ) | |
| with pytest.raises(ValueError): | |
| transform(mock_ctx) | |
| def test_transform_with_multiple_targets(self, mock_multi_target_ctx): | |
| """ | |
| Test transformation where the target_key is a tuple and the transform function | |
| returns multiple values that are then unrolled to these target keys. | |
| """ | |
| # Define a transformation that splits the input into two parts | |
| def split_transform(ctx, x): | |
| return x - 1, x + 1 | |
| # Apply the transformation | |
| transform = StateDictTransform( | |
| source_key="decoder.layers.*.mlp.linear_fc1.weight", | |
| target_key=( | |
| "model.layers.*.mlp.gate_proj.weight", | |
| "model.layers.*.mlp.up_proj.weight", | |
| ), | |
| transform=split_transform, | |
| ) | |
| transform(mock_multi_target_ctx) | |
| # Check that the target state has been updated correctly | |
| assert mock_multi_target_ctx.target_state["model.layers.0.mlp.gate_proj.weight"] == 0 | |
| assert mock_multi_target_ctx.target_state["model.layers.0.mlp.up_proj.weight"] == 2 | |
| assert mock_multi_target_ctx.target_state["model.layers.1.mlp.gate_proj.weight"] == 1 | |
| assert mock_multi_target_ctx.target_state["model.layers.1.mlp.up_proj.weight"] == 3 | |
| def test_transform_with_multiple_sources_multiple_wildcards(self, mock_ctx): | |
| """ | |
| Test transformation when multiple source keys are specified, each with more than 1 wildcard. | |
| """ | |
| transform = StateDictTransform( | |
| source_key=( | |
| "model.layers.*.mlp.experts.*.gate_proj.weight", | |
| "model.layers.*.mlp.experts.*.up_proj.weight", | |
| ), | |
| target_key="decoder.layers.*.mlp.experts.linear_fc1.weight*", | |
| transform=lambda ctx, gate, up: gate * 10 + up, | |
| ) | |
| transform(mock_ctx) | |
| assert mock_ctx.target_state["decoder.layers.0.mlp.experts.linear_fc1.weight0"] == 45 | |
| assert mock_ctx.target_state["decoder.layers.0.mlp.experts.linear_fc1.weight1"] == 67 | |
| assert mock_ctx.target_state["decoder.layers.1.mlp.experts.linear_fc1.weight0"] == 56 | |
| assert mock_ctx.target_state["decoder.layers.1.mlp.experts.linear_fc1.weight1"] == 78 | |
| def test_transform_with_multiple_targets_multiple_wildcards(self, mock_multi_target_ctx): | |
| """ | |
| Test transformation when multiple target keys are specified, each with more than 1 wildcard. | |
| """ | |
| def split_transform(ctx, x): | |
| return x // 10, x % 10 | |
| transform = StateDictTransform( | |
| source_key="decoder.layers.*.mlp.experts.linear_fc1.weight*", | |
| target_key=( | |
| "model.layers.*.mlp.experts.*.gate_proj.weight", | |
| "model.layers.*.mlp.experts.*.up_proj.weight", | |
| ), | |
| transform=split_transform, | |
| ) | |
| transform(mock_multi_target_ctx) | |
| assert mock_multi_target_ctx.target_state["model.layers.2.mlp.experts.0.gate_proj.weight"] == 2 | |
| assert mock_multi_target_ctx.target_state["model.layers.2.mlp.experts.0.up_proj.weight"] == 3 | |
| assert mock_multi_target_ctx.target_state["model.layers.2.mlp.experts.1.gate_proj.weight"] == 4 | |
| assert mock_multi_target_ctx.target_state["model.layers.2.mlp.experts.1.up_proj.weight"] == 5 | |
| assert mock_multi_target_ctx.target_state["model.layers.3.mlp.experts.0.gate_proj.weight"] == 3 | |
| assert mock_multi_target_ctx.target_state["model.layers.3.mlp.experts.0.up_proj.weight"] == 4 | |
| assert mock_multi_target_ctx.target_state["model.layers.3.mlp.experts.1.gate_proj.weight"] == 5 | |
| assert mock_multi_target_ctx.target_state["model.layers.3.mlp.experts.1.up_proj.weight"] == 6 | |
| def test_transform_with_no_matching_target(self, mock_ctx): | |
| """ | |
| Test transformation when no source keys match the pattern. | |
| """ | |
| transform = StateDictTransform( | |
| source_key="model.layers.*.mlp.experts.0.down_proj.weight", | |
| target_key="non.existent.pattern", | |
| transform=lambda ctx, *args: sum(args), | |
| ) | |
| with pytest.raises(ValueError): | |
| transform(mock_ctx) | |
| def test_transform_with_invalid_transform_function(self, mock_ctx): | |
| """ | |
| Test transformation with a transform function that does not match expected signature. | |
| """ | |
| transform = StateDictTransform( | |
| source_key="model.layers.*.self_attn.q_proj.weight", | |
| target_key="decoder.layers.*.self_attention.linear_qkv.weight", | |
| transform=lambda ctx: 0, # Invalid signature | |
| ) | |
| with pytest.raises(ValueError): | |
| transform(mock_ctx) | |
| class TestStateTransformDecorator: | |
| """ | |
| Tests for the @state_transform decorator functionality. | |
| """ | |
| def mock_ctx(self): | |
| """ | |
| Provides a mock transformation context with predefined source and target states. | |
| """ | |
| source_state = { | |
| 'model.layers.1.self_attn.q_proj.weight': 1, | |
| 'model.layers.1.self_attn.k_proj.weight': 2, | |
| 'model.layers.1.self_attn.v_proj.weight': 3, | |
| } | |
| # Pre-populate target_state with initial values or placeholders | |
| target_state = { | |
| "decoder.layers.1.self_attention.linear_q.weight": 0, | |
| "decoder.layers.1.self_attention.linear_k.weight": 0, | |
| "decoder.layers.1.self_attention.linear_v.weight": 0, | |
| } | |
| ctx = TransformCTX( | |
| source=nn.Module(), source_state=source_state, target=nn.Module(), target_state=target_state | |
| ) | |
| return ctx | |
| def test_single_transform(self, mock_ctx): | |
| """ | |
| Test the @state_transform decorator with a single source and target key. | |
| """ | |
| # Apply the transformation | |
| single_transform(mock_ctx) | |
| # Verify the target state is updated correctly | |
| assert mock_ctx.target_state["decoder.layers.1.self_attention.linear_q.weight"] == 11 | |
| def test_multiple_outputs_transform(self, mock_ctx): | |
| """ | |
| Test the @state_transform decorator with a single source key and multiple target keys. | |
| """ | |
| # Apply the transformation | |
| multiple_outputs_transform(mock_ctx) | |
| # Verify the target state is updated correctly for each key | |
| assert mock_ctx.target_state["decoder.layers.1.self_attention.linear_q.weight"] == 1 | |
| assert mock_ctx.target_state["decoder.layers.1.self_attention.linear_k.weight"] == 2 | |
| assert mock_ctx.target_state["decoder.layers.1.self_attention.linear_v.weight"] == 3 | |
| def single_transform(ctx, x): | |
| """ | |
| A single transformation function that adds 10 to the input value. | |
| """ | |
| return x + 10 | |
| def multiple_outputs_transform(ctx, *args): | |
| """ | |
| A transformation function that returns multiple values for multiple target keys. | |
| """ | |
| return args | |