sentinel / tests /test_risk_models /test_base_model_enum_validation.py
jeuko's picture
Sync from GitHub (main)
7638cbd verified
"""Tests for enum subset validation in REQUIRED_INPUTS.
This module tests the enhanced validation logic that supports restricting
enum fields to specific subsets using Literal types.
"""
from typing import Any, Literal
from sentinel.risk_models.base import RiskModel
from sentinel.user_input import (
Anthropometrics,
Demographics,
Ethnicity,
Lifestyle,
PersonalMedicalHistory,
Sex,
SmokingHistory,
SmokingStatus,
UserInput,
)
class EnumValidationTestModel(RiskModel):
"""Test risk model with various enum restrictions for validation testing."""
def __init__(self):
super().__init__("test_enum_validation")
# Test cases for different enum restriction patterns
REQUIRED_INPUTS: dict[str, tuple[type | Any, bool]] = {
# Single enum value restriction
"demographics.sex": (Literal[Sex.FEMALE], True),
# Multiple enum value restriction (subset)
"demographics.ethnicity": (
Literal[Ethnicity.WHITE, Ethnicity.BLACK, Ethnicity.ASIAN] | None,
False,
),
}
def compute_score(self, user: UserInput) -> str:
"""Test implementation.
Args:
user: The user profile to score.
Returns:
A test score string.
"""
return "test_score"
def cancer_type(self) -> str:
return "test"
def description(self) -> str:
return "Test model"
def interpretation(self) -> str:
return "Test interpretation"
def references(self) -> list[str]:
return ["Test reference"]
def time_horizon_years(self) -> float | None:
return None
class TestEnumSubsetValidation:
"""Test enum subset validation functionality."""
def setup_method(self):
"""Set up test model."""
self.model = EnumValidationTestModel()
def _create_user_input(
self, sex: Sex, ethnicity: Ethnicity | None = None
) -> UserInput:
"""Create a valid UserInput instance for testing.
Args:
sex: The biological sex for the user.
ethnicity: The ethnicity for the user (optional).
Returns:
A valid UserInput instance for testing.
"""
return UserInput(
demographics=Demographics(
age_years=40,
sex=sex,
ethnicity=ethnicity,
anthropometrics=Anthropometrics(height_cm=165.0, weight_kg=65.0),
),
lifestyle=Lifestyle(
smoking=SmokingHistory(status=SmokingStatus.NEVER),
),
personal_medical_history=PersonalMedicalHistory(),
)
def test_single_enum_value_restriction_valid(self):
"""Test that valid single enum value passes validation."""
user = self._create_user_input(Sex.FEMALE, Ethnicity.WHITE)
is_valid, errors = self.model.validate_inputs(user)
assert is_valid
assert len(errors) == 0
def test_single_enum_value_restriction_invalid(self):
"""Test that invalid single enum value fails validation with clear message."""
user = self._create_user_input(Sex.MALE, Ethnicity.WHITE) # Should be FEMALE
is_valid, errors = self.model.validate_inputs(user)
assert not is_valid
assert len(errors) == 1
assert "Field 'demographics.sex': must be FEMALE" in errors[0]
def test_multiple_enum_value_restriction_valid(self):
"""Test that valid enum values from subset pass validation."""
valid_ethnicities = [Ethnicity.WHITE, Ethnicity.BLACK, Ethnicity.ASIAN]
for ethnicity in valid_ethnicities:
user = self._create_user_input(Sex.FEMALE, ethnicity)
is_valid, errors = self.model.validate_inputs(user)
assert is_valid, f"Failed for ethnicity: {ethnicity}"
assert len(errors) == 0
def test_multiple_enum_value_restriction_invalid(self):
"""Test that invalid enum values fail validation with clear message."""
invalid_ethnicities = [
Ethnicity.HISPANIC,
Ethnicity.ASHKENAZI_JEWISH,
Ethnicity.NATIVE_AMERICAN,
Ethnicity.PACIFIC_ISLANDER,
Ethnicity.OTHER,
Ethnicity.UNKNOWN,
]
for ethnicity in invalid_ethnicities:
user = self._create_user_input(Sex.FEMALE, ethnicity)
is_valid, errors = self.model.validate_inputs(user)
assert not is_valid, f"Should have failed for ethnicity: {ethnicity}"
assert len(errors) == 1
assert "Field 'demographics.ethnicity': Input should be" in errors[0]
assert (
"WHITE" in errors[0] and "BLACK" in errors[0] and "ASIAN" in errors[0]
)
def test_optional_enum_field_with_none(self):
"""Test that None values are handled correctly for optional enum fields."""
user = self._create_user_input(Sex.FEMALE, None) # Optional field
is_valid, errors = self.model.validate_inputs(user)
assert is_valid
assert len(errors) == 0
def test_missing_required_enum_field(self):
"""Test that missing required enum fields are caught."""
# Create a model that requires a field that's not in the user input
class MissingFieldModel(RiskModel):
"""Test model for missing field validation."""
def __init__(self):
super().__init__("missing_field_test")
REQUIRED_INPUTS: dict[str, tuple[Any, bool]] = {
"demographics.sex": (Literal[Sex.FEMALE], True),
"demographics.ethnicity": (Ethnicity | None, False),
"demographics.nonexistent_field": (
str,
True,
), # This field doesn't exist
}
def compute_score(self, user: UserInput) -> str:
return "test"
def cancer_type(self) -> str:
return "test"
def description(self) -> str:
return "test"
def interpretation(self) -> str:
return "test"
def references(self) -> list[str]:
return ["test"]
def time_horizon_years(self) -> float | None:
return None
model = MissingFieldModel()
user = self._create_user_input(Sex.FEMALE, Ethnicity.WHITE)
is_valid, errors = model.validate_inputs(user)
assert not is_valid
assert len(errors) == 1
assert "Required field 'demographics.nonexistent_field' is missing" in errors[0]
def test_multiple_validation_errors(self):
"""Test that multiple validation errors are reported."""
user = self._create_user_input(Sex.MALE, Ethnicity.HISPANIC) # Both wrong
is_valid, errors = self.model.validate_inputs(user)
assert not is_valid
assert len(errors) == 2
# Check that both errors are present
error_messages = " ".join(errors)
assert "must be FEMALE" in error_messages
assert "Input should be" in error_messages
assert (
"WHITE" in error_messages
and "BLACK" in error_messages
and "ASIAN" in error_messages
)
def test_literal_enum_type_detection(self):
"""Test the _is_literal_enum_type helper method."""
# Test Literal with single enum value
single_literal = Literal[Sex.FEMALE]
assert self.model._is_literal_enum_type(single_literal)
# Test Literal with multiple enum values
multi_literal = Literal[Ethnicity.WHITE, Ethnicity.BLACK]
assert self.model._is_literal_enum_type(multi_literal)
# Test non-Literal types
assert not self.model._is_literal_enum_type(Sex)
assert not self.model._is_literal_enum_type(int)
assert not self.model._is_literal_enum_type(str)
def test_extract_literal_enum_values(self):
"""Test the _extract_literal_enum_values helper method."""
# Test single enum value
single_literal = Literal[Sex.FEMALE]
values = self.model._extract_literal_enum_values(single_literal)
assert values == ["FEMALE"]
# Test multiple enum values
multi_literal = Literal[Ethnicity.WHITE, Ethnicity.BLACK, Ethnicity.ASIAN]
values = self.model._extract_literal_enum_values(multi_literal)
assert set(values) == {"WHITE", "BLACK", "ASIAN"}
def test_backward_compatibility_unrestricted_enum(self):
"""Test that unrestricted enum types still work (backward compatibility)."""
# Create a model with unrestricted enum
class UnrestrictedModel(RiskModel):
"""Test model for backward compatibility with unrestricted enums."""
def __init__(self):
super().__init__("unrestricted_test")
REQUIRED_INPUTS: dict[str, tuple[type | Any, bool]] = {
"demographics.sex": (Sex, True),
"demographics.ethnicity": (Ethnicity | None, False),
}
def compute_score(self, user: UserInput) -> str:
return "test"
def cancer_type(self) -> str:
return "test"
def description(self) -> str:
return "test"
def interpretation(self) -> str:
return "test"
def references(self) -> list[str]:
return ["test"]
def time_horizon_years(self) -> float | None:
return None
model = UnrestrictedModel()
# Test with any valid enum values
user = self._create_user_input(
Sex.MALE, Ethnicity.HISPANIC
) # Any values should work
is_valid, errors = model.validate_inputs(user)
assert is_valid
assert len(errors) == 0