Spaces:
Running
Running
| """Unit tests for dataset record helpers in scripts.download_data.""" | |
| from __future__ import annotations | |
| import importlib.util | |
| import unittest | |
| from pathlib import Path | |
| from typing import Any, Dict, Iterator, List, cast | |
| PROJECT_ROOT = Path(__file__).resolve().parents[2] | |
| DOWNLOAD_SCRIPT = PROJECT_ROOT / "scripts" / "download_data.py" | |
| spec = importlib.util.spec_from_file_location("download_data", DOWNLOAD_SCRIPT) | |
| if spec is None or spec.loader is None: | |
| raise RuntimeError("Unable to load scripts/download_data.py for testing") | |
| download_data = importlib.util.module_from_spec(spec) | |
| spec.loader.exec_module(download_data) | |
| class DummyDataset: | |
| def __init__(self, records: List[Dict[str, object]]) -> None: | |
| self._records = records | |
| def __iter__(self) -> Iterator[Dict[str, object]]: | |
| return iter(self._records) | |
| class DownloadDataRecordTests(unittest.TestCase): | |
| def test_emotion_records_handles_out_of_range_labels(self) -> None: | |
| dataset_split = DummyDataset( | |
| [ | |
| {"text": "sample", "label": 1}, | |
| {"text": "multi", "label": [0, 5]}, | |
| {"text": "string", "label": "2"}, | |
| ] | |
| ) | |
| label_names = ["sadness", "joy", "love"] | |
| records = list( | |
| download_data._emotion_records( | |
| cast(Any, dataset_split), | |
| label_names, | |
| ) | |
| ) | |
| self.assertEqual(records[0]["emotions"], ["joy"]) | |
| # Out-of-range index falls back to string representation | |
| self.assertEqual(records[1]["emotions"], ["sadness", "5"]) | |
| # Non-int values fall back to string | |
| self.assertEqual(records[2]["emotions"], ["2"]) | |
| def test_topic_records_handles_varied_label_inputs(self) -> None: | |
| dataset_split = DummyDataset( | |
| [ | |
| {"text": "news", "label": 3}, | |
| {"text": "list", "label": [1]}, | |
| {"text": "unknown", "label": "5"}, | |
| {"text": "missing", "label": []}, | |
| ] | |
| ) | |
| label_names = ["World", "Sports", "Business", "Sci/Tech"] | |
| records = list( | |
| download_data._topic_records( | |
| cast(Any, dataset_split), | |
| label_names, | |
| ) | |
| ) | |
| self.assertEqual(records[0]["topic"], "Sci/Tech") | |
| self.assertEqual(records[1]["topic"], "Sports") | |
| # Out-of-range string label falls back to original string value | |
| self.assertEqual(records[2]["topic"], "5") | |
| # Empty list yields empty string | |
| self.assertEqual(records[3]["topic"], "") | |
| if __name__ == "__main__": | |
| unittest.main() | |