LexiMind / tests /test_data /test_download_records.py
OliverPerrin's picture
Style: Apply ruff formatting
ee1a8a3
"""Unit tests for dataset record helpers in scripts.download_data."""
from __future__ import annotations
import importlib.util
import unittest
from pathlib import Path
from typing import Any, Dict, Iterator, List, cast
PROJECT_ROOT = Path(__file__).resolve().parents[2]
DOWNLOAD_SCRIPT = PROJECT_ROOT / "scripts" / "download_data.py"
spec = importlib.util.spec_from_file_location("download_data", DOWNLOAD_SCRIPT)
if spec is None or spec.loader is None:
raise RuntimeError("Unable to load scripts/download_data.py for testing")
download_data = importlib.util.module_from_spec(spec)
spec.loader.exec_module(download_data)
class DummyDataset:
def __init__(self, records: List[Dict[str, object]]) -> None:
self._records = records
def __iter__(self) -> Iterator[Dict[str, object]]:
return iter(self._records)
class DownloadDataRecordTests(unittest.TestCase):
def test_emotion_records_handles_out_of_range_labels(self) -> None:
dataset_split = DummyDataset(
[
{"text": "sample", "label": 1},
{"text": "multi", "label": [0, 5]},
{"text": "string", "label": "2"},
]
)
label_names = ["sadness", "joy", "love"]
records = list(
download_data._emotion_records(
cast(Any, dataset_split),
label_names,
)
)
self.assertEqual(records[0]["emotions"], ["joy"])
# Out-of-range index falls back to string representation
self.assertEqual(records[1]["emotions"], ["sadness", "5"])
# Non-int values fall back to string
self.assertEqual(records[2]["emotions"], ["2"])
def test_topic_records_handles_varied_label_inputs(self) -> None:
dataset_split = DummyDataset(
[
{"text": "news", "label": 3},
{"text": "list", "label": [1]},
{"text": "unknown", "label": "5"},
{"text": "missing", "label": []},
]
)
label_names = ["World", "Sports", "Business", "Sci/Tech"]
records = list(
download_data._topic_records(
cast(Any, dataset_split),
label_names,
)
)
self.assertEqual(records[0]["topic"], "Sci/Tech")
self.assertEqual(records[1]["topic"], "Sports")
# Out-of-range string label falls back to original string value
self.assertEqual(records[2]["topic"], "5")
# Empty list yields empty string
self.assertEqual(records[3]["topic"], "")
if __name__ == "__main__":
unittest.main()