test_database.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. import tempfile
  2. import unittest
  3. from pathlib import Path
  4. import sqlite3
  5. import pandas as pd
  6. from defect_analysis.database import (
  7. create_import_batch,
  8. init_database,
  9. insert_defects,
  10. list_import_batches,
  11. load_defects,
  12. )
  13. from defect_analysis.schemas import normalize_defect_schema
  14. class DatabaseTest(unittest.TestCase):
  15. def setUp(self):
  16. self.tmpdir = tempfile.TemporaryDirectory()
  17. self.db_path = Path(self.tmpdir.name) / "defects.db"
  18. self.df = normalize_defect_schema(
  19. pd.DataFrame(
  20. {
  21. "defect_id": ["D1", "D2"],
  22. "panel_id": ["P1", "P2"],
  23. "batch_id": ["B1", "B1"],
  24. "equipment_id": ["LAM-A01", "LAM-A01"],
  25. "seat_id": ["R1C1", "R1C2"],
  26. "inspection_station": ["AOI-1", "AOI-1"],
  27. "timestamp": [pd.Timestamp("2026-04-01 08:00:00"), pd.Timestamp("2026-04-01 09:00:00")],
  28. "defect_type": ["划痕", "气泡"],
  29. "severity": ["严重", "轻微"],
  30. "x_mm": [10.0, 20.0],
  31. "y_mm": [30.0, 40.0],
  32. "panel_width_mm": [155.0, 155.0],
  33. "panel_height_mm": [340.0, 340.0],
  34. "hour": [8, 9],
  35. "shift": ["白班", "白班"],
  36. "day": ["2026-04-01", "2026-04-01"],
  37. "lam_fixture_id": ["FIX-1", "FIX-2"],
  38. }
  39. )
  40. )
  41. def tearDown(self):
  42. self.tmpdir.cleanup()
  43. def test_init_database_creates_required_tables(self):
  44. init_database(self.db_path)
  45. batches = list_import_batches(self.db_path)
  46. self.assertEqual([], batches.to_dict("records"))
  47. def test_insert_and_load_defects_round_trips_schema(self):
  48. init_database(self.db_path)
  49. import_id = create_import_batch(self.db_path, source_name="unit-test.csv", row_count=len(self.df))
  50. inserted = insert_defects(self.db_path, self.df, import_id=import_id)
  51. loaded = load_defects(self.db_path)
  52. self.assertEqual(2, inserted)
  53. self.assertEqual(["D1", "D2"], loaded["defect_id"].tolist())
  54. self.assertIn("lam_fixture_id", loaded.columns)
  55. self.assertEqual("FIX-1", loaded.loc[0, "lam_fixture_id"])
  56. self.assertEqual(import_id, loaded.loc[0, "import_id"])
  57. def test_insert_defects_is_idempotent_by_defect_id(self):
  58. init_database(self.db_path)
  59. import_id = create_import_batch(self.db_path, source_name="unit-test.csv", row_count=len(self.df))
  60. first = insert_defects(self.db_path, self.df, import_id=import_id)
  61. second = insert_defects(self.db_path, self.df, import_id=import_id)
  62. self.assertEqual(2, first)
  63. self.assertEqual(0, second)
  64. self.assertEqual(2, len(load_defects(self.db_path)))
  65. def test_foreign_keys_are_enforced(self):
  66. init_database(self.db_path)
  67. with self.assertRaises(sqlite3.IntegrityError):
  68. insert_defects(self.db_path, self.df, import_id=999)
  69. if __name__ == "__main__":
  70. unittest.main()