Spaces:
Runtime error
Runtime error
| """Tests for enum subset validation in REQUIRED_INPUTS. | |
| This module tests the enhanced validation logic that supports restricting | |
| enum fields to specific subsets using Literal types. | |
| """ | |
| from typing import Any, Literal | |
| from sentinel.risk_models.base import RiskModel | |
| from sentinel.user_input import ( | |
| Anthropometrics, | |
| Demographics, | |
| Ethnicity, | |
| Lifestyle, | |
| PersonalMedicalHistory, | |
| Sex, | |
| SmokingHistory, | |
| SmokingStatus, | |
| UserInput, | |
| ) | |
| class EnumValidationTestModel(RiskModel): | |
| """Test risk model with various enum restrictions for validation testing.""" | |
| def __init__(self): | |
| super().__init__("test_enum_validation") | |
| # Test cases for different enum restriction patterns | |
| REQUIRED_INPUTS: dict[str, tuple[type | Any, bool]] = { | |
| # Single enum value restriction | |
| "demographics.sex": (Literal[Sex.FEMALE], True), | |
| # Multiple enum value restriction (subset) | |
| "demographics.ethnicity": ( | |
| Literal[Ethnicity.WHITE, Ethnicity.BLACK, Ethnicity.ASIAN] | None, | |
| False, | |
| ), | |
| } | |
| def compute_score(self, user: UserInput) -> str: | |
| """Test implementation. | |
| Args: | |
| user: The user profile to score. | |
| Returns: | |
| A test score string. | |
| """ | |
| return "test_score" | |
| def cancer_type(self) -> str: | |
| return "test" | |
| def description(self) -> str: | |
| return "Test model" | |
| def interpretation(self) -> str: | |
| return "Test interpretation" | |
| def references(self) -> list[str]: | |
| return ["Test reference"] | |
| def time_horizon_years(self) -> float | None: | |
| return None | |
| class TestEnumSubsetValidation: | |
| """Test enum subset validation functionality.""" | |
| def setup_method(self): | |
| """Set up test model.""" | |
| self.model = EnumValidationTestModel() | |
| def _create_user_input( | |
| self, sex: Sex, ethnicity: Ethnicity | None = None | |
| ) -> UserInput: | |
| """Create a valid UserInput instance for testing. | |
| Args: | |
| sex: The biological sex for the user. | |
| ethnicity: The ethnicity for the user (optional). | |
| Returns: | |
| A valid UserInput instance for testing. | |
| """ | |
| return UserInput( | |
| demographics=Demographics( | |
| age_years=40, | |
| sex=sex, | |
| ethnicity=ethnicity, | |
| anthropometrics=Anthropometrics(height_cm=165.0, weight_kg=65.0), | |
| ), | |
| lifestyle=Lifestyle( | |
| smoking=SmokingHistory(status=SmokingStatus.NEVER), | |
| ), | |
| personal_medical_history=PersonalMedicalHistory(), | |
| ) | |
| def test_single_enum_value_restriction_valid(self): | |
| """Test that valid single enum value passes validation.""" | |
| user = self._create_user_input(Sex.FEMALE, Ethnicity.WHITE) | |
| is_valid, errors = self.model.validate_inputs(user) | |
| assert is_valid | |
| assert len(errors) == 0 | |
| def test_single_enum_value_restriction_invalid(self): | |
| """Test that invalid single enum value fails validation with clear message.""" | |
| user = self._create_user_input(Sex.MALE, Ethnicity.WHITE) # Should be FEMALE | |
| is_valid, errors = self.model.validate_inputs(user) | |
| assert not is_valid | |
| assert len(errors) == 1 | |
| assert "Field 'demographics.sex': must be FEMALE" in errors[0] | |
| def test_multiple_enum_value_restriction_valid(self): | |
| """Test that valid enum values from subset pass validation.""" | |
| valid_ethnicities = [Ethnicity.WHITE, Ethnicity.BLACK, Ethnicity.ASIAN] | |
| for ethnicity in valid_ethnicities: | |
| user = self._create_user_input(Sex.FEMALE, ethnicity) | |
| is_valid, errors = self.model.validate_inputs(user) | |
| assert is_valid, f"Failed for ethnicity: {ethnicity}" | |
| assert len(errors) == 0 | |
| def test_multiple_enum_value_restriction_invalid(self): | |
| """Test that invalid enum values fail validation with clear message.""" | |
| invalid_ethnicities = [ | |
| Ethnicity.HISPANIC, | |
| Ethnicity.ASHKENAZI_JEWISH, | |
| Ethnicity.NATIVE_AMERICAN, | |
| Ethnicity.PACIFIC_ISLANDER, | |
| Ethnicity.OTHER, | |
| Ethnicity.UNKNOWN, | |
| ] | |
| for ethnicity in invalid_ethnicities: | |
| user = self._create_user_input(Sex.FEMALE, ethnicity) | |
| is_valid, errors = self.model.validate_inputs(user) | |
| assert not is_valid, f"Should have failed for ethnicity: {ethnicity}" | |
| assert len(errors) == 1 | |
| assert "Field 'demographics.ethnicity': Input should be" in errors[0] | |
| assert ( | |
| "WHITE" in errors[0] and "BLACK" in errors[0] and "ASIAN" in errors[0] | |
| ) | |
| def test_optional_enum_field_with_none(self): | |
| """Test that None values are handled correctly for optional enum fields.""" | |
| user = self._create_user_input(Sex.FEMALE, None) # Optional field | |
| is_valid, errors = self.model.validate_inputs(user) | |
| assert is_valid | |
| assert len(errors) == 0 | |
| def test_missing_required_enum_field(self): | |
| """Test that missing required enum fields are caught.""" | |
| # Create a model that requires a field that's not in the user input | |
| class MissingFieldModel(RiskModel): | |
| """Test model for missing field validation.""" | |
| def __init__(self): | |
| super().__init__("missing_field_test") | |
| REQUIRED_INPUTS: dict[str, tuple[Any, bool]] = { | |
| "demographics.sex": (Literal[Sex.FEMALE], True), | |
| "demographics.ethnicity": (Ethnicity | None, False), | |
| "demographics.nonexistent_field": ( | |
| str, | |
| True, | |
| ), # This field doesn't exist | |
| } | |
| def compute_score(self, user: UserInput) -> str: | |
| return "test" | |
| def cancer_type(self) -> str: | |
| return "test" | |
| def description(self) -> str: | |
| return "test" | |
| def interpretation(self) -> str: | |
| return "test" | |
| def references(self) -> list[str]: | |
| return ["test"] | |
| def time_horizon_years(self) -> float | None: | |
| return None | |
| model = MissingFieldModel() | |
| user = self._create_user_input(Sex.FEMALE, Ethnicity.WHITE) | |
| is_valid, errors = model.validate_inputs(user) | |
| assert not is_valid | |
| assert len(errors) == 1 | |
| assert "Required field 'demographics.nonexistent_field' is missing" in errors[0] | |
| def test_multiple_validation_errors(self): | |
| """Test that multiple validation errors are reported.""" | |
| user = self._create_user_input(Sex.MALE, Ethnicity.HISPANIC) # Both wrong | |
| is_valid, errors = self.model.validate_inputs(user) | |
| assert not is_valid | |
| assert len(errors) == 2 | |
| # Check that both errors are present | |
| error_messages = " ".join(errors) | |
| assert "must be FEMALE" in error_messages | |
| assert "Input should be" in error_messages | |
| assert ( | |
| "WHITE" in error_messages | |
| and "BLACK" in error_messages | |
| and "ASIAN" in error_messages | |
| ) | |
| def test_literal_enum_type_detection(self): | |
| """Test the _is_literal_enum_type helper method.""" | |
| # Test Literal with single enum value | |
| single_literal = Literal[Sex.FEMALE] | |
| assert self.model._is_literal_enum_type(single_literal) | |
| # Test Literal with multiple enum values | |
| multi_literal = Literal[Ethnicity.WHITE, Ethnicity.BLACK] | |
| assert self.model._is_literal_enum_type(multi_literal) | |
| # Test non-Literal types | |
| assert not self.model._is_literal_enum_type(Sex) | |
| assert not self.model._is_literal_enum_type(int) | |
| assert not self.model._is_literal_enum_type(str) | |
| def test_extract_literal_enum_values(self): | |
| """Test the _extract_literal_enum_values helper method.""" | |
| # Test single enum value | |
| single_literal = Literal[Sex.FEMALE] | |
| values = self.model._extract_literal_enum_values(single_literal) | |
| assert values == ["FEMALE"] | |
| # Test multiple enum values | |
| multi_literal = Literal[Ethnicity.WHITE, Ethnicity.BLACK, Ethnicity.ASIAN] | |
| values = self.model._extract_literal_enum_values(multi_literal) | |
| assert set(values) == {"WHITE", "BLACK", "ASIAN"} | |
| def test_backward_compatibility_unrestricted_enum(self): | |
| """Test that unrestricted enum types still work (backward compatibility).""" | |
| # Create a model with unrestricted enum | |
| class UnrestrictedModel(RiskModel): | |
| """Test model for backward compatibility with unrestricted enums.""" | |
| def __init__(self): | |
| super().__init__("unrestricted_test") | |
| REQUIRED_INPUTS: dict[str, tuple[type | Any, bool]] = { | |
| "demographics.sex": (Sex, True), | |
| "demographics.ethnicity": (Ethnicity | None, False), | |
| } | |
| def compute_score(self, user: UserInput) -> str: | |
| return "test" | |
| def cancer_type(self) -> str: | |
| return "test" | |
| def description(self) -> str: | |
| return "test" | |
| def interpretation(self) -> str: | |
| return "test" | |
| def references(self) -> list[str]: | |
| return ["test"] | |
| def time_horizon_years(self) -> float | None: | |
| return None | |
| model = UnrestrictedModel() | |
| # Test with any valid enum values | |
| user = self._create_user_input( | |
| Sex.MALE, Ethnicity.HISPANIC | |
| ) # Any values should work | |
| is_valid, errors = model.validate_inputs(user) | |
| assert is_valid | |
| assert len(errors) == 0 | |