import tempfile import unittest from pathlib import Path import sqlite3 import pandas as pd from defect_analysis.database import ( create_import_batch, init_database, insert_defects, list_import_batches, load_defects, ) from defect_analysis.schemas import normalize_defect_schema class DatabaseTest(unittest.TestCase): def setUp(self): self.tmpdir = tempfile.TemporaryDirectory() self.db_path = Path(self.tmpdir.name) / "defects.db" self.df = normalize_defect_schema( pd.DataFrame( { "defect_id": ["D1", "D2"], "panel_id": ["P1", "P2"], "batch_id": ["B1", "B1"], "equipment_id": ["LAM-A01", "LAM-A01"], "seat_id": ["R1C1", "R1C2"], "inspection_station": ["AOI-1", "AOI-1"], "timestamp": [pd.Timestamp("2026-04-01 08:00:00"), pd.Timestamp("2026-04-01 09:00:00")], "defect_type": ["划痕", "气泡"], "severity": ["严重", "轻微"], "x_mm": [10.0, 20.0], "y_mm": [30.0, 40.0], "panel_width_mm": [155.0, 155.0], "panel_height_mm": [340.0, 340.0], "hour": [8, 9], "shift": ["白班", "白班"], "day": ["2026-04-01", "2026-04-01"], "lam_fixture_id": ["FIX-1", "FIX-2"], } ) ) def tearDown(self): self.tmpdir.cleanup() def test_init_database_creates_required_tables(self): init_database(self.db_path) batches = list_import_batches(self.db_path) self.assertEqual([], batches.to_dict("records")) def test_insert_and_load_defects_round_trips_schema(self): init_database(self.db_path) import_id = create_import_batch(self.db_path, source_name="unit-test.csv", row_count=len(self.df)) inserted = insert_defects(self.db_path, self.df, import_id=import_id) loaded = load_defects(self.db_path) self.assertEqual(2, inserted) self.assertEqual(["D1", "D2"], loaded["defect_id"].tolist()) self.assertIn("lam_fixture_id", loaded.columns) self.assertEqual("FIX-1", loaded.loc[0, "lam_fixture_id"]) self.assertEqual(import_id, loaded.loc[0, "import_id"]) def test_insert_defects_is_idempotent_by_defect_id(self): init_database(self.db_path) import_id = create_import_batch(self.db_path, source_name="unit-test.csv", row_count=len(self.df)) first = insert_defects(self.db_path, self.df, import_id=import_id) second = insert_defects(self.db_path, self.df, import_id=import_id) self.assertEqual(2, first) self.assertEqual(0, second) self.assertEqual(2, len(load_defects(self.db_path))) def test_foreign_keys_are_enforced(self): init_database(self.db_path) with self.assertRaises(sqlite3.IntegrityError): insert_defects(self.db_path, self.df, import_id=999) if __name__ == "__main__": unittest.main()