Source code for camcops_server.cc_modules.tests.cc_sqla_coltypes_tests

#!/usr/bin/env python

"""
camcops_server/cc_modules/tests/cc_sqla_coltypes_tests.py

===============================================================================

    Copyright (C) 2012, University of Cambridge, Department of Psychiatry.
    Created by Rudolf Cardinal (rnc1001@cam.ac.uk).

    This file is part of CamCOPS.

    CamCOPS is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    CamCOPS is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with CamCOPS. If not, see <https://www.gnu.org/licenses/>.

===============================================================================
"""

import datetime
from typing import Union
import unittest

from cardinal_pythonlib.datetimefunc import coerce_to_pendulum
from cardinal_pythonlib.sqlalchemy.session import SQLITE_MEMORY_URL
import pendulum
from pendulum import DateTime as Pendulum, Duration
import phonenumbers
from semantic_version import Version
from sqlalchemy.engine import create_engine
from sqlalchemy.sql.expression import select
from sqlalchemy.sql.functions import func
from sqlalchemy.sql.schema import Column, MetaData, Table
from sqlalchemy.sql.sqltypes import DateTime, Integer

from camcops_server.cc_modules.cc_sqla_coltypes import (
    isotzdatetime_to_utcdatetime,
    PendulumDateTimeAsIsoTextColType,
    PendulumDurationAsIsoTextColType,
    PhoneNumberColType,
    SemanticVersionColType,
    unknown_field_to_utcdatetime,
)


# =============================================================================
# Unit testing
# =============================================================================


[docs]class SqlaColtypesTest(unittest.TestCase): """ Unit tests. """ # don't inherit from ExtendedTestCase; circular import
[docs] def setUp(self) -> None: super().setUp() engine = create_engine(SQLITE_MEMORY_URL, echo=True) self.meta = MetaData() self.meta.bind = engine # adds execute() method to select() etc.
# ... https://docs.sqlalchemy.org/en/latest/core/connections.html @staticmethod def _assert_dt_equal( a: Union[datetime.datetime, Pendulum], b: Union[datetime.datetime, Pendulum], ) -> None: # Accept that one may have been truncated or rounded to milliseconds. a = coerce_to_pendulum(a) b = coerce_to_pendulum(b) diff = a - b assert diff.microseconds < 1000, f"{a!r} != {b!r}" def test_iso_datetime_field(self) -> None: id_colname = "id" dt_local_colname = "dt_local" dt_utc_colname = "dt_utc" iso_colname = "iso" id_col = Column(id_colname, Integer, primary_key=True) dt_local_col = Column(dt_local_colname, DateTime) dt_utc_col = Column(dt_utc_colname, DateTime) iso_col = Column(iso_colname, PendulumDateTimeAsIsoTextColType) table = Table( "testtable", self.meta, id_col, dt_local_col, dt_utc_col, iso_col ) table.create() now = Pendulum.now() now_utc = now.in_tz(pendulum.UTC) yesterday = now.subtract(days=1) yesterday_utc = yesterday.in_tz(pendulum.UTC) table.insert().values( [ { id_colname: 1, dt_local_colname: now, dt_utc_colname: now_utc, iso_colname: now, }, { id_colname: 2, dt_local_colname: yesterday, dt_utc_colname: yesterday_utc, iso_colname: yesterday, }, ] ).execute() select_fields = [ id_col, dt_local_col, dt_utc_col, iso_col, func.length(dt_local_col).label("len_dt_local_col"), func.length(dt_utc_col).label("len_dt_utc_col"), func.length(iso_col).label("len_iso_col"), isotzdatetime_to_utcdatetime(iso_col).label("iso_to_utcdt"), unknown_field_to_utcdatetime(dt_utc_col).label( "uk_utcdt_to_utcdt" ), unknown_field_to_utcdatetime(iso_col).label("uk_iso_to_utc_dt"), ] rows = list( select(select_fields).select_from(table).order_by(id_col).execute() ) self._assert_dt_equal(rows[0][dt_local_col], now) self._assert_dt_equal(rows[0][dt_utc_col], now_utc) self._assert_dt_equal(rows[0][iso_colname], now) self._assert_dt_equal(rows[0]["iso_to_utcdt"], now_utc) self._assert_dt_equal(rows[0]["uk_utcdt_to_utcdt"], now_utc) self._assert_dt_equal(rows[0]["uk_iso_to_utc_dt"], now_utc) self._assert_dt_equal(rows[1][dt_local_col], yesterday) self._assert_dt_equal(rows[1][dt_utc_col], yesterday_utc) self._assert_dt_equal(rows[1][iso_colname], yesterday) self._assert_dt_equal(rows[1]["iso_to_utcdt"], yesterday_utc) self._assert_dt_equal(rows[1]["uk_utcdt_to_utcdt"], yesterday_utc) self._assert_dt_equal(rows[1]["uk_iso_to_utc_dt"], yesterday_utc) @staticmethod def _assert_duration_equal(a: Duration, b: Duration) -> None: assert a.total_seconds() == b.total_seconds(), f"{a!r} != {b!r}" def test_iso_duration_field(self) -> None: id_colname = "id" duration_colname = "duration_iso" id_col = Column(id_colname, Integer, primary_key=True) duration_col = Column( duration_colname, PendulumDurationAsIsoTextColType ) table = Table("testtable", self.meta, id_col, duration_col) table.create() d1 = Duration(years=1, months=3, seconds=3, microseconds=4) d2 = Duration(seconds=987.654321) d3 = Duration(days=-5) table.insert().values( [ {id_colname: 1, duration_colname: d1}, {id_colname: 2, duration_colname: d2}, {id_colname: 3, duration_colname: d3}, ] ).execute() select_fields = [id_col, duration_col] rows = list( select(select_fields).select_from(table).order_by(id_col).execute() ) self._assert_duration_equal(rows[0][duration_col], d1) self._assert_duration_equal(rows[1][duration_col], d2) self._assert_duration_equal(rows[2][duration_col], d3) @staticmethod def _assert_version_equal(a: Version, b: Version) -> None: assert a == b, f"{a!r} != {b!r}" def test_semantic_version_field(self) -> None: id_colname = "id" version_colname = "version" id_col = Column(id_colname, Integer, primary_key=True) version_col = Column(version_colname, SemanticVersionColType) table = Table("testtable", self.meta, id_col, version_col) table.create() v1 = Version("1.1.0") v2 = Version("2.0.1") v3 = Version("14.0.0") table.insert().values( [ {id_colname: 1, version_colname: v1}, {id_colname: 2, version_colname: v2}, {id_colname: 3, version_colname: v3}, ] ).execute() select_fields = [id_col, version_col] rows = list( select(select_fields).select_from(table).order_by(id_col).execute() ) self._assert_version_equal(rows[0][version_col], v1) self._assert_version_equal(rows[1][version_col], v2) self._assert_version_equal(rows[2][version_col], v3) def test_phone_number_field(self) -> None: id_colname = "id" phone_number_colname = "phone_number" id_col = Column(id_colname, Integer, primary_key=True) phone_number_col = Column(phone_number_colname, PhoneNumberColType) table = Table("testtable", self.meta, id_col, phone_number_col) table.create() # https://en.wikipedia.org/wiki/Fictitious_telephone_number p1 = phonenumbers.parse("+44 (0)113 496 0123") p2 = phonenumbers.parse("+33 1 99 00 12 34 56") p3 = phonenumbers.parse("07700 900123", "GB") p4 = None table.insert().values( [ {id_colname: 1, phone_number_colname: p1}, {id_colname: 2, phone_number_colname: p2}, {id_colname: 3, phone_number_colname: p3}, {id_colname: 4, phone_number_colname: p4}, ] ).execute() select_fields = [id_col, phone_number_col] rows = list( select(select_fields).select_from(table).order_by(id_col).execute() ) self.assertEqual(rows[0][phone_number_col], p1) self.assertEqual(rows[1][phone_number_col], p2) self.assertEqual(rows[2][phone_number_col], p3) self.assertIsNone(rows[3][phone_number_col])