test_database.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import tempfile
  2. import unittest
  3. from pathlib import Path
  4. import pandas as pd
  5. from defect_analysis.database import (
  6. create_import_batch,
  7. init_database,
  8. insert_defects,
  9. list_import_batches,
  10. load_defects,
  11. )
  12. from defect_analysis.schemas import normalize_defect_schema
  13. class DatabaseTest(unittest.TestCase):
  14. def setUp(self):
  15. self.tmpdir = tempfile.TemporaryDirectory()
  16. self.db_path = Path(self.tmpdir.name) / "defects.db"
  17. self.df = normalize_defect_schema(
  18. pd.DataFrame(
  19. {
  20. "defect_id": ["D1", "D2"],
  21. "panel_id": ["P1", "P2"],
  22. "batch_id": ["B1", "B1"],
  23. "equipment_id": ["LAM-A01", "LAM-A01"],
  24. "seat_id": ["R1C1", "R1C2"],
  25. "inspection_station": ["AOI-1", "AOI-1"],
  26. "timestamp": [pd.Timestamp("2026-04-01 08:00:00"), pd.Timestamp("2026-04-01 09:00:00")],
  27. "defect_type": ["划痕", "气泡"],
  28. "severity": ["严重", "轻微"],
  29. "x_mm": [10.0, 20.0],
  30. "y_mm": [30.0, 40.0],
  31. "panel_width_mm": [155.0, 155.0],
  32. "panel_height_mm": [340.0, 340.0],
  33. "hour": [8, 9],
  34. "shift": ["白班", "白班"],
  35. "day": ["2026-04-01", "2026-04-01"],
  36. "lam_fixture_id": ["FIX-1", "FIX-2"],
  37. }
  38. )
  39. )
  40. def tearDown(self):
  41. self.tmpdir.cleanup()
  42. def test_init_database_creates_required_tables(self):
  43. init_database(self.db_path)
  44. batches = list_import_batches(self.db_path)
  45. self.assertEqual([], batches.to_dict("records"))
  46. def test_insert_and_load_defects_round_trips_schema(self):
  47. init_database(self.db_path)
  48. import_id = create_import_batch(self.db_path, source_name="unit-test.csv", row_count=len(self.df))
  49. inserted = insert_defects(self.db_path, self.df, import_id=import_id)
  50. loaded = load_defects(self.db_path)
  51. self.assertEqual(2, inserted)
  52. self.assertEqual(["D1", "D2"], loaded["defect_id"].tolist())
  53. self.assertIn("lam_fixture_id", loaded.columns)
  54. self.assertEqual("FIX-1", loaded.loc[0, "lam_fixture_id"])
  55. self.assertEqual(import_id, loaded.loc[0, "import_id"])
  56. def test_insert_defects_is_idempotent_by_defect_id(self):
  57. init_database(self.db_path)
  58. import_id = create_import_batch(self.db_path, source_name="unit-test.csv", row_count=len(self.df))
  59. first = insert_defects(self.db_path, self.df, import_id=import_id)
  60. second = insert_defects(self.db_path, self.df, import_id=import_id)
  61. self.assertEqual(2, first)
  62. self.assertEqual(0, second)
  63. self.assertEqual(2, len(load_defects(self.db_path)))
  64. if __name__ == "__main__":
  65. unittest.main()