# -*- coding: utf-8 -*-
import decimal
from google.protobuf import struct_pb2
from . import issues, types, _apis
_SHIFT_BIT_COUNT = 64
_SHIFT = 2**64
_SIGN_BIT = 2**63
_DecimalNanRepr = 10**35 + 1
_DecimalInfRepr = 10**35
_DecimalSignedInfRepr = -(10**35)
_primitive_type_by_id = {}
_default_allow_truncated_result = False
def _initialize():
for pt in types.PrimitiveType:
_primitive_type_by_id[pt._idn_] = pt
_initialize()
class _DotDict(dict):
def __init__(self, *args, **kwargs):
super(_DotDict, self).__init__(*args, **kwargs)
def __getattr__(self, item):
return self[item]
def _is_decimal_signed(hi_value):
return (hi_value & _SIGN_BIT) == _SIGN_BIT
def _pb_to_decimal(type_pb, value_pb, table_client_settings):
hi = (value_pb.high_128 - (1 << _SHIFT_BIT_COUNT)) if _is_decimal_signed(value_pb.high_128) else value_pb.high_128
int128_value = value_pb.low_128 + (hi << _SHIFT_BIT_COUNT)
if int128_value == _DecimalNanRepr:
return decimal.Decimal("Nan")
elif int128_value == _DecimalInfRepr:
return decimal.Decimal("Inf")
elif int128_value == _DecimalSignedInfRepr:
return decimal.Decimal("-Inf")
return decimal.Decimal(int128_value) / decimal.Decimal(10**type_pb.decimal_type.scale)
def _pb_to_primitive(type_pb, value_pb, table_client_settings):
return _primitive_type_by_id.get(type_pb.type_id).get_value(value_pb, table_client_settings)
def _pb_to_optional(type_pb, value_pb, table_client_settings):
if value_pb.WhichOneof("value") == "null_flag_value":
return None
if value_pb.WhichOneof("value") == "nested_value":
return _to_native_value(type_pb.optional_type.item, value_pb.nested_value, table_client_settings)
return _to_native_value(type_pb.optional_type.item, value_pb, table_client_settings)
def _pb_to_list(type_pb, value_pb, table_client_settings):
return [
_to_native_value(type_pb.list_type.item, value_proto_item, table_client_settings)
for value_proto_item in value_pb.items
]
def _pb_to_tuple(type_pb, value_pb, table_client_settings):
return tuple(
_to_native_value(item_type, item_value, table_client_settings)
for item_type, item_value in zip(type_pb.tuple_type.elements, value_pb.items)
)
def _pb_to_dict(type_pb, value_pb, table_client_settings):
result = {}
for kv_pair in value_pb.pairs:
key = _to_native_value(type_pb.dict_type.key, kv_pair.key, table_client_settings)
payload = _to_native_value(type_pb.dict_type.payload, kv_pair.payload, table_client_settings)
result[key] = payload
return result
class _Struct(_DotDict):
pass
def _pb_to_struct(type_pb, value_pb, table_client_settings):
result = _Struct()
for member, item in zip(type_pb.struct_type.members, value_pb.items):
result[member.name] = _to_native_value(member.type, item, table_client_settings)
return result
def _pb_to_void(type_pb, value_pb, table_client_settings):
return None
_to_native_map = {
"type_id": _pb_to_primitive,
"decimal_type": _pb_to_decimal,
"optional_type": _pb_to_optional,
"list_type": _pb_to_list,
"tuple_type": _pb_to_tuple,
"dict_type": _pb_to_dict,
"struct_type": _pb_to_struct,
"void_type": _pb_to_void,
"empty_list_type": _pb_to_list,
"empty_dict_type": _pb_to_dict,
}
def _to_native_value(type_pb, value_pb, table_client_settings=None):
return _to_native_map.get(type_pb.WhichOneof("type"))(type_pb, value_pb, table_client_settings)
def _decimal_to_int128(value_type, value):
if value.is_nan():
return _DecimalNanRepr
elif value.is_infinite():
if value.is_signed():
return _DecimalSignedInfRepr
return _DecimalInfRepr
sign, digits, exponent = value.as_tuple()
int128_value = 0
digits_count = 0
for digit in digits:
int128_value *= 10
int128_value += digit
digits_count += 1
if value_type.decimal_type.scale + exponent < 0:
raise issues.GenericError("Couldn't parse decimal value, exponent is too large")
for _ in range(value_type.decimal_type.scale + exponent):
int128_value *= 10
digits_count += 1
if digits_count > value_type.decimal_type.precision + value_type.decimal_type.scale:
raise issues.GenericError("Couldn't parse decimal value, digits count > 35")
if sign:
int128_value *= -1
return int128_value
def _decimal_to_pb(value_type, value):
value_pb = _apis.ydb_value.Value()
int128_value = _decimal_to_int128(value_type, value)
if int128_value < 0:
value_pb.high_128 = (int128_value >> _SHIFT_BIT_COUNT) + (1 << _SHIFT_BIT_COUNT)
int128_value -= (int128_value >> _SHIFT_BIT_COUNT) << _SHIFT_BIT_COUNT
else:
value_pb.high_128 = int128_value >> _SHIFT_BIT_COUNT
int128_value -= value_pb.high_128 << _SHIFT_BIT_COUNT
value_pb.low_128 = int128_value
return value_pb
def _primitive_to_pb(type_pb, value):
value_pb = _apis.ydb_value.Value()
data_type = _primitive_type_by_id.get(type_pb.type_id)
data_type.set_value(value_pb, value)
return value_pb
def _optional_to_pb(type_pb, value):
if value is None:
return _apis.ydb_value.Value(null_flag_value=struct_pb2.NULL_VALUE)
return _from_native_value(type_pb.optional_type.item, value)
def _list_to_pb(type_pb, value):
value_pb = _apis.ydb_value.Value()
for element in value:
value_item_proto = value_pb.items.add()
value_item_proto.MergeFrom(_from_native_value(type_pb.list_type.item, element))
return value_pb
def _tuple_to_pb(type_pb, value):
value_pb = _apis.ydb_value.Value()
for element_type, element_value in zip(type_pb.tuple_type.elements, value):
value_item_proto = value_pb.items.add()
value_item_proto.MergeFrom(_from_native_value(element_type, element_value))
return value_pb
def _dict_to_pb(type_pb, value):
value_pb = _apis.ydb_value.Value()
for key, payload in value.items():
kv_pair = value_pb.pairs.add()
kv_pair.key.MergeFrom(_from_native_value(type_pb.dict_type.key, key))
if payload:
kv_pair.payload.MergeFrom(_from_native_value(type_pb.dict_type.payload, payload))
return value_pb
def _struct_to_pb(type_pb, value):
value_pb = _apis.ydb_value.Value()
for member in type_pb.struct_type.members:
value_item_proto = value_pb.items.add()
value_item = value[member.name] if isinstance(value, dict) else getattr(value, member.name)
value_item_proto.MergeFrom(_from_native_value(member.type, value_item))
return value_pb
_from_native_map = {
"type_id": _primitive_to_pb,
"decimal_type": _decimal_to_pb,
"optional_type": _optional_to_pb,
"list_type": _list_to_pb,
"tuple_type": _tuple_to_pb,
"dict_type": _dict_to_pb,
"struct_type": _struct_to_pb,
}
def _decimal_type_to_native(type_pb):
return types.DecimalType(type_pb.decimal_type.precision, type_pb.decimal_type.scale)
def _optional_type_to_native(type_pb):
return types.OptionalType(type_to_native(type_pb.optional_type.item))
def _list_type_to_native(type_pb):
return types.ListType(type_to_native(type_pb.list_type.item))
def _primitive_type_to_native(type_pb):
return _primitive_type_by_id.get(type_pb.type_id)
def _null_type_factory(type_pb):
return types.NullType()
_type_to_native_map = {
"optional_type": _optional_type_to_native,
"type_id": _primitive_type_to_native,
"decimal_type": _decimal_type_to_native,
"null_type": _null_type_factory,
"list_type": _list_type_to_native,
}
def type_to_native(type_pb):
return _type_to_native_map.get(type_pb.WhichOneof("type"))(type_pb)
def _from_native_value(type_pb, value):
return _from_native_map.get(type_pb.WhichOneof("type"))(type_pb, value)
def to_typed_value_from_native(type_pb, value):
typed_value = _apis.ydb_value.TypedValue()
typed_value.type.MergeFrom(type_pb)
typed_value.value.MergeFrom(from_native_value(type_pb, value))
return typed_value
def parameters_to_pb(parameters_types, parameters_values):
if parameters_values is None or not parameters_values:
return {}
param_values_pb = {}
for name, type_pb in parameters_types.items():
result = _apis.ydb_value.TypedValue()
ttype = type_pb
if isinstance(type_pb, types.AbstractTypeBuilder):
ttype = type_pb.proto
elif isinstance(type_pb, types.PrimitiveType):
ttype = type_pb.proto
result.type.MergeFrom(ttype)
result.value.MergeFrom(_from_native_value(ttype, parameters_values[name]))
param_values_pb[name] = result
return param_values_pb
def query_parameters_to_pb(parameters):
if parameters is None or not parameters:
return {}
parameters_types = {}
parameters_values = {}
for name, value in parameters.items():
if isinstance(value, types.TypedValue):
if value.value_type is None:
value.value_type = _type_from_python_native(value.value)
elif isinstance(value, tuple):
value = types.TypedValue(*value)
else:
value = types.TypedValue(value, _type_from_python_native(value))
parameters_values[name] = value.value
parameters_types[name] = value.value_type
return parameters_to_pb(parameters_types, parameters_values)
_from_python_type_map = {
int: types.PrimitiveType.Int64,
float: types.PrimitiveType.Double,
bool: types.PrimitiveType.Bool,
str: types.PrimitiveType.Utf8,
bytes: types.PrimitiveType.String,
}
def _type_from_python_native(value):
t = type(value)
if t in _from_python_type_map:
return _from_python_type_map[t]
if t == list:
if len(value) == 0:
raise ValueError(
"Could not map empty list to any type, please specify "
"it manually by tuple(value, type) or ydb.TypedValue"
)
entry_type = _type_from_python_native(value[0])
return types.ListType(entry_type)
if t == dict:
if len(value) == 0:
raise ValueError(
"Could not map empty dict to any type, please specify "
"it manually by tuple(value, type) or ydb.TypedValue"
)
entry = list(value.items())[0]
key_type = _type_from_python_native(entry[0])
value_type = _type_from_python_native(entry[1])
return types.DictType(key_type, value_type)
raise ValueError(
"Could not map value to any type, please specify it manually by tuple(value, type) or ydb.TypedValue"
)
def _unwrap_optionality(column):
c_type = column.type
current_type = c_type.WhichOneof("type")
while current_type == "optional_type":
c_type = c_type.optional_type.item
current_type = c_type.WhichOneof("type")
return _to_native_map.get(current_type), c_type
[docs]
class _ResultSet(object):
__slots__ = ("columns", "rows", "truncated", "snapshot")
def __init__(self, columns, rows, truncated, snapshot=None):
self.columns = columns
self.rows = rows
self.truncated = truncated
self.snapshot = snapshot
[docs]
@classmethod
def from_message(cls, message, table_client_settings=None, snapshot=None):
rows = []
# prepare column parsers before actuall parsing
column_parsers = []
if len(message.rows) > 0:
for column in message.columns:
column_parsers.append(_unwrap_optionality(column))
for row_proto in message.rows:
row = _Row(message.columns)
for column, value, column_info in zip(message.columns, row_proto.items, column_parsers):
v_type = value.WhichOneof("value")
if v_type == "null_flag_value":
row[column.name] = None
continue
while v_type == "nested_value":
value = value.nested_value
v_type = value.WhichOneof("value")
column_parser, unwrapped_type = column_info
row[column.name] = column_parser(unwrapped_type, value, table_client_settings)
rows.append(row)
return cls(message.columns, rows, message.truncated, snapshot)
[docs]
@classmethod
def lazy_from_message(cls, message, table_client_settings=None, snapshot=None):
rows = _LazyRows(message.rows, table_client_settings, message.columns)
return cls(message.columns, rows, message.truncated, snapshot)
ResultSet = _ResultSet
class _Row(_DotDict):
def __init__(self, columns):
super(_Row, self).__init__()
self._columns = columns
def __getitem__(self, key):
if isinstance(key, int):
return self[self._columns[key].name]
elif isinstance(key, slice):
return tuple(map(lambda x: self[x.name], self._columns[key]))
else:
return super(_Row, self).__getitem__(key)
class _LazyRowItem:
__slots__ = ["_item", "_type", "_table_client_settings", "_processed", "_parser"]
def __init__(self, proto_item, proto_type, table_client_settings, parser):
self._item = proto_item
self._type = proto_type
self._table_client_settings = table_client_settings
self._processed = False
self._parser = parser
def get(self):
if not self._processed:
self._item = self._parser(self._type, self._item, self._table_client_settings)
self._processed = True
return self._item
class _LazyRow(_DotDict):
def __init__(self, columns, proto_row, table_client_settings, parsers):
super(_LazyRow, self).__init__()
self._columns = columns
self._table_client_settings = table_client_settings
for i, (column, row_item) in enumerate(zip(self._columns, proto_row.items)):
super(_LazyRow, self).__setitem__(
column.name,
_LazyRowItem(row_item, column.type, table_client_settings, parsers[i]),
)
def __setitem__(self, key, value):
raise NotImplementedError("Cannot insert values into lazy row")
def __getitem__(self, key):
if isinstance(key, int):
return self[self._columns[key].name]
elif isinstance(key, slice):
return tuple(map(lambda x: self[x.name], self._columns[key]))
else:
return super(_LazyRow, self).__getitem__(key).get()
def __iter__(self):
return super(_LazyRow, self).__iter__()
def __next__(self):
return super(_LazyRow, self).__next__().get()
def next(self):
return self.__next__()
def from_native_value(type_pb, value):
return _from_native_value(type_pb, value)
def to_native_value(typed_value):
return _to_native_value(typed_value.type, typed_value.value)
class _LazyRows:
def __init__(self, rows, table_client_settings, columns):
self._rows = rows
self._parsers = [_LazyParser(columns, i) for i in range(len(columns))]
self._table_client_settings = table_client_settings
self._columns = columns
def __len__(self):
return len(self._rows)
def fetchone(self):
return _LazyRow(self._columns, self._rows[0], self._table_client_settings, self._parsers)
def fetchmany(self, number):
for index in range(min(len(self), number)):
yield _LazyRow(
self._columns,
self._rows[index],
self._table_client_settings,
self._parsers,
)
def __iter__(self):
for row in self.fetchmany(len(self)):
yield row
def fetchall(self):
for row in self:
yield row
class _LazyParser:
__slots__ = ["_columns", "_column_index", "_prepared"]
def __init__(self, columns, column_index):
self._columns = columns
self._column_index = column_index
self._prepared = None
def __call__(self, *args, **kwargs):
if self._prepared is None:
self._prepared = _to_native_map.get(self._columns[self._column_index].type.WhichOneof("type"))
return self._prepared(*args, **kwargs)
class ResultSets(list):
def __init__(self, result_sets_pb, table_client_settings=None):
make_lazy = False if table_client_settings is None else table_client_settings._make_result_sets_lazy
allow_truncated_result = _default_allow_truncated_result
if table_client_settings:
allow_truncated_result = table_client_settings._allow_truncated_result
result_sets = []
initializer = _ResultSet.from_message if not make_lazy else _ResultSet.lazy_from_message
for result_set in result_sets_pb:
result_set = initializer(result_set, table_client_settings)
if result_set.truncated and not allow_truncated_result:
raise issues.TruncatedResponseError("Response for the request was truncated by server")
result_sets.append(result_set)
super(ResultSets, self).__init__(result_sets)