Source code for camcops_server.cc_modules.tests.merge_db_tests

"""
camcops_server/cc_modules/tests/merge_db_tests.py

===============================================================================

    Copyright (C) 2012, University of Cambridge, Department of Psychiatry.
    Created by Rudolf Cardinal (rnc1001@cam.ac.uk).

    This file is part of CamCOPS.

    CamCOPS is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    CamCOPS is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with CamCOPS. If not, see <https://www.gnu.org/licenses/>.

===============================================================================
"""

from configparser import ConfigParser
from dataclasses import dataclass, replace
from io import StringIO
import logging
import os.path
from tempfile import TemporaryDirectory
from typing import Any, Dict
from unittest import TestCase

from cardinal_pythonlib.classes import all_subclasses
from cardinal_pythonlib.httpconst import MimeType
from cardinal_pythonlib.sqlalchemy.session import make_sqlite_url
from pendulum import now as pendulum_now
from sqlalchemy.orm.session import Session, sessionmaker

from camcops_server.cc_modules.cc_alembic import (
    create_database_from_scratch,
    upgrade_database_to_head,
)
from camcops_server.cc_modules.cc_blob import Blob
from camcops_server.cc_modules.cc_config import CamcopsConfig, get_demo_config
from camcops_server.cc_modules.cc_constants import (
    CONFIG_FILE_SITE_SECTION,
    ConfigParamSite,
    ERA_NOW,
)
from camcops_server.cc_modules.cc_group import Group
from camcops_server.cc_modules.cc_idnumdef import IdNumDefinition
from camcops_server.cc_modules.cc_patient import Patient
from camcops_server.cc_modules.cc_patientidnum import PatientIdNum
from camcops_server.cc_modules.cc_task import Task, TaskHasPatientMixin
from camcops_server.cc_modules.cc_testfactories import (
    AnyIdNumGroupFactory,
    BaseFactory,
    DeviceFactory,
    IpUseFactory,
    UserFactory,
)
from camcops_server.cc_modules.merge_db import merge_camcops_db
from camcops_server.cc_modules.cc_unittest import DEMO_PNG_BYTES
from camcops_server.tasks.photo import Photo

log = logging.getLogger(__name__)


# =============================================================================
# Unit testing
# =============================================================================


[docs]@dataclass class MergeTestDbInfo: # ID numbers: nhs_which_idnum: int nhs_desc: str nhs_shortdesc: str rio_which_idnum: int rio_desc: str rio_shortdesc: str # Groups: groupnum: int
[docs]class MergeDbTests(TestCase): @staticmethod def _get_config(dburl: str) -> CamcopsConfig: base_cfg_txt = get_demo_config() parser = ConfigParser() parser.read_string(base_cfg_txt) parser.set(CONFIG_FILE_SITE_SECTION, ConfigParamSite.DB_URL, dburl) with StringIO() as buffer: parser.write(buffer) return CamcopsConfig( config_filename="", config_text=buffer.getvalue() ) @staticmethod def _populate_src_db(session: Session) -> MergeTestDbInfo: # Insert some linked data into database 1. # I have failed to achieve this with the test factory classes in full, # e.g. NHSIdNumDefinitionFactory, PatientFactory, etc. The ordering is # currently such that multiple spurious groups seem to get generated. log.info("Inserting data into source database...") for factory in all_subclasses(BaseFactory): factory._meta.sqlalchemy_session = session # Objects created via the factory classes will be auto-added to the # session. But manually add groups and ID number definitions. # Group group = Group( id=777, name="src_group", description="source group", upload_policy=AnyIdNumGroupFactory.upload_policy, finalize_policy=AnyIdNumGroupFactory.finalize_policy, ip_use=IpUseFactory(), ) session.add(group) # Device device = DeviceFactory() # Demo user user = UserFactory() # All sorts of things need this (see GenericTabletRecordMixin): now = pendulum_now() default_tablet_args = dict( _adding_user=user, _current=True, _device=device, _era=ERA_NOW, _group=group, _group_id=group.id, _when_added_batch_utc=now, _when_added_exact=now, ) # ID definitions nhs_iddef = IdNumDefinition( which_idnum=1000, short_description="NHS#", description="NHS number", ) session.add(nhs_iddef) rio_iddef = IdNumDefinition( which_idnum=1001, short_description="RiO", description="RiO number", ) session.add(rio_iddef) # Patients with ID numbers patient_with_one_idnum = Patient( id=1, forename="Forename1", surname="Surname1", sex="M", **default_tablet_args, ) session.add(patient_with_one_idnum) session.add( PatientIdNum( id=1, patient_id=patient_with_one_idnum.id, which_idnum=nhs_iddef.which_idnum, idnum_value=555, **default_tablet_args, ) ) patient_with_two_idnums = Patient( id=2, forename="Forename2", surname="Surname2", sex="F", **default_tablet_args, ) session.add(patient_with_two_idnums) session.add( PatientIdNum( id=2, patient_id=patient_with_two_idnums.id, which_idnum=nhs_iddef.which_idnum, idnum_value=666, **default_tablet_args, ) ) session.add( PatientIdNum( id=3, patient_id=patient_with_two_idnums.id, which_idnum=rio_iddef.which_idnum, idnum_value=222, **default_tablet_args, ) ) session.commit() # just in case blobargs = dict( tablename=Photo.tablename, fieldname="photo_blobid", filename="some_picture.png", mimetype=MimeType.PNG, image_rotation_deg_cw=0, theblob=DEMO_PNG_BYTES, **default_tablet_args, ) default_task_args = default_tablet_args.copy() default_task_args.update(when_created=now) for cls in Task.all_subclasses_by_tablename(): # Make one task from each class for both test patients. # Beware the factory classes, because they are making new groups. t1_kwargs: Dict[str, Any] = default_task_args.copy() t2_kwargs = default_task_args.copy() t1_kwargs.update(id=1) t2_kwargs.update(id=2) if issubclass(cls, TaskHasPatientMixin): t1_kwargs.update(patient_id=patient_with_one_idnum.id) t2_kwargs.update(patient_id=patient_with_two_idnums.id) is_photo = cls.__name__ == "Photo" if is_photo: # We'll set tablepk=0 temporarily, then fix it in a moment. blob1 = Blob(id=1, tablepk=0, **blobargs) session.add(blob1) blob2 = Blob(id=2, tablepk=0, **blobargs) session.add(blob2) task1 = cls(**t1_kwargs) task2 = cls(**t2_kwargs) session.add(task1) session.add(task2) if is_photo: task1.photo_blobid = blob1.id task2.photo_blobid = blob2.id blob1.tablepk = task1.id blob2.tablepk = task2.id session.commit() # just in case log.info("... source data inserted") return MergeTestDbInfo( nhs_which_idnum=nhs_iddef.which_idnum, nhs_desc=nhs_iddef.description, nhs_shortdesc=nhs_iddef.short_description, rio_which_idnum=rio_iddef.which_idnum, rio_desc=rio_iddef.description, rio_shortdesc=rio_iddef.short_description, groupnum=group.id, ) @staticmethod def _populate_dst_db( session: Session, srcinfo: MergeTestDbInfo ) -> MergeTestDbInfo: log.info("Inserting basic data only into destination database...") dstinfo = replace( srcinfo, nhs_which_idnum=1, rio_which_idnum=2, groupnum=9 ) assert dstinfo.nhs_which_idnum != dstinfo.rio_which_idnum # Hacky... dst_iddef_nhs = IdNumDefinition( which_idnum=dstinfo.nhs_which_idnum, description=dstinfo.nhs_desc, short_description=dstinfo.nhs_shortdesc, ) session.add(dst_iddef_nhs) dst_iddef_rio = IdNumDefinition( which_idnum=dstinfo.rio_which_idnum, description=dstinfo.rio_desc, short_description=dstinfo.rio_shortdesc, ) session.add(dst_iddef_rio) dst_group = Group( id=dstinfo.groupnum, name="dest_group", description="destination group", ) session.add(dst_group) session.commit() log.info("... destination data inserted") return dstinfo
[docs] def setUp(self) -> None: # --------------------------------------------------------------------- # Basic setup # --------------------------------------------------------------------- # Create database URLs self.tempdir = TemporaryDirectory() self.dbfilename1 = os.path.join(self.tempdir.name, "camcops_test_1.db") self.dbfilename2 = os.path.join(self.tempdir.name, "camcops_test_2.db") self.dburl1 = make_sqlite_url(self.dbfilename1) self.dburl2 = make_sqlite_url(self.dbfilename2) log.info(f"Test database URLs:\n- {self.dburl1}\n- {self.dburl2}") # Create configs self.cfg1 = self._get_config(self.dburl1) self.cfg2 = self._get_config(self.dburl2) # Create table structure in each. log.info("Creating source database (from-scratch method)...") create_database_from_scratch(self.cfg1) log.info("... source database created") alembic_currently_working_with_this_sqla_version = False # todo: fix # 2025-02-27: There is an internal Alembic failure with alembic==1.4.2 # and SQLAlchemy==1.4.49. if alembic_currently_working_with_this_sqla_version: log.info("Creating destination database (incremental method)...") upgrade_database_to_head(self.cfg2) else: log.info("Creating destination database (from-scratch method)...") create_database_from_scratch(self.cfg2) log.info("... destination database created") # Create a session for each database self.db1session: Session = sessionmaker()( bind=self.cfg1.get_sqla_engine() ) self.db2session: Session = sessionmaker()( bind=self.cfg2.get_sqla_engine() ) # --------------------------------------------------------------------- # Data setup # --------------------------------------------------------------------- self.srcinfo = self._populate_src_db(self.db1session) self.dstinfo = self._populate_dst_db(self.db2session, self.srcinfo)
def test_merge_db(self) -> None: # import pdb; pdb.set_trace() whichidnum_map = { self.srcinfo.nhs_which_idnum: self.dstinfo.nhs_which_idnum, self.srcinfo.rio_which_idnum: self.dstinfo.rio_which_idnum, } groupnum_map = { self.srcinfo.groupnum: self.dstinfo.groupnum, } log.info("Merging test databases...") merge_camcops_db( src=self.dburl1, dst_url=self.dburl2, echo=False, dummy_run=False, info_only=False, default_group_id=1, default_group_name="testgroup", groupnum_map=groupnum_map, whichidnum_map=whichidnum_map, skip_export_logs=True, skip_audit_logs=True, ) log.info("... test databases merged.")