Source code for ydb_sqlalchemy.sqlalchemy.types

import decimal
from typing import Any, Mapping, Type, Union

from sqlalchemy import __version__ as sa_version

if sa_version.startswith("2."):
    from sqlalchemy import ColumnElement
else:
    from sqlalchemy.sql.expression import ColumnElement

from sqlalchemy import ARRAY, exc, Table, types
from sqlalchemy.sql import type_api

from .datetime_types import YqlDate, YqlDateTime, YqlTimestamp, YqlDate32, YqlTimestamp64, YqlDateTime64  # noqa: F401
from .json import YqlJSON  # noqa: F401


[docs] class UInt64(types.Integer): __visit_name__ = "uint64"
[docs] class UInt32(types.Integer): __visit_name__ = "uint32"
[docs] class UInt16(types.Integer): __visit_name__ = "uint16"
[docs] class UInt8(types.Integer): __visit_name__ = "uint8"
[docs] class Int64(types.Integer): __visit_name__ = "int64"
[docs] class Int32(types.Integer): __visit_name__ = "int32"
[docs] class Int16(types.Integer): __visit_name__ = "int32"
[docs] class Int8(types.Integer): __visit_name__ = "int8"
[docs] class Decimal(types.DECIMAL): __visit_name__ = "DECIMAL" def __init__(self, precision=None, scale=None, asdecimal=True): # YDB supports Decimal(22,9) by default if precision is None: precision = 22 if scale is None: scale = 9 super().__init__(precision=precision, scale=scale, asdecimal=asdecimal)
[docs] def bind_processor(self, dialect): def process(value): if value is None: return None # Convert float to Decimal if needed if isinstance(value, float): return decimal.Decimal(str(value)) elif isinstance(value, str): return decimal.Decimal(value) elif not isinstance(value, decimal.Decimal): return decimal.Decimal(str(value)) return value return process
[docs] def result_processor(self, dialect, coltype): def process(value): if value is None: return None # YDB always returns Decimal values as decimal.Decimal objects # But if asdecimal=False, we should convert to float if not self.asdecimal: return float(value) # For asdecimal=True (default), return as Decimal if not isinstance(value, decimal.Decimal): return decimal.Decimal(str(value)) return value return process
[docs] def literal_processor(self, dialect): def process(value): # Convert float to Decimal if needed if isinstance(value, float): value = decimal.Decimal(str(value)) elif not isinstance(value, decimal.Decimal): value = decimal.Decimal(str(value)) # Use default precision and scale if not specified precision = self.precision if self.precision is not None else 22 scale = self.scale if self.scale is not None else 9 return f'Decimal("{str(value)}", {precision}, {scale})' return process
[docs] class ListType(ARRAY): __visit_name__ = "list_type"
[docs] def bind_processor(self, dialect): item_proc = self.item_type.bind_processor(dialect) def process(value): if value is None: return None return [item_proc(v) if v is not None else None for v in value] if item_proc: return process return None
[docs] class HashableDict(dict): def __hash__(self): return hash(tuple(self.items()))
[docs] class Optional(types.TypeEngine): """ Wrapper for YDB Optional type. Used primarily within StructType to denote nullable fields. """ __visit_name__ = "optional" def __init__(self, element_type: Union[Type[types.TypeEngine], types.TypeEngine]): self.element_type = element_type
[docs] class StructType(types.TypeEngine[Mapping[str, Any]]): """ YDB Struct type. Represents a structured data type with named fields, mapped to a Python dictionary. """ __visit_name__ = "struct_type" def __init__( self, fields_types: Mapping[ str, Union[Type[types.TypeEngine], types.TypeEngine, Optional], ], ): self.fields_types = HashableDict(dict(sorted(fields_types.items())))
[docs] @classmethod def from_table(cls, table: Table) -> "StructType": """ Create a StructType definition from a SQLAlchemy Table. Automatically wraps nullable columns in Optional. :param table: SQLAlchemy Table object :return: StructType instance """ fields = {} for col in table.columns: t = col.type if col.nullable: fields[col.name] = Optional(t) else: fields[col.name] = t return cls(fields)
@property def python_type(self): return dict
[docs] def compare_values(self, x, y): return x == y
[docs] def bind_processor(self, dialect): processors = {} for name, type_ in self.fields_types.items(): if isinstance(type_, Optional): type_ = type_.element_type type_ = type_api.to_instance(type_) proc = type_.bind_processor(dialect) if proc: processors[name] = proc if not processors: return None def process(value): if value is None: return None new_value = value.copy() for name, proc in processors.items(): if name in new_value: if new_value[name] is not None: new_value[name] = proc(new_value[name]) return new_value return process
[docs] class Lambda(ColumnElement): __visit_name__ = "lambda" def __init__(self, func): if not callable(func): raise exc.ArgumentError("func must be callable") self.type = type_api.NULLTYPE self.func = func
[docs] class Binary(types.LargeBinary): __visit_name__ = "BINARY"
[docs] def bind_processor(self, dialect): return None