Source code for ydb_sqlalchemy.sqlalchemy.types

import decimal
from typing import Any, Mapping, Type, Union

from sqlalchemy import __version__ as sa_version

if sa_version.startswith("2."):
    from sqlalchemy import ColumnElement
else:
    from sqlalchemy.sql.expression import ColumnElement

from sqlalchemy import ARRAY, exc, types
from sqlalchemy.sql import type_api

from .datetime_types import YqlDate, YqlDateTime, YqlTimestamp, YqlDate32, YqlTimestamp64, YqlDateTime64  # noqa: F401
from .json import YqlJSON  # noqa: F401


[docs] class UInt64(types.Integer): __visit_name__ = "uint64"
[docs] class UInt32(types.Integer): __visit_name__ = "uint32"
[docs] class UInt16(types.Integer): __visit_name__ = "uint16"
[docs] class UInt8(types.Integer): __visit_name__ = "uint8"
[docs] class Int64(types.Integer): __visit_name__ = "int64"
[docs] class Int32(types.Integer): __visit_name__ = "int32"
[docs] class Int16(types.Integer): __visit_name__ = "int32"
[docs] class Int8(types.Integer): __visit_name__ = "int8"
[docs] class Decimal(types.DECIMAL): __visit_name__ = "DECIMAL" def __init__(self, precision=None, scale=None, asdecimal=True): # YDB supports Decimal(22,9) by default if precision is None: precision = 22 if scale is None: scale = 9 super().__init__(precision=precision, scale=scale, asdecimal=asdecimal)
[docs] def bind_processor(self, dialect): def process(value): if value is None: return None # Convert float to Decimal if needed if isinstance(value, float): return decimal.Decimal(str(value)) elif isinstance(value, str): return decimal.Decimal(value) elif not isinstance(value, decimal.Decimal): return decimal.Decimal(str(value)) return value return process
[docs] def result_processor(self, dialect, coltype): def process(value): if value is None: return None # YDB always returns Decimal values as decimal.Decimal objects # But if asdecimal=False, we should convert to float if not self.asdecimal: return float(value) # For asdecimal=True (default), return as Decimal if not isinstance(value, decimal.Decimal): return decimal.Decimal(str(value)) return value return process
[docs] def literal_processor(self, dialect): def process(value): # Convert float to Decimal if needed if isinstance(value, float): value = decimal.Decimal(str(value)) elif not isinstance(value, decimal.Decimal): value = decimal.Decimal(str(value)) # Use default precision and scale if not specified precision = self.precision if self.precision is not None else 22 scale = self.scale if self.scale is not None else 9 return f'Decimal("{str(value)}", {precision}, {scale})' return process
[docs] class ListType(ARRAY): __visit_name__ = "list_type"
[docs] class HashableDict(dict): def __hash__(self): return hash(tuple(self.items()))
[docs] class StructType(types.TypeEngine[Mapping[str, Any]]): __visit_name__ = "struct_type" def __init__(self, fields_types: Mapping[str, Union[Type[types.TypeEngine], Type[types.TypeDecorator]]]): self.fields_types = HashableDict(dict(sorted(fields_types.items()))) @property def python_type(self): return dict
[docs] def compare_values(self, x, y): return x == y
[docs] class Lambda(ColumnElement): __visit_name__ = "lambda" def __init__(self, func): if not callable(func): raise exc.ArgumentError("func must be callable") self.type = type_api.NULLTYPE self.func = func