Source code for ydb.query.transaction

import abc
import logging
import enum
import functools
from typing import (
    Iterable,
    Optional,
)

from .. import (
    _apis,
    issues,
)
from .._grpc.grpcwrapper import ydb_query as _ydb_query
from ..connection import _RpcState as RpcState

from . import base
from ..settings import BaseRequestSettings

logger = logging.getLogger(__name__)


class QueryTxStateEnum(enum.Enum):
    NOT_INITIALIZED = "NOT_INITIALIZED"
    BEGINED = "BEGINED"
    COMMITTED = "COMMITTED"
    ROLLBACKED = "ROLLBACKED"
    DEAD = "DEAD"


class QueryTxStateHelper(abc.ABC):
    _VALID_TRANSITIONS = {
        QueryTxStateEnum.NOT_INITIALIZED: [
            QueryTxStateEnum.BEGINED,
            QueryTxStateEnum.DEAD,
            QueryTxStateEnum.COMMITTED,
            QueryTxStateEnum.ROLLBACKED,
        ],
        QueryTxStateEnum.BEGINED: [QueryTxStateEnum.COMMITTED, QueryTxStateEnum.ROLLBACKED, QueryTxStateEnum.DEAD],
        QueryTxStateEnum.COMMITTED: [],
        QueryTxStateEnum.ROLLBACKED: [],
        QueryTxStateEnum.DEAD: [],
    }

    @classmethod
    def valid_transition(cls, before: QueryTxStateEnum, after: QueryTxStateEnum) -> bool:
        return after in cls._VALID_TRANSITIONS[before]

    @classmethod
    def terminal(cls, state: QueryTxStateEnum) -> bool:
        return len(cls._VALID_TRANSITIONS[state]) == 0


def reset_tx_id_handler(func):
    @functools.wraps(func)
    def decorator(
        rpc_state, response_pb, session_state: base.IQuerySessionState, tx_state: QueryTxState, *args, **kwargs
    ):
        try:
            return func(rpc_state, response_pb, session_state, tx_state, *args, **kwargs)
        except issues.Error:
            tx_state._change_state(QueryTxStateEnum.DEAD)
            tx_state.tx_id = None
            raise

    return decorator


class QueryTxState:
    def __init__(self, tx_mode: base.BaseQueryTxMode):
        """
        Holds transaction context manager info
        :param tx_mode: A mode of transaction
        """
        self.tx_id = None
        self.tx_mode = tx_mode
        self._state = QueryTxStateEnum.NOT_INITIALIZED

    def _check_invalid_transition(self, target: QueryTxStateEnum) -> None:
        if not QueryTxStateHelper.valid_transition(self._state, target):
            raise RuntimeError(f"Transaction could not be moved from {self._state.value} to {target.value}")

    def _change_state(self, target: QueryTxStateEnum) -> None:
        self._check_invalid_transition(target)
        self._state = target

    def _check_tx_ready_to_use(self) -> None:
        if QueryTxStateHelper.terminal(self._state):
            raise RuntimeError(f"Transaction is in terminal state: {self._state.value}")

    def _already_in(self, target: QueryTxStateEnum) -> bool:
        return self._state == target


def _construct_tx_settings(tx_state: QueryTxState) -> _ydb_query.TransactionSettings:
    tx_settings = _ydb_query.TransactionSettings.from_public(tx_state.tx_mode)
    return tx_settings


def _create_begin_transaction_request(
    session_state: base.IQuerySessionState, tx_state: QueryTxState
) -> _apis.ydb_query.BeginTransactionRequest:
    request = _ydb_query.BeginTransactionRequest(
        session_id=session_state.session_id,
        tx_settings=_construct_tx_settings(tx_state),
    ).to_proto()
    return request


def _create_commit_transaction_request(
    session_state: base.IQuerySessionState, tx_state: QueryTxState
) -> _apis.ydb_query.CommitTransactionRequest:
    request = _apis.ydb_query.CommitTransactionRequest()
    request.tx_id = tx_state.tx_id
    request.session_id = session_state.session_id
    return request


def _create_rollback_transaction_request(
    session_state: base.IQuerySessionState, tx_state: QueryTxState
) -> _apis.ydb_query.RollbackTransactionRequest:
    request = _apis.ydb_query.RollbackTransactionRequest()
    request.tx_id = tx_state.tx_id
    request.session_id = session_state.session_id
    return request


@base.bad_session_handler
def wrap_tx_begin_response(
    rpc_state: RpcState,
    response_pb: _apis.ydb_query.BeginTransactionResponse,
    session_state: base.IQuerySessionState,
    tx_state: QueryTxState,
    tx: "BaseQueryTxContext",
) -> "BaseQueryTxContext":
    message = _ydb_query.BeginTransactionResponse.from_proto(response_pb)
    issues._process_response(message.status)
    tx_state._change_state(QueryTxStateEnum.BEGINED)
    tx_state.tx_id = message.tx_meta.tx_id
    return tx


@base.bad_session_handler
@reset_tx_id_handler
def wrap_tx_commit_response(
    rpc_state: RpcState,
    response_pb: _apis.ydb_query.CommitTransactionResponse,
    session_state: base.IQuerySessionState,
    tx_state: QueryTxState,
    tx: "BaseQueryTxContext",
) -> "BaseQueryTxContext":
    message = _ydb_query.CommitTransactionResponse.from_proto(response_pb)
    issues._process_response(message.status)
    tx_state._change_state(QueryTxStateEnum.COMMITTED)
    return tx


@base.bad_session_handler
@reset_tx_id_handler
def wrap_tx_rollback_response(
    rpc_state: RpcState,
    response_pb: _apis.ydb_query.RollbackTransactionResponse,
    session_state: base.IQuerySessionState,
    tx_state: QueryTxState,
    tx: "BaseQueryTxContext",
) -> "BaseQueryTxContext":
    message = _ydb_query.RollbackTransactionResponse.from_proto(response_pb)
    issues._process_response(message.status)
    tx_state._change_state(QueryTxStateEnum.ROLLBACKED)
    return tx


class BaseQueryTxContext:
    def __init__(self, driver, session_state, session, tx_mode):
        """
        An object that provides a simple transaction context manager that allows statements execution
        in a transaction. You don't have to open transaction explicitly, because context manager encapsulates
        transaction control logic, and opens new transaction if:

        1) By explicit .begin() method;
        2) On execution of a first statement, which is strictly recommended method, because that avoids useless round trip

        This context manager is not thread-safe, so you should not manipulate on it concurrently.

        :param driver: A driver instance
        :param session_state: A state of session
        :param tx_mode: Transaction mode, which is a one from the following choises:
         1) QuerySerializableReadWrite() which is default mode;
         2) QueryOnlineReadOnly(allow_inconsistent_reads=False);
         3) QuerySnapshotReadOnly();
         4) QueryStaleReadOnly().
        """

        self._driver = driver
        self._tx_state = QueryTxState(tx_mode)
        self._session_state = session_state
        self.session = session
        self._prev_stream = None

    @property
    def session_id(self) -> str:
        """
        A transaction's session id

        :return: A transaction's session id
        """
        return self._session_state.session_id

    @property
    def tx_id(self) -> Optional[str]:
        """
        Returns an id of open transaction or None otherwise

        :return: An id of open transaction or None otherwise
        """
        return self._tx_state.tx_id

    def _begin_call(self, settings: Optional[BaseRequestSettings]) -> "BaseQueryTxContext":
        self._tx_state._check_invalid_transition(QueryTxStateEnum.BEGINED)

        return self._driver(
            _create_begin_transaction_request(self._session_state, self._tx_state),
            _apis.QueryService.Stub,
            _apis.QueryService.BeginTransaction,
            wrap_tx_begin_response,
            settings,
            (self._session_state, self._tx_state, self),
        )

    def _commit_call(self, settings: Optional[BaseRequestSettings]) -> "BaseQueryTxContext":
        self._tx_state._check_invalid_transition(QueryTxStateEnum.COMMITTED)

        return self._driver(
            _create_commit_transaction_request(self._session_state, self._tx_state),
            _apis.QueryService.Stub,
            _apis.QueryService.CommitTransaction,
            wrap_tx_commit_response,
            settings,
            (self._session_state, self._tx_state, self),
        )

    def _rollback_call(self, settings: Optional[BaseRequestSettings]) -> "BaseQueryTxContext":
        self._tx_state._check_invalid_transition(QueryTxStateEnum.ROLLBACKED)

        return self._driver(
            _create_rollback_transaction_request(self._session_state, self._tx_state),
            _apis.QueryService.Stub,
            _apis.QueryService.RollbackTransaction,
            wrap_tx_rollback_response,
            settings,
            (self._session_state, self._tx_state, self),
        )

    def _execute_call(
        self,
        query: str,
        commit_tx: Optional[bool],
        syntax: Optional[base.QuerySyntax],
        exec_mode: Optional[base.QueryExecMode],
        parameters: Optional[dict],
        concurrent_result_sets: Optional[bool],
        settings: Optional[BaseRequestSettings],
    ) -> Iterable[_apis.ydb_query.ExecuteQueryResponsePart]:
        self._tx_state._check_tx_ready_to_use()

        request = base.create_execute_query_request(
            query=query,
            session_id=self._session_state.session_id,
            commit_tx=commit_tx,
            tx_id=self._tx_state.tx_id,
            tx_mode=self._tx_state.tx_mode,
            syntax=syntax,
            exec_mode=exec_mode,
            parameters=parameters,
            concurrent_result_sets=concurrent_result_sets,
        )

        return self._driver(
            request.to_proto(),
            _apis.QueryService.Stub,
            _apis.QueryService.ExecuteQuery,
            settings=settings,
        )

    def _move_to_beginned(self, tx_id: str) -> None:
        if self._tx_state._already_in(QueryTxStateEnum.BEGINED):
            return
        self._tx_state._change_state(QueryTxStateEnum.BEGINED)
        self._tx_state.tx_id = tx_id

    def _move_to_commited(self) -> None:
        if self._tx_state._already_in(QueryTxStateEnum.COMMITTED):
            return
        self._tx_state._change_state(QueryTxStateEnum.COMMITTED)


[docs] class QueryTxContext(BaseQueryTxContext): def __enter__(self) -> "BaseQueryTxContext": """ Enters a context manager and returns a transaction :return: A transaction instance """ return self def __exit__(self, *args, **kwargs): """ Closes a transaction context manager and rollbacks transaction if it is not finished explicitly """ self._ensure_prev_stream_finished() if self._tx_state._state == QueryTxStateEnum.BEGINED: # It's strictly recommended to close transactions directly # by using commit_tx=True flag while executing statement or by # .commit() or .rollback() methods, but here we trying to do best # effort to avoid useless open transactions logger.warning("Potentially leaked tx: %s", self._tx_state.tx_id) try: self.rollback() except issues.Error: logger.warning("Failed to rollback leaked tx: %s", self._tx_state.tx_id) def _ensure_prev_stream_finished(self) -> None: if self._prev_stream is not None: with self._prev_stream: pass self._prev_stream = None
[docs] def begin(self, settings: Optional[BaseRequestSettings] = None) -> "QueryTxContext": """Explicitly begins a transaction :param settings: An additional request settings BaseRequestSettings; :return: Transaction object or exception if begin is failed """ self._begin_call(settings) return self
[docs] def commit(self, settings: Optional[BaseRequestSettings] = None) -> None: """Calls commit on a transaction if it is open otherwise is no-op. If transaction execution failed then this method raises PreconditionFailed. :param settings: An additional request settings BaseRequestSettings; :return: A committed transaction or exception if commit is failed """ if self._tx_state._already_in(QueryTxStateEnum.COMMITTED): return if self._tx_state._state == QueryTxStateEnum.NOT_INITIALIZED: self._tx_state._change_state(QueryTxStateEnum.COMMITTED) return self._ensure_prev_stream_finished() self._commit_call(settings)
[docs] def rollback(self, settings: Optional[BaseRequestSettings] = None) -> None: """Calls rollback on a transaction if it is open otherwise is no-op. If transaction execution failed then this method raises PreconditionFailed. :param settings: An additional request settings BaseRequestSettings; :return: A committed transaction or exception if commit is failed """ if self._tx_state._already_in(QueryTxStateEnum.ROLLBACKED): return if self._tx_state._state == QueryTxStateEnum.NOT_INITIALIZED: self._tx_state._change_state(QueryTxStateEnum.ROLLBACKED) return self._ensure_prev_stream_finished() self._rollback_call(settings)
[docs] def execute( self, query: str, parameters: Optional[dict] = None, commit_tx: Optional[bool] = False, syntax: Optional[base.QuerySyntax] = None, exec_mode: Optional[base.QueryExecMode] = None, concurrent_result_sets: Optional[bool] = False, settings: Optional[BaseRequestSettings] = None, ) -> base.SyncResponseContextIterator: """Sends a query to Query Service :param query: (YQL or SQL text) to be executed. :param parameters: dict with parameters and YDB types; :param commit_tx: A special flag that allows transaction commit. :param syntax: Syntax of the query, which is a one from the following choises: 1) QuerySyntax.YQL_V1, which is default; 2) QuerySyntax.PG. :param exec_mode: Exec mode of the query, which is a one from the following choises: 1) QueryExecMode.EXECUTE, which is default; 2) QueryExecMode.EXPLAIN; 3) QueryExecMode.VALIDATE; 4) QueryExecMode.PARSE. :param concurrent_result_sets: A flag to allow YDB mix parts of different result sets. Default is False; :param settings: An additional request settings BaseRequestSettings; :return: Iterator with result sets """ self._ensure_prev_stream_finished() stream_it = self._execute_call( query=query, commit_tx=commit_tx, syntax=syntax, exec_mode=exec_mode, parameters=parameters, concurrent_result_sets=concurrent_result_sets, settings=settings, ) self._prev_stream = base.SyncResponseContextIterator( stream_it, lambda resp: base.wrap_execute_query_response( rpc_state=None, response_pb=resp, session_state=self._session_state, tx=self, commit_tx=commit_tx, settings=self.session._settings, ), ) return self._prev_stream