import abc
import asyncio
import enum
import functools
from collections import defaultdict
import typing
from typing import (
Optional,
)
from .._grpc.grpcwrapper import ydb_query
from .._grpc.grpcwrapper.ydb_query_public_types import (
BaseQueryTxMode,
)
from ..connection import _RpcState as RpcState
from .. import convert
from .. import issues
from .. import _utilities
from .. import _apis
from ydb._topic_common.common import CallFromSyncToAsync, _get_shared_event_loop
from ydb._grpc.grpcwrapper.common_utils import to_thread
if typing.TYPE_CHECKING:
from .transaction import BaseQueryTxContext
class QuerySyntax(enum.IntEnum):
UNSPECIFIED = 0
YQL_V1 = 1
PG = 2
class QueryExecMode(enum.IntEnum):
UNSPECIFIED = 0
PARSE = 10
VALIDATE = 20
EXPLAIN = 30
EXECUTE = 50
class StatsMode(enum.IntEnum):
UNSPECIFIED = 0
NONE = 10
BASIC = 20
FULL = 30
PROFILE = 40
class SyncResponseContextIterator(_utilities.SyncResponseIterator):
def __enter__(self) -> "SyncResponseContextIterator":
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# To close stream on YDB it is necessary to scroll through it to the end
for _ in self:
pass
[docs]
class QueryClientSettings:
def __init__(self):
self._native_datetime_in_result_sets = True
self._native_date_in_result_sets = True
self._native_json_in_result_sets = True
self._native_interval_in_result_sets = True
self._native_timestamp_in_result_sets = True
[docs]
def with_native_timestamp_in_result_sets(self, enabled: bool) -> "QueryClientSettings":
self._native_timestamp_in_result_sets = enabled
return self
[docs]
def with_native_interval_in_result_sets(self, enabled: bool) -> "QueryClientSettings":
self._native_interval_in_result_sets = enabled
return self
[docs]
def with_native_json_in_result_sets(self, enabled: bool) -> "QueryClientSettings":
self._native_json_in_result_sets = enabled
return self
[docs]
def with_native_date_in_result_sets(self, enabled: bool) -> "QueryClientSettings":
self._native_date_in_result_sets = enabled
return self
[docs]
def with_native_datetime_in_result_sets(self, enabled: bool) -> "QueryClientSettings":
self._native_datetime_in_result_sets = enabled
return self
class IQuerySessionState(abc.ABC):
def __init__(self, settings: Optional[QueryClientSettings] = None):
pass
@abc.abstractmethod
def reset(self) -> None:
pass
@property
@abc.abstractmethod
def session_id(self) -> Optional[str]:
pass
@abc.abstractmethod
def set_session_id(self, session_id: str) -> "IQuerySessionState":
pass
@property
@abc.abstractmethod
def node_id(self) -> Optional[int]:
pass
@abc.abstractmethod
def set_node_id(self, node_id: int) -> "IQuerySessionState":
pass
@property
@abc.abstractmethod
def attached(self) -> bool:
pass
@abc.abstractmethod
def set_attached(self, attached: bool) -> "IQuerySessionState":
pass
def create_execute_query_request(
query: str,
session_id: str,
tx_id: Optional[str],
commit_tx: Optional[bool],
tx_mode: Optional[BaseQueryTxMode],
syntax: Optional[QuerySyntax],
exec_mode: Optional[QueryExecMode],
parameters: Optional[dict],
concurrent_result_sets: Optional[bool],
) -> ydb_query.ExecuteQueryRequest:
syntax = QuerySyntax.YQL_V1 if not syntax else syntax
exec_mode = QueryExecMode.EXECUTE if not exec_mode else exec_mode
stats_mode = StatsMode.NONE # TODO: choise is not supported yet
tx_control = None
if not tx_id and not tx_mode:
tx_control = None
elif tx_id:
tx_control = ydb_query.TransactionControl(
tx_id=tx_id,
commit_tx=commit_tx,
begin_tx=None,
)
else:
tx_control = ydb_query.TransactionControl(
begin_tx=ydb_query.TransactionSettings(
tx_mode=tx_mode,
),
commit_tx=commit_tx,
tx_id=None,
)
return ydb_query.ExecuteQueryRequest(
session_id=session_id,
query_content=ydb_query.QueryContent.from_public(
query=query,
syntax=syntax,
),
tx_control=tx_control,
exec_mode=exec_mode,
parameters=parameters,
concurrent_result_sets=concurrent_result_sets,
stats_mode=stats_mode,
)
def bad_session_handler(func):
@functools.wraps(func)
def decorator(rpc_state, response_pb, session_state: IQuerySessionState, *args, **kwargs):
try:
return func(rpc_state, response_pb, session_state, *args, **kwargs)
except issues.BadSession:
session_state.reset()
raise
return decorator
@bad_session_handler
def wrap_execute_query_response(
rpc_state: RpcState,
response_pb: _apis.ydb_query.ExecuteQueryResponsePart,
session_state: IQuerySessionState,
tx: Optional["BaseQueryTxContext"] = None,
commit_tx: Optional[bool] = False,
settings: Optional[QueryClientSettings] = None,
) -> convert.ResultSet:
issues._process_response(response_pb)
if tx and commit_tx:
tx._move_to_commited()
elif tx and response_pb.tx_meta and not tx.tx_id:
tx._move_to_beginned(response_pb.tx_meta.id)
if response_pb.HasField("result_set"):
return convert.ResultSet.from_message(response_pb.result_set, settings)
return None
class TxEvent(enum.Enum):
BEFORE_COMMIT = "BEFORE_COMMIT"
AFTER_COMMIT = "AFTER_COMMIT"
BEFORE_ROLLBACK = "BEFORE_ROLLBACK"
AFTER_ROLLBACK = "AFTER_ROLLBACK"
class CallbackHandlerMode(enum.Enum):
SYNC = "SYNC"
ASYNC = "ASYNC"
def _get_sync_callback(method: typing.Callable, loop: Optional[asyncio.AbstractEventLoop]):
if asyncio.iscoroutinefunction(method):
if loop is None:
loop = _get_shared_event_loop()
def async_to_sync_callback(*args, **kwargs):
caller = CallFromSyncToAsync(loop)
return caller.safe_call_with_result(method(*args, **kwargs), 10)
return async_to_sync_callback
return method
def _get_async_callback(method: typing.Callable):
if asyncio.iscoroutinefunction(method):
return method
async def sync_to_async_callback(*args, **kwargs):
return await to_thread(method, *args, **kwargs, executor=None)
return sync_to_async_callback
class CallbackHandler:
def _init_callback_handler(self, mode: CallbackHandlerMode) -> None:
self._callbacks = defaultdict(list)
self._callback_mode = mode
def _execute_callbacks_sync(self, event_name: str, *args, **kwargs) -> None:
for callback in self._callbacks[event_name]:
callback(self, *args, **kwargs)
async def _execute_callbacks_async(self, event_name: str, *args, **kwargs) -> None:
tasks = [asyncio.create_task(callback(self, *args, **kwargs)) for callback in self._callbacks[event_name]]
if not tasks:
return
await asyncio.gather(*tasks)
def _prepare_callback(
self, callback: typing.Callable, loop: Optional[asyncio.AbstractEventLoop]
) -> typing.Callable:
if self._callback_mode == CallbackHandlerMode.SYNC:
return _get_sync_callback(callback, loop)
return _get_async_callback(callback)
def _add_callback(self, event_name: str, callback: typing.Callable, loop: Optional[asyncio.AbstractEventLoop]):
self._callbacks[event_name].append(self._prepare_callback(callback, loop))