Spaces:
Runtime error
Runtime error
File size: 10,830 Bytes
8018595 3fc6f6d 8018595 3fc6f6d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 |
"""Tests for the documentation generation script."""
import pytest
from scripts.generate_documentation import (
_normalise_cancer_label,
_unique_qcancer_sites,
build_field_usage_map,
cancer_types_for_model,
discover_risk_models,
extract_field_attributes,
extract_model_requirements,
format_field_path,
gather_spec_details,
group_fields_by_requirements,
prettify_field_name,
traverse_user_input_structure,
)
class TestUtilityFunctions:
"""Test utility functions for documentation generation."""
def test_prettify_field_name(self):
"""Test field name prettification."""
assert prettify_field_name("female_specific") == "Female Specific"
assert prettify_field_name("family_history[]") == "Family History"
assert prettify_field_name("age_years") == "Age Years"
assert prettify_field_name("test") == "Test"
def test_format_field_path(self):
"""Test field path formatting."""
assert (
format_field_path("demographics.age_years") == "Demographics\n - Age Years"
)
assert (
format_field_path("family_history[].relation")
== "Family History\n - Relation"
)
assert format_field_path("simple_field") == "Simple Field"
def test_normalise_cancer_label(self):
"""Test cancer label normalization."""
assert _normalise_cancer_label("Lung Cancer") == "Lung"
assert _normalise_cancer_label("breast-cancer") == "Breast"
assert _normalise_cancer_label("colorectal_cancer") == "Colorectal"
assert _normalise_cancer_label("Prostate") == "Prostate"
def test_unique_qcancer_sites(self):
"""Test QCancer sites extraction."""
sites = _unique_qcancer_sites()
assert isinstance(sites, list)
assert len(sites) > 0
# Check that sites are normalized
for site in sites:
assert "cancer" not in site.lower()
assert "_" not in site
assert "-" not in site
def test_cancer_types_for_model(self):
"""Test cancer type extraction for models."""
# Mock a risk model
class MockModel:
"""Mock risk model for testing."""
def __init__(self, name, cancer_type):
"""Initialize mock model.
Args:
name: Model name.
cancer_type: Cancer type string.
"""
self.name = name
self._cancer_type = cancer_type
def cancer_type(self):
"""Return cancer type.
Returns:
str: Cancer type string.
"""
return self._cancer_type
# Test regular model
model = MockModel("gail", "breast")
types = cancer_types_for_model(model)
assert types == ["Breast"]
# Test QCancer model
qcancer_model = MockModel("qcancer", "multiple")
qcancer_types = cancer_types_for_model(qcancer_model)
assert isinstance(qcancer_types, list)
assert len(qcancer_types) > 0
def test_group_fields_by_requirements(self):
"""Test field grouping by requirements."""
# Mock requirements data
requirements = [
("demographics.age_years", int, True),
("demographics.sex", str, True),
("family_history.relation", str, False),
("family_history.cancer_type", str, False),
]
grouped = group_fields_by_requirements(requirements)
assert len(grouped) == 2
# Check demographics group
dem_group = next((g for g in grouped if g[0] == "Demographics"), None)
assert dem_group is not None
assert len(dem_group[1]) == 2
# Check family history group
fh_group = next((g for g in grouped if g[0] == "Family History"), None)
assert fh_group is not None
assert len(fh_group[1]) == 2
def test_gather_spec_details_regular(self):
"""Test spec details gathering for regular fields."""
note = "Test note"
note_text, required_text, unit_text, range_text = gather_spec_details(
None, None, note
)
assert note_text == "Test note"
assert required_text == "Optional"
assert unit_text == "-"
assert range_text == "-"
def test_gather_spec_details_clinical_observation(self):
"""Test spec details gathering for clinical observations."""
note = "multivitamin - Yes/No"
note_text, required_text, unit_text, range_text = gather_spec_details(
None, None, note
)
assert "Multivitamin usage status" in note_text
assert required_text == "Optional"
assert unit_text == "-"
assert range_text == "Yes/No"
def test_gather_spec_details_unknown_observation(self):
"""Test spec details gathering for unknown clinical observations."""
note = "unknown_obs - Some values"
note_text, required_text, unit_text, range_text = gather_spec_details(
None, None, note
)
assert "Clinical observation: unknown_obs" in note_text
assert required_text == "Optional"
assert unit_text == "-"
assert range_text == "Some values"
class TestMainFunctionality:
"""Test main functionality of the documentation generator."""
def test_discover_risk_models(self):
"""Test risk model discovery."""
models = discover_risk_models()
assert isinstance(models, list)
assert len(models) > 0
# Check that all models have required attributes
for model in models:
assert hasattr(model, "name")
assert hasattr(model, "cancer_type")
assert hasattr(model, "description")
assert hasattr(model, "interpretation")
assert hasattr(model, "references")
def test_main_function_import(self):
"""Test that the main function can be imported without errors."""
from scripts.generate_documentation import main
assert callable(main)
class TestEdgeCases:
"""Test edge cases and error handling."""
def test_empty_field_grouping(self):
"""Test field grouping with empty input."""
grouped = group_fields_by_requirements([])
assert grouped == []
def test_single_segment_path(self):
"""Test field path formatting with single segment."""
result = format_field_path("single_field")
assert result == "Single Field"
def test_empty_cancer_label(self):
"""Test cancer label normalization with empty input."""
result = _normalise_cancer_label("")
assert result == ""
def test_none_cancer_label(self):
"""Test cancer label normalization with None input."""
# The function should handle None input gracefully
with pytest.raises(AttributeError):
_normalise_cancer_label(None)
def test_gather_spec_details_none_inputs(self):
"""Test spec details gathering with None inputs."""
note_text, required_text, unit_text, range_text = gather_spec_details(
None, None, ""
)
assert note_text == "-"
assert required_text == "Optional"
assert unit_text == "-"
assert range_text == "-"
def test_gather_spec_details_empty_note(self):
"""Test spec details gathering with empty note."""
note_text, required_text, unit_text, range_text = gather_spec_details(
None, None, ""
)
assert note_text == "-"
assert required_text == "Optional"
assert unit_text == "-"
assert range_text == "-"
class TestUserInputStructureExtraction:
"""Test functions for extracting and processing UserInput structure."""
def test_traverse_user_input_structure(self):
"""Test UserInput structure traversal."""
from sentinel.user_input import UserInput
structure = traverse_user_input_structure(UserInput)
assert isinstance(structure, list)
assert len(structure) > 0
# Check that we have both parent models and leaf fields
parent_models = [item for item in structure if item[2] is not None]
leaf_fields = [item for item in structure if item[2] is None]
assert len(parent_models) > 0
assert len(leaf_fields) > 0
# Check structure format: (path, name, model_class)
for path, name, model_class in structure:
assert isinstance(path, str)
assert isinstance(name, str)
assert model_class is None or hasattr(model_class, "model_fields")
def test_extract_model_requirements(self):
"""Test model requirements extraction."""
from sentinel.risk_models.gail import GailRiskModel
model = GailRiskModel()
requirements = extract_model_requirements(model)
assert isinstance(requirements, list)
assert len(requirements) > 0
# Check format: (field_path, field_type, is_required)
for field_path, field_type, is_required in requirements:
assert isinstance(field_path, str)
# field_type can be Annotated types, so we check it's not None
assert field_type is not None
assert isinstance(is_required, bool)
def test_build_field_usage_map(self):
"""Test field usage mapping."""
from sentinel.risk_models.claus import ClausRiskModel
from sentinel.risk_models.gail import GailRiskModel
models = [GailRiskModel(), ClausRiskModel()]
usage_map = build_field_usage_map(models)
assert isinstance(usage_map, dict)
assert len(usage_map) > 0
# Check format: field_path -> [(model_name, is_required), ...]
for field_path, usage_list in usage_map.items():
assert isinstance(field_path, str)
assert isinstance(usage_list, list)
for model_name, is_required in usage_list:
assert isinstance(model_name, str)
assert isinstance(is_required, bool)
def test_extract_field_attributes(self):
"""Test field attributes extraction."""
from sentinel.user_input import UserInput
# Get a field from UserInput
field_info = UserInput.model_fields["demographics"]
field_type = field_info.annotation
description, examples, constraints, used_by, enum_class = (
extract_field_attributes(field_info, field_type)
)
assert isinstance(description, str)
assert isinstance(examples, str)
assert isinstance(constraints, str)
assert isinstance(used_by, str)
assert enum_class is None or isinstance(enum_class, type)
|