Source code for ydb_sqlalchemy.sqlalchemy

"""
Experimental
Work in progress, breaking changes are possible.
"""

import collections
import collections.abc
from typing import Any, Mapping, Optional, Sequence, Tuple, Union

import sqlalchemy as sa
import ydb
from sqlalchemy import util
from sqlalchemy.engine import characteristics, reflection
from sqlalchemy.engine.default import DefaultExecutionContext, StrCompileDialect
from sqlalchemy.exc import NoSuchTableError
from sqlalchemy.sql import functions

from sqlalchemy.sql.elements import ClauseList

import ydb_dbapi
from ydb_sqlalchemy.sqlalchemy.dbapi_adapter import AdaptedAsyncConnection
from ydb_sqlalchemy.sqlalchemy.dml import Upsert

from ydb_sqlalchemy.sqlalchemy.compiler import YqlCompiler, YqlDDLCompiler, YqlIdentifierPreparer, YqlTypeCompiler

from . import types


OLD_SA = sa.__version__ < "2."


[docs] class ParametrizedFunction(functions.Function): __visit_name__ = "parametrized_function" def __init__(self, name, params, *args, **kwargs): super(ParametrizedFunction, self).__init__(name, *args, **kwargs) self._func_name = name self._func_params = params self.params_expr = ClauseList(operator=functions.operators.comma_op, group_contents=True, *params).self_group()
[docs] def upsert(table): return Upsert(table)
COLUMN_TYPES = { ydb.PrimitiveType.Int8: sa.INTEGER, ydb.PrimitiveType.Int16: sa.INTEGER, ydb.PrimitiveType.Int32: sa.INTEGER, ydb.PrimitiveType.Int64: sa.INTEGER, ydb.PrimitiveType.Uint8: sa.INTEGER, ydb.PrimitiveType.Uint16: sa.INTEGER, ydb.PrimitiveType.Uint32: types.UInt32, ydb.PrimitiveType.Uint64: types.UInt64, ydb.PrimitiveType.Float: sa.FLOAT, ydb.PrimitiveType.Double: sa.FLOAT, ydb.PrimitiveType.String: sa.BINARY, ydb.PrimitiveType.Utf8: sa.TEXT, ydb.PrimitiveType.Json: sa.JSON, ydb.PrimitiveType.JsonDocument: sa.JSON, ydb.DecimalType: sa.DECIMAL, ydb.PrimitiveType.Yson: sa.TEXT, ydb.PrimitiveType.Date: sa.DATE, ydb.PrimitiveType.Date32: sa.DATE, ydb.PrimitiveType.Timestamp64: sa.TIMESTAMP, ydb.PrimitiveType.Datetime64: sa.DATETIME, ydb.PrimitiveType.Datetime: sa.DATETIME, ydb.PrimitiveType.Timestamp: sa.TIMESTAMP, ydb.PrimitiveType.Interval: sa.INTEGER, ydb.PrimitiveType.Bool: sa.BOOLEAN, ydb.PrimitiveType.DyNumber: sa.TEXT, } def _get_column_info(t): nullable = False if isinstance(t, ydb.OptionalType): nullable = True t = t.item if isinstance(t, ydb.DecimalType): return sa.DECIMAL(precision=t.precision, scale=t.scale), nullable return COLUMN_TYPES[t], nullable
[docs] class YdbRequestSettingsCharacteristic(characteristics.ConnectionCharacteristic):
[docs] def reset_characteristic(self, dialect: "YqlDialect", dbapi_connection: ydb_dbapi.Connection) -> None: dialect.reset_ydb_request_settings(dbapi_connection)
[docs] def set_characteristic( self, dialect: "YqlDialect", dbapi_connection: ydb_dbapi.Connection, value: ydb.BaseRequestSettings ) -> None: dialect.set_ydb_request_settings(dbapi_connection, value)
[docs] def get_characteristic( self, dialect: "YqlDialect", dbapi_connection: ydb_dbapi.Connection ) -> ydb.BaseRequestSettings: return dialect.get_ydb_request_settings(dbapi_connection)
[docs] class YqlDialect(StrCompileDialect): name = "yql" driver = "ydb" supports_alter = False max_identifier_length = 63 supports_sane_rowcount = False supports_statement_cache = True supports_native_enum = False supports_native_boolean = True supports_native_decimal = True supports_smallserial = False supports_schemas = False supports_constraint_comments = False supports_json_type = True insert_returning = False update_returning = False delete_returning = False supports_sequences = False sequences_optional = False preexecute_autoincrement_sequences = True postfetch_lastrowid = False supports_default_values = False supports_empty_insert = False supports_multivalues_insert = True default_paramstyle = "qmark" isolation_level = None preparer = YqlIdentifierPreparer statement_compiler = YqlCompiler ddl_compiler = YqlDDLCompiler type_compiler = YqlTypeCompiler colspecs = { sa.types.JSON: types.YqlJSON, sa.types.JSON.JSONPathType: types.YqlJSON.YqlJSONPathType, sa.types.Date: types.YqlDate, sa.types.DateTime: types.YqlTimestamp, # Because YDB's DateTime doesn't store microseconds sa.types.DATETIME: types.YqlDateTime, sa.types.TIMESTAMP: types.YqlTimestamp, sa.types.DECIMAL: types.Decimal, } connection_characteristics = util.immutabledict( { "isolation_level": characteristics.IsolationLevelCharacteristic(), "ydb_request_settings": YdbRequestSettingsCharacteristic(), } ) construct_arguments = [ ( sa.schema.Table, { "auto_partitioning_by_size": None, "auto_partitioning_by_load": None, "auto_partitioning_partition_size_mb": None, "auto_partitioning_min_partitions_count": None, "auto_partitioning_max_partitions_count": None, "uniform_partitions": None, "partition_at_keys": None, }, ), ( sa.schema.Index, { "async": False, "cover": [], }, ), ]
[docs] @classmethod def import_dbapi(cls: Any): return ydb_dbapi
[docs] @classmethod def dbapi(cls): return cls.import_dbapi()
def __init__( self, json_serializer=None, json_deserializer=None, _add_declare_for_yql_stmt_vars=False, **kwargs, ): super().__init__(**kwargs) self._json_deserializer = json_deserializer self._json_serializer = json_serializer # NOTE: _add_declare_for_yql_stmt_vars is temporary and is soon to be removed. # no need in declare in yql statement here since ydb 24-1 self._add_declare_for_yql_stmt_vars = _add_declare_for_yql_stmt_vars def _describe_table(self, connection, table_name, schema=None) -> ydb.TableDescription: if schema is not None: raise ydb_dbapi.NotSupportedError("unsupported on non empty schema") qt = table_name if isinstance(table_name, str) else table_name.name raw_conn = connection.connection try: return raw_conn.describe(qt) except ydb_dbapi.DatabaseError as e: raise NoSuchTableError(qt) from e
[docs] def get_view_names(self, connection, schema=None, **kw: Any): return []
[docs] @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): table = self._describe_table(connection, table_name, schema) as_compatible = [] for column in table.columns: col_type, nullable = _get_column_info(column.type) as_compatible.append( { "name": column.name, "type": col_type, "nullable": nullable, "default": None, } ) return as_compatible
[docs] @reflection.cache def get_table_names(self, connection, schema=None, **kw): if schema: raise ydb_dbapi.NotSupportedError("unsupported on non empty schema") raw_conn = connection.connection return raw_conn.get_table_names()
[docs] @reflection.cache def has_table(self, connection, table_name, schema=None, **kwargs): try: self._describe_table(connection, table_name, schema) return True except NoSuchTableError: return False
[docs] @reflection.cache def get_pk_constraint(self, connection, table_name, schema=None, **kwargs): table = self._describe_table(connection, table_name, schema) return {"constrained_columns": table.primary_key, "name": None}
[docs] @reflection.cache def get_foreign_keys(self, connection, table_name, schema=None, **kwargs): # foreign keys unsupported return []
[docs] @reflection.cache def get_indexes(self, connection, table_name, schema=None, **kwargs): table = self._describe_table(connection, table_name, schema) indexes: list[ydb.TableIndex] = table.indexes if OLD_SA: sa_indexes: list[dict] = [] for index in indexes: sa_indexes.append( { "name": index.name, "column_names": index.index_columns, "unique": False, "dialect_options": { "ydb_async": False, # TODO After https://github.com/ydb-platform/ydb-python-sdk/issues/351 "ydb_cover": [], # TODO After https://github.com/ydb-platform/ydb-python-sdk/issues/409 }, } ) return sa_indexes sa_indexes: list[sa.engine.interfaces.ReflectedIndex] = [] for index in indexes: sa_indexes.append( sa.engine.interfaces.ReflectedIndex( name=index.name, column_names=index.index_columns, unique=False, dialect_options={ "ydb_async": False, # TODO After https://github.com/ydb-platform/ydb-python-sdk/issues/351 "ydb_cover": [], # TODO After https://github.com/ydb-platform/ydb-python-sdk/issues/409 }, ) ) return sa_indexes
[docs] def set_isolation_level(self, dbapi_connection: ydb_dbapi.Connection, level: str) -> None: dbapi_connection.set_isolation_level(level)
[docs] def get_default_isolation_level(self, dbapi_conn: ydb_dbapi.Connection) -> str: return ydb_dbapi.IsolationLevel.AUTOCOMMIT
[docs] def get_isolation_level(self, dbapi_connection: ydb_dbapi.Connection) -> str: return dbapi_connection.get_isolation_level()
[docs] def set_ydb_request_settings( self, dbapi_connection: ydb_dbapi.Connection, value: ydb.BaseRequestSettings, ) -> None: dbapi_connection.set_ydb_request_settings(value)
[docs] def reset_ydb_request_settings(self, dbapi_connection: ydb_dbapi.Connection): self.set_ydb_request_settings(dbapi_connection, ydb.BaseRequestSettings())
[docs] def get_ydb_request_settings(self, dbapi_connection: ydb_dbapi.Connection) -> ydb.BaseRequestSettings: return dbapi_connection.get_ydb_request_settings()
[docs] def create_connect_args(self, url): args, kwargs = super().create_connect_args(url) # YDB database name should start with '/' if "database" in kwargs: if not kwargs["database"].startswith("/"): kwargs["database"] = "/" + kwargs["database"] return [args, kwargs]
[docs] def connect(self, *cargs, **cparams): return self.dbapi.connect(*cargs, **cparams)
[docs] def do_begin(self, dbapi_connection: ydb_dbapi.Connection) -> None: dbapi_connection.begin()
[docs] def do_rollback(self, dbapi_connection: ydb_dbapi.Connection) -> None: dbapi_connection.rollback()
[docs] def do_commit(self, dbapi_connection: ydb_dbapi.Connection) -> None: dbapi_connection.commit()
def _handle_column_name(self, variable): return "`" + variable + "`" def _format_variables( self, statement: str, parameters: Optional[Union[Sequence[Mapping[str, Any]], Mapping[str, Any]]], execute_many: bool, ) -> Tuple[str, Optional[Union[Sequence[Mapping[str, Any]], Mapping[str, Any]]]]: formatted_statement = statement formatted_parameters = None if parameters: if execute_many: parameters_sequence: Sequence[Mapping[str, Any]] = parameters variable_names = set() formatted_parameters = [] for i in range(len(parameters_sequence)): variable_names.update(set(parameters_sequence[i].keys())) formatted_parameters.append({f"${k}": v for k, v in parameters_sequence[i].items()}) else: variable_names = set(parameters.keys()) formatted_parameters = {f"${k}": v for k, v in parameters.items()} formatted_variable_names = { variable_name: f"${self._handle_column_name(variable_name)}" for variable_name in variable_names } formatted_statement = formatted_statement % formatted_variable_names formatted_statement = formatted_statement.replace("%%", "%") return formatted_statement, formatted_parameters def _add_declare_for_yql_stmt_vars_impl(self, statement, parameters_types): declarations = "\n".join( [ f"DECLARE $`{param_name[1:] if param_name.startswith('$') else param_name}` as {str(param_type)};" for param_name, param_type in parameters_types.items() ] ) return f"{declarations}\n{statement}" def __merge_parameters_values_and_types( self, values: Mapping[str, Any], types: Mapping[str, Any], execute_many: bool ) -> Sequence[Mapping[str, ydb.TypedValue]]: if isinstance(values, collections.abc.Mapping): values = [values] result_list = [] for value_map in values: result = {} for key in value_map.keys(): if key in types: result[key] = ydb.TypedValue(value_map[key], types[key]) else: result[key] = value_map[key] result_list.append(result) return result_list if execute_many else result_list[0] def _prepare_ydb_query( self, statement: str, context: Optional[DefaultExecutionContext] = None, parameters: Optional[Union[Sequence[Mapping[str, Any]], Mapping[str, Any]]] = None, execute_many: bool = False, ) -> Tuple[Optional[Union[Sequence[Mapping[str, Any]], Mapping[str, Any]]]]: is_ddl = context.isddl if context is not None else False if not is_ddl and parameters: parameters_types = context.compiled.get_bind_types(parameters) if parameters_types != {}: parameters = self.__merge_parameters_values_and_types(parameters, parameters_types, execute_many) statement, parameters = self._format_variables(statement, parameters, execute_many) if self._add_declare_for_yql_stmt_vars: statement = self._add_declare_for_yql_stmt_vars_impl(statement, parameters_types) return statement, parameters statement, parameters = self._format_variables(statement, parameters, execute_many) return statement, parameters
[docs] def do_ping(self, dbapi_connection: ydb_dbapi.Connection) -> bool: cursor = dbapi_connection.cursor() statement, _ = self._prepare_ydb_query(self._dialect_specific_select_one) try: cursor.execute(statement) finally: cursor.close() return True
[docs] def do_executemany( self, cursor: ydb_dbapi.Cursor, statement: str, parameters: Optional[Sequence[Mapping[str, Any]]], context: Optional[DefaultExecutionContext] = None, ) -> None: operation, parameters = self._prepare_ydb_query(statement, context, parameters, execute_many=True) cursor.executemany(operation, parameters)
[docs] def do_execute( self, cursor: ydb_dbapi.Cursor, statement: str, parameters: Optional[Mapping[str, Any]] = None, context: Optional[DefaultExecutionContext] = None, ) -> None: operation, parameters = self._prepare_ydb_query(statement, context, parameters, execute_many=False) is_ddl = context.isddl if context is not None else False if is_ddl: cursor.execute_scheme(operation, parameters) else: cursor.execute(operation, parameters)
[docs] class AsyncYqlDialect(YqlDialect): driver = "ydb_async" is_async = True supports_statement_cache = True
[docs] def connect(self, *cargs, **cparams): return AdaptedAsyncConnection(util.await_only(self.dbapi.async_connect(*cargs, **cparams)))