subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from torch import nn
from nemo.lightning.io.state import StateDictTransform, TransformCTX, state_transform
class TestStateDictTransform:
"""
Tests for the StateDictTransform functionality.
"""
@pytest.fixture
def mock_ctx(self):
"""
Provides a mock transformation context with predefined source and target states.
Returns
-------
TransformCTX: A context object with source and target states.
"""
source_state = {
"model.layers.0.self_attn.q_proj.weight": 1,
"model.layers.0.self_attn.k_proj.weight": 2,
"model.layers.0.self_attn.v_proj.weight": 3,
"model.layers.0.mlp.experts.0.gate_proj.weight": 4,
"model.layers.0.mlp.experts.0.up_proj.weight": 5,
"model.layers.0.mlp.experts.0.down_proj.weight": 8,
"model.layers.0.mlp.experts.1.gate_proj.weight": 6,
"model.layers.0.mlp.experts.1.up_proj.weight": 7,
"model.layers.0.mlp.experts.1.down_proj.weight": 8,
"model.layers.1.self_attn.q_proj.weight": 2,
"model.layers.1.self_attn.k_proj.weight": 3,
"model.layers.1.self_attn.v_proj.weight": 4,
"model.layers.1.mlp.experts.0.gate_proj.weight": 5,
"model.layers.1.mlp.experts.0.up_proj.weight": 6,
"model.layers.1.mlp.experts.0.down_proj.weight": 9,
"model.layers.1.mlp.experts.1.gate_proj.weight": 7,
"model.layers.1.mlp.experts.1.up_proj.weight": 8,
"model.layers.1.mlp.experts.1.down_proj.weight": 9,
}
target_state = {
"decoder.layers.0.self_attention.linear_qkv.weight": -1,
"decoder.layers.0.self_attention.linear_proj.weight": -1,
"decoder.layers.0.mlp.experts.linear_fc1.weight0": -1,
"decoder.layers.0.mlp.experts.linear_fc1.weight1": -1,
"decoder.layers.0.mlp.experts.linear_fc2.weight": -1,
"decoder.layers.1.self_attention.linear_qkv.weight": -1,
"decoder.layers.1.self_attention.linear_proj.weight": -1,
"decoder.layers.1.mlp.experts.linear_fc1.weight0": -1,
"decoder.layers.1.mlp.experts.linear_fc1.weight1": -1,
"decoder.layers.1.mlp.experts.linear_fc2.weight": -1,
}
ctx = TransformCTX(
source=nn.Module(), source_state=source_state, target=nn.Module(), target_state=target_state
)
return ctx
@pytest.fixture
def mock_multi_target_ctx(self):
"""
Provides a mock transformation context with a source state that matches the expected source_key
and a target state prepared with initial values for the expected target_keys.
"""
source_state = {
"decoder.layers.0.mlp.linear_fc1.weight": 1,
"decoder.layers.1.mlp.linear_fc1.weight": 2,
"decoder.layers.2.mlp.experts.linear_fc1.weight0": 23,
"decoder.layers.2.mlp.experts.linear_fc1.weight1": 45,
"decoder.layers.3.mlp.experts.linear_fc1.weight0": 34,
"decoder.layers.3.mlp.experts.linear_fc1.weight1": 56,
}
# Populate target_state with initial placeholder values for keys expected to be matched and updated
target_state = {
"model.layers.0.mlp.gate_proj.weight": -1,
"model.layers.0.mlp.up_proj.weight": -1,
"model.layers.1.mlp.gate_proj.weight": -1,
"model.layers.1.mlp.up_proj.weight": -1,
"model.layers.2.mlp.experts.0.gate_proj.weight": -1,
"model.layers.2.mlp.experts.0.up_proj.weight": -1,
"model.layers.2.mlp.experts.1.gate_proj.weight": -1,
"model.layers.2.mlp.experts.1.up_proj.weight": -1,
"model.layers.3.mlp.experts.0.gate_proj.weight": -1,
"model.layers.3.mlp.experts.0.up_proj.weight": -1,
"model.layers.3.mlp.experts.1.gate_proj.weight": -1,
"model.layers.3.mlp.experts.1.up_proj.weight": -1,
}
ctx = TransformCTX(
source=nn.Module(), source_state=source_state, target=nn.Module(), target_state=target_state
)
return ctx
def test_transform_with_single_source_single_target(self, mock_ctx):
"""
Test transformation when a single source and target key is specified.
"""
transform = StateDictTransform(
source_key="model.layers.*.mlp.experts.0.down_proj.weight",
target_key="decoder.layers.*.mlp.experts.linear_fc2.weight",
transform=lambda ctx, x: x * 100,
)
transform(mock_ctx)
assert mock_ctx.target_state["decoder.layers.0.mlp.experts.linear_fc2.weight"] == 800
assert mock_ctx.target_state["decoder.layers.1.mlp.experts.linear_fc2.weight"] == 900
def test_transform_with_multiple_sources(self, mock_ctx):
"""
Test transformation when multiple source keys are specified.
"""
transform = StateDictTransform(
source_key=(
"model.layers.*.self_attn.q_proj.weight",
"model.layers.*.self_attn.k_proj.weight",
"model.layers.*.self_attn.v_proj.weight",
),
target_key="decoder.layers.*.self_attention.linear_qkv.weight",
transform=lambda ctx, q, k, v: q * 100 + k * 10 + v,
)
transform(mock_ctx)
assert mock_ctx.target_state["decoder.layers.0.self_attention.linear_qkv.weight"] == 123
assert mock_ctx.target_state["decoder.layers.1.self_attention.linear_qkv.weight"] == 234
def test_transform_with_multiple_mapped_sources(self, mock_ctx):
"""
Test transformation with a dictionary mapping for source keys.
"""
transform = StateDictTransform(
source_key={
"q": "model.layers.*.self_attn.q_proj.weight",
"k": "model.layers.*.self_attn.k_proj.weight",
"v": "model.layers.*.self_attn.v_proj.weight",
},
target_key="decoder.layers.*.self_attention.linear_qkv.weight",
transform=lambda ctx, q, k, v: q * 100 + k * 10 + v,
)
transform(mock_ctx)
assert mock_ctx.target_state["decoder.layers.0.self_attention.linear_qkv.weight"] == 123
assert mock_ctx.target_state["decoder.layers.1.self_attention.linear_qkv.weight"] == 234
def test_transform_with_variable_arguments(self, mock_ctx):
"""
Test transformation with a wildcard pattern and variable arguments.
"""
transform = StateDictTransform(
source_key="model.layers.*.self_attn.*_proj.weight",
target_key="decoder.layers.*.self_attention.linear_qkv.weight",
transform=lambda ctx, *args: sum(args),
)
transform(mock_ctx)
assert mock_ctx.target_state["decoder.layers.0.self_attention.linear_qkv.weight"] == 6
assert mock_ctx.target_state["decoder.layers.1.self_attention.linear_qkv.weight"] == 9
def test_transform_with_no_matching_source(self, mock_ctx):
"""
Test transformation when no source keys match the pattern.
"""
transform = StateDictTransform(
source_key="non.existent.pattern",
target_key="decoder.layers.*.self_attention.linear_qkv.weight",
transform=lambda ctx, *args: sum(args),
)
with pytest.raises(ValueError):
transform(mock_ctx)
def test_transform_with_multiple_targets(self, mock_multi_target_ctx):
"""
Test transformation where the target_key is a tuple and the transform function
returns multiple values that are then unrolled to these target keys.
"""
# Define a transformation that splits the input into two parts
def split_transform(ctx, x):
return x - 1, x + 1
# Apply the transformation
transform = StateDictTransform(
source_key="decoder.layers.*.mlp.linear_fc1.weight",
target_key=(
"model.layers.*.mlp.gate_proj.weight",
"model.layers.*.mlp.up_proj.weight",
),
transform=split_transform,
)
transform(mock_multi_target_ctx)
# Check that the target state has been updated correctly
assert mock_multi_target_ctx.target_state["model.layers.0.mlp.gate_proj.weight"] == 0
assert mock_multi_target_ctx.target_state["model.layers.0.mlp.up_proj.weight"] == 2
assert mock_multi_target_ctx.target_state["model.layers.1.mlp.gate_proj.weight"] == 1
assert mock_multi_target_ctx.target_state["model.layers.1.mlp.up_proj.weight"] == 3
def test_transform_with_multiple_sources_multiple_wildcards(self, mock_ctx):
"""
Test transformation when multiple source keys are specified, each with more than 1 wildcard.
"""
transform = StateDictTransform(
source_key=(
"model.layers.*.mlp.experts.*.gate_proj.weight",
"model.layers.*.mlp.experts.*.up_proj.weight",
),
target_key="decoder.layers.*.mlp.experts.linear_fc1.weight*",
transform=lambda ctx, gate, up: gate * 10 + up,
)
transform(mock_ctx)
assert mock_ctx.target_state["decoder.layers.0.mlp.experts.linear_fc1.weight0"] == 45
assert mock_ctx.target_state["decoder.layers.0.mlp.experts.linear_fc1.weight1"] == 67
assert mock_ctx.target_state["decoder.layers.1.mlp.experts.linear_fc1.weight0"] == 56
assert mock_ctx.target_state["decoder.layers.1.mlp.experts.linear_fc1.weight1"] == 78
def test_transform_with_multiple_targets_multiple_wildcards(self, mock_multi_target_ctx):
"""
Test transformation when multiple target keys are specified, each with more than 1 wildcard.
"""
def split_transform(ctx, x):
return x // 10, x % 10
transform = StateDictTransform(
source_key="decoder.layers.*.mlp.experts.linear_fc1.weight*",
target_key=(
"model.layers.*.mlp.experts.*.gate_proj.weight",
"model.layers.*.mlp.experts.*.up_proj.weight",
),
transform=split_transform,
)
transform(mock_multi_target_ctx)
assert mock_multi_target_ctx.target_state["model.layers.2.mlp.experts.0.gate_proj.weight"] == 2
assert mock_multi_target_ctx.target_state["model.layers.2.mlp.experts.0.up_proj.weight"] == 3
assert mock_multi_target_ctx.target_state["model.layers.2.mlp.experts.1.gate_proj.weight"] == 4
assert mock_multi_target_ctx.target_state["model.layers.2.mlp.experts.1.up_proj.weight"] == 5
assert mock_multi_target_ctx.target_state["model.layers.3.mlp.experts.0.gate_proj.weight"] == 3
assert mock_multi_target_ctx.target_state["model.layers.3.mlp.experts.0.up_proj.weight"] == 4
assert mock_multi_target_ctx.target_state["model.layers.3.mlp.experts.1.gate_proj.weight"] == 5
assert mock_multi_target_ctx.target_state["model.layers.3.mlp.experts.1.up_proj.weight"] == 6
def test_transform_with_no_matching_target(self, mock_ctx):
"""
Test transformation when no source keys match the pattern.
"""
transform = StateDictTransform(
source_key="model.layers.*.mlp.experts.0.down_proj.weight",
target_key="non.existent.pattern",
transform=lambda ctx, *args: sum(args),
)
with pytest.raises(ValueError):
transform(mock_ctx)
def test_transform_with_invalid_transform_function(self, mock_ctx):
"""
Test transformation with a transform function that does not match expected signature.
"""
transform = StateDictTransform(
source_key="model.layers.*.self_attn.q_proj.weight",
target_key="decoder.layers.*.self_attention.linear_qkv.weight",
transform=lambda ctx: 0, # Invalid signature
)
with pytest.raises(ValueError):
transform(mock_ctx)
class TestStateTransformDecorator:
"""
Tests for the @state_transform decorator functionality.
"""
@pytest.fixture
def mock_ctx(self):
"""
Provides a mock transformation context with predefined source and target states.
"""
source_state = {
'model.layers.1.self_attn.q_proj.weight': 1,
'model.layers.1.self_attn.k_proj.weight': 2,
'model.layers.1.self_attn.v_proj.weight': 3,
}
# Pre-populate target_state with initial values or placeholders
target_state = {
"decoder.layers.1.self_attention.linear_q.weight": 0,
"decoder.layers.1.self_attention.linear_k.weight": 0,
"decoder.layers.1.self_attention.linear_v.weight": 0,
}
ctx = TransformCTX(
source=nn.Module(), source_state=source_state, target=nn.Module(), target_state=target_state
)
return ctx
def test_single_transform(self, mock_ctx):
"""
Test the @state_transform decorator with a single source and target key.
"""
# Apply the transformation
single_transform(mock_ctx)
# Verify the target state is updated correctly
assert mock_ctx.target_state["decoder.layers.1.self_attention.linear_q.weight"] == 11
def test_multiple_outputs_transform(self, mock_ctx):
"""
Test the @state_transform decorator with a single source key and multiple target keys.
"""
# Apply the transformation
multiple_outputs_transform(mock_ctx)
# Verify the target state is updated correctly for each key
assert mock_ctx.target_state["decoder.layers.1.self_attention.linear_q.weight"] == 1
assert mock_ctx.target_state["decoder.layers.1.self_attention.linear_k.weight"] == 2
assert mock_ctx.target_state["decoder.layers.1.self_attention.linear_v.weight"] == 3
@state_transform(
source_key="model.layers.*.self_attn.q_proj.weight", target_key="decoder.layers.1.self_attention.linear_q.weight"
)
def single_transform(ctx, x):
"""
A single transformation function that adds 10 to the input value.
"""
return x + 10
@state_transform(
source_key="model.layers.1.self_attn.*_proj.weight",
target_key=("decoder.layers.1.self_attention.linear_*.weight",),
)
def multiple_outputs_transform(ctx, *args):
"""
A transformation function that returns multiple values for multiple target keys.
"""
return args