# -*- coding: utf-8 -*-
import grpc
import logging
import os
from typing import Any # noqa
from . import credentials as credentials_impl, table, scheme, pool
from . import tracing
from . import iam
from . import _utilities
logger = logging.getLogger(__name__)
class RPCCompression:
"""Indicates the compression method to be used for an RPC."""
NoCompression = grpc.Compression.NoCompression
Deflate = grpc.Compression.Deflate
Gzip = grpc.Compression.Gzip
def default_credentials(credentials=None, tracer=None):
tracer = tracer if tracer is not None else tracing.Tracer(None)
with tracer.trace("Driver.default_credentials") as ctx:
if credentials is None:
ctx.trace({"credentials.anonymous": True})
return credentials_impl.AnonymousCredentials()
else:
ctx.trace({"credentials.prepared": True})
return credentials
def credentials_from_env_variables(tracer=None):
tracer = tracer if tracer is not None else tracing.Tracer(None)
with tracer.trace("Driver.credentials_from_env_variables") as ctx:
service_account_key_file = os.getenv("YDB_SERVICE_ACCOUNT_KEY_FILE_CREDENTIALS")
if service_account_key_file is not None:
ctx.trace({"credentials.service_account_key_file": True})
import ydb.iam
return ydb.iam.ServiceAccountCredentials.from_file(service_account_key_file)
anonymous_credetials = os.getenv("YDB_ANONYMOUS_CREDENTIALS", "0") == "1"
if anonymous_credetials:
ctx.trace({"credentials.anonymous": True})
return credentials_impl.AnonymousCredentials()
metadata_credentials = os.getenv("YDB_METADATA_CREDENTIALS", "0") == "1"
if metadata_credentials:
ctx.trace({"credentials.metadata": True})
return iam.MetadataUrlCredentials(tracer=tracer)
access_token = os.getenv("YDB_ACCESS_TOKEN_CREDENTIALS")
if access_token is not None:
ctx.trace({"credentials.access_token": True})
return credentials_impl.AuthTokenCredentials(access_token)
oauth2_key_file = os.getenv("YDB_OAUTH2_KEY_FILE")
if oauth2_key_file:
ctx.trace({"credentials.oauth2_key_file": True})
import ydb.oauth2_token_exchange
return ydb.oauth2_token_exchange.Oauth2TokenExchangeCredentials.from_file(oauth2_key_file)
ctx.trace(
{
"credentials.env_default": True,
"credentials.metadata": True,
}
)
return iam.MetadataUrlCredentials(tracer=tracer)
[docs]
class DriverConfig(object):
__slots__ = (
"endpoint",
"database",
"ca_cert",
"channel_options",
"credentials",
"use_all_nodes",
"root_certificates",
"certificate_chain",
"private_key",
"grpc_keep_alive_timeout",
"secure_channel",
"table_client_settings",
"topic_client_settings",
"query_client_settings",
"endpoints",
"primary_user_agent",
"tracer",
"grpc_lb_policy_name",
"discovery_request_timeout",
"compression",
)
def __init__(
self,
endpoint,
database=None,
ca_cert=None,
auth_token=None,
channel_options=None,
credentials=None,
use_all_nodes=False,
root_certificates=None,
certificate_chain=None,
private_key=None,
grpc_keep_alive_timeout=None,
table_client_settings=None,
topic_client_settings=None,
query_client_settings=None,
endpoints=None,
primary_user_agent="python-library",
tracer=None,
grpc_lb_policy_name="round_robin",
discovery_request_timeout=10,
compression=None,
):
"""
A driver config to initialize a driver instance
:param endpoint: A endpoint specified in pattern host:port to be used for initial channel initialization and for YDB endpoint discovery mechanism
:param database: A name of the database
:param ca_cert: A CA certificate when SSL should be used
:param auth_token: A authentication token
:param credentials: An instance of AbstractCredentials
:param use_all_nodes: A balancing policy that forces to use all available nodes.
:param root_certificates: The PEM-encoded root certificates as a byte string.
:param private_key: The PEM-encoded private key as a byte string, or None if no\
private key should be used.
:param certificate_chain: The PEM-encoded certificate chain as a byte string\
to use or or None if no certificate chain should be used.
:param grpc_keep_alive_timeout: GRpc KeepAlive timeout, ms
:param ydb.Tracer tracer: ydb.Tracer instance to trace requests in driver.\
If tracing aio ScopeManager must be ContextVarsScopeManager
:param grpc_lb_policy_name: A load balancing policy to be used for discovery channel construction. Default value is `round_round`
:param discovery_request_timeout: A default timeout to complete the discovery. The default value is 10 seconds.
"""
self.endpoint = endpoint
self.database = database
self.ca_cert = ca_cert
self.channel_options = channel_options
self.secure_channel = _utilities.is_secure_protocol(endpoint)
self.endpoint = _utilities.wrap_endpoint(self.endpoint)
self.endpoints = []
if endpoints is not None:
self.endpoints = [_utilities.wrap_endpoint(endp) for endp in endpoints]
if auth_token is not None:
credentials = credentials_impl.AuthTokenCredentials(auth_token)
self.credentials = credentials
self.use_all_nodes = use_all_nodes
self.root_certificates = root_certificates
self.certificate_chain = certificate_chain
self.private_key = private_key
self.grpc_keep_alive_timeout = grpc_keep_alive_timeout
self.table_client_settings = table_client_settings
self.topic_client_settings = topic_client_settings
self.query_client_settings = query_client_settings
self.primary_user_agent = primary_user_agent
self.tracer = tracer if tracer is not None else tracing.Tracer(None)
self.grpc_lb_policy_name = grpc_lb_policy_name
self.discovery_request_timeout = discovery_request_timeout
self.compression = compression
[docs]
def set_database(self, database):
self.database = database
return self
[docs]
@classmethod
def default_from_endpoint_and_database(cls, endpoint, database, root_certificates=None, credentials=None, **kwargs):
return cls(
endpoint,
database,
credentials=default_credentials(credentials),
root_certificates=root_certificates,
**kwargs,
)
[docs]
@classmethod
def default_from_connection_string(cls, connection_string, root_certificates=None, credentials=None, **kwargs):
endpoint, database = _utilities.parse_connection_string(connection_string)
return cls(
endpoint,
database,
credentials=default_credentials(credentials),
root_certificates=root_certificates,
**kwargs,
)
[docs]
def set_grpc_keep_alive_timeout(self, timeout):
self.grpc_keep_alive_timeout = timeout
return self
def _update_attrs_by_kwargs(self, **kwargs):
for key, value in kwargs.items():
if value is not None:
if getattr(self, key) is not None:
logger.warning(
f"Arg {key} was used in both DriverConfig and Driver. Value from Driver will be used."
)
setattr(self, key, value)
ConnectionParams = DriverConfig
def get_config(
driver_config=None,
connection_string=None,
endpoint=None,
database=None,
root_certificates=None,
credentials=None,
config_class=DriverConfig,
**kwargs,
):
if driver_config is None:
if connection_string is not None:
driver_config = config_class.default_from_connection_string(
connection_string, root_certificates, credentials, **kwargs
)
else:
driver_config = config_class.default_from_endpoint_and_database(
endpoint, database, root_certificates, credentials, **kwargs
)
else:
kwargs["endpoint"] = endpoint
kwargs["database"] = database
kwargs["root_certificates"] = root_certificates
kwargs["credentials"] = credentials
driver_config._update_attrs_by_kwargs(**kwargs)
if driver_config.credentials is not None:
driver_config.credentials._update_driver_config(driver_config)
return driver_config
[docs]
class Driver(pool.ConnectionPool):
__slots__ = ("scheme_client", "table_client")
def __init__(
self,
driver_config=None,
connection_string=None,
endpoint=None,
database=None,
root_certificates=None,
credentials=None,
**kwargs,
):
"""
Constructs a driver instance to be used in table and scheme clients.
It encapsulates endpoints discovery mechanism and provides ability to execute RPCs
on discovered endpoints
:param driver_config: A driver config
:param connection_string: A string in the following format: <protocol>://<hostame>:<port>/?database=/path/to/the/database
:param endpoint: An endpoint specified in the following format: <protocol>://<hostame>:<port>
:param database: A database path
:param credentials: A credentials. If not specifed credentials constructed by default.
"""
from . import topic # local import for prevent cycle import error
driver_config = get_config(
driver_config,
connection_string,
endpoint,
database,
root_certificates,
credentials,
)
super(Driver, self).__init__(driver_config)
self._credentials = driver_config.credentials
self.scheme_client = scheme.SchemeClient(self)
self.table_client = table.TableClient(self, driver_config.table_client_settings)
self.topic_client = topic.TopicClient(self, driver_config.topic_client_settings)
[docs]
def stop(self, timeout=10):
self.table_client._stop_pool_if_needed(timeout=timeout)
super().stop(timeout=timeout)