From 0bcd65daea3bc8ecda992c2e1b023a2cc5de4d1c Mon Sep 17 00:00:00 2001 From: Drew DeVault Date: Tue, 16 Nov 2021 08:23:10 +0100 Subject: [PATCH 1/2] all: remove support for nonfree Redis modules --- redis/__init__.py | 12 - redis/client.py | 11 +- redis/cluster.py | 2150 ----------------- redis/commands/__init__.py | 8 +- redis/commands/bf/__init__.py | 204 -- redis/commands/bf/commands.py | 498 ---- redis/commands/bf/info.py | 85 - redis/commands/cluster.py | 412 ---- redis/commands/graph/__init__.py | 162 -- redis/commands/graph/commands.py | 202 -- redis/commands/graph/edge.py | 87 - redis/commands/graph/exceptions.py | 3 - redis/commands/graph/node.py | 84 - redis/commands/graph/path.py | 74 - redis/commands/graph/query_result.py | 362 --- redis/commands/json/__init__.py | 118 - redis/commands/json/commands.py | 329 --- redis/commands/json/decoders.py | 60 - redis/commands/json/path.py | 16 - redis/commands/redismodules.py | 83 - redis/commands/search/__init__.py | 96 - redis/commands/search/_util.py | 7 - redis/commands/search/aggregation.py | 357 --- redis/commands/search/commands.py | 790 ------- redis/commands/search/document.py | 13 - redis/commands/search/field.py | 92 - redis/commands/search/indexDefinition.py | 79 - redis/commands/search/query.py | 322 --- redis/commands/search/querystring.py | 314 --- redis/commands/search/reducers.py | 178 -- redis/commands/search/result.py | 73 - redis/commands/search/suggestion.py | 51 - redis/commands/sentinel.py | 93 - redis/commands/timeseries/__init__.py | 80 - redis/commands/timeseries/commands.py | 768 ------- redis/commands/timeseries/info.py | 82 - redis/commands/timeseries/utils.py | 44 - redis/sentinel.py | 337 --- setup.py | 5 - tests/test_bloom.py | 383 ---- tests/test_cluster.py | 2664 ---------------------- tests/test_commands.py | 12 - tests/test_connection.py | 22 - tests/test_graph.py | 477 ---- tests/test_graph_utils/__init__.py | 0 tests/test_graph_utils/test_edge.py | 77 - tests/test_graph_utils/test_node.py | 52 - tests/test_graph_utils/test_path.py | 91 - tests/test_json.py | 1432 ------------ tests/test_pubsub.py | 11 - tests/test_search.py | 1457 ------------ tests/test_sentinel.py | 234 -- tests/test_timeseries.py | 514 ----- 53 files changed, 5 insertions(+), 16162 deletions(-) delete mode 100644 redis/cluster.py delete mode 100644 redis/commands/bf/__init__.py delete mode 100644 redis/commands/bf/commands.py delete mode 100644 redis/commands/bf/info.py delete mode 100644 redis/commands/cluster.py delete mode 100644 redis/commands/graph/__init__.py delete mode 100644 redis/commands/graph/commands.py delete mode 100644 redis/commands/graph/edge.py delete mode 100644 redis/commands/graph/exceptions.py delete mode 100644 redis/commands/graph/node.py delete mode 100644 redis/commands/graph/path.py delete mode 100644 redis/commands/graph/query_result.py delete mode 100644 redis/commands/json/__init__.py delete mode 100644 redis/commands/json/commands.py delete mode 100644 redis/commands/json/decoders.py delete mode 100644 redis/commands/json/path.py delete mode 100644 redis/commands/redismodules.py delete mode 100644 redis/commands/search/__init__.py delete mode 100644 redis/commands/search/_util.py delete mode 100644 redis/commands/search/aggregation.py delete mode 100644 redis/commands/search/commands.py delete mode 100644 redis/commands/search/document.py delete mode 100644 redis/commands/search/field.py delete mode 100644 redis/commands/search/indexDefinition.py delete mode 100644 redis/commands/search/query.py delete mode 100644 redis/commands/search/querystring.py delete mode 100644 redis/commands/search/reducers.py delete mode 100644 redis/commands/search/result.py delete mode 100644 redis/commands/search/suggestion.py delete mode 100644 redis/commands/sentinel.py delete mode 100644 redis/commands/timeseries/__init__.py delete mode 100644 redis/commands/timeseries/commands.py delete mode 100644 redis/commands/timeseries/info.py delete mode 100644 redis/commands/timeseries/utils.py delete mode 100644 redis/sentinel.py delete mode 100644 tests/test_bloom.py delete mode 100644 tests/test_cluster.py delete mode 100644 tests/test_graph.py delete mode 100644 tests/test_graph_utils/__init__.py delete mode 100644 tests/test_graph_utils/test_edge.py delete mode 100644 tests/test_graph_utils/test_node.py delete mode 100644 tests/test_graph_utils/test_path.py delete mode 100644 tests/test_json.py delete mode 100644 tests/test_search.py delete mode 100644 tests/test_sentinel.py delete mode 100644 tests/test_timeseries.py diff --git a/redis/__init__.py b/redis/__init__.py index 35044be..f0b8623 100644 --- a/redis/__init__.py +++ b/redis/__init__.py @@ -6,7 +6,6 @@ else: import importlib_metadata as metadata from redis.client import Redis, StrictRedis -from redis.cluster import RedisCluster from redis.connection import ( BlockingConnectionPool, Connection, @@ -29,12 +28,6 @@ from redis.exceptions import ( TimeoutError, WatchError, ) -from redis.sentinel import ( - Sentinel, - SentinelConnectionPool, - SentinelManagedConnection, - SentinelManagedSSLConnection, -) from redis.utils import from_url @@ -68,13 +61,8 @@ __all__ = [ "PubSubError", "ReadOnlyError", "Redis", - "RedisCluster", "RedisError", "ResponseError", - "Sentinel", - "SentinelConnectionPool", - "SentinelManagedConnection", - "SentinelManagedSSLConnection", "SSLConnection", "StrictRedis", "TimeoutError", diff --git a/redis/client.py b/redis/client.py index 0984a7c..bf7f596 100755 --- a/redis/client.py +++ b/redis/client.py @@ -5,14 +5,9 @@ import threading import time import warnings from itertools import chain - -from redis.commands import ( - CoreCommands, - RedisModuleCommands, - SentinelCommands, - list_or_args, -) +from redis.commands import CoreCommands, list_or_args from redis.connection import ConnectionPool, SSLConnection, UnixDomainSocketConnection +from redis.lock import Lock from redis.exceptions import ( ConnectionError, ExecAbortError, @@ -642,7 +637,7 @@ def parse_set_result(response, **options): return response and str_if_bytes(response) == "OK" -class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): +class Redis(CoreCommands, object): """ Implementation of the Redis protocol. diff --git a/redis/cluster.py b/redis/cluster.py deleted file mode 100644 index 5707a9d..0000000 --- a/redis/cluster.py +++ /dev/null @@ -1,2150 +0,0 @@ -import copy -import logging -import random -import socket -import sys -import threading -import time -from collections import OrderedDict - -from redis.client import CaseInsensitiveDict, PubSub, Redis -from redis.commands import CommandsParser, RedisClusterCommands -from redis.connection import ConnectionPool, DefaultParser, Encoder, parse_url -from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot -from redis.exceptions import ( - AskError, - BusyLoadingError, - ClusterCrossSlotError, - ClusterDownError, - ClusterError, - ConnectionError, - DataError, - MasterDownError, - MovedError, - RedisClusterException, - RedisError, - ResponseError, - SlotNotCoveredError, - TimeoutError, - TryAgainError, -) -from redis.utils import ( - dict_merge, - list_keys_to_dict, - merge_result, - safe_str, - str_if_bytes, -) - -log = logging.getLogger(__name__) - - -def get_node_name(host, port): - return f"{host}:{port}" - - -def get_connection(redis_node, *args, **options): - return redis_node.connection or redis_node.connection_pool.get_connection( - args[0], **options - ) - - -def parse_scan_result(command, res, **options): - keys_list = [] - for primary_res in res.values(): - keys_list += primary_res[1] - return 0, keys_list - - -def parse_pubsub_numsub(command, res, **options): - numsub_d = OrderedDict() - for numsub_tups in res.values(): - for channel, numsubbed in numsub_tups: - try: - numsub_d[channel] += numsubbed - except KeyError: - numsub_d[channel] = numsubbed - - ret_numsub = [(channel, numsub) for channel, numsub in numsub_d.items()] - return ret_numsub - - -def parse_cluster_slots(resp, **options): - current_host = options.get("current_host", "") - - def fix_server(*args): - return str_if_bytes(args[0]) or current_host, args[1] - - slots = {} - for slot in resp: - start, end, primary = slot[:3] - replicas = slot[3:] - slots[start, end] = { - "primary": fix_server(*primary), - "replicas": [fix_server(*replica) for replica in replicas], - } - - return slots - - -PRIMARY = "primary" -REPLICA = "replica" -SLOT_ID = "slot-id" - -REDIS_ALLOWED_KEYS = ( - "charset", - "connection_class", - "connection_pool", - "client_name", - "db", - "decode_responses", - "encoding", - "encoding_errors", - "errors", - "host", - "max_connections", - "nodes_flag", - "redis_connect_func", - "password", - "port", - "retry", - "retry_on_timeout", - "socket_connect_timeout", - "socket_keepalive", - "socket_keepalive_options", - "socket_timeout", - "ssl", - "ssl_ca_certs", - "ssl_certfile", - "ssl_cert_reqs", - "ssl_keyfile", - "unix_socket_path", - "username", -) -KWARGS_DISABLED_KEYS = ( - "host", - "port", -) - -# Not complete, but covers the major ones -# https://redis.io/commands -READ_COMMANDS = frozenset( - [ - "BITCOUNT", - "BITPOS", - "EXISTS", - "GEODIST", - "GEOHASH", - "GEOPOS", - "GEORADIUS", - "GEORADIUSBYMEMBER", - "GET", - "GETBIT", - "GETRANGE", - "HEXISTS", - "HGET", - "HGETALL", - "HKEYS", - "HLEN", - "HMGET", - "HSTRLEN", - "HVALS", - "KEYS", - "LINDEX", - "LLEN", - "LRANGE", - "MGET", - "PTTL", - "RANDOMKEY", - "SCARD", - "SDIFF", - "SINTER", - "SISMEMBER", - "SMEMBERS", - "SRANDMEMBER", - "STRLEN", - "SUNION", - "TTL", - "ZCARD", - "ZCOUNT", - "ZRANGE", - "ZSCORE", - ] -) - - -def cleanup_kwargs(**kwargs): - """ - Remove unsupported or disabled keys from kwargs - """ - connection_kwargs = { - k: v - for k, v in kwargs.items() - if k in REDIS_ALLOWED_KEYS and k not in KWARGS_DISABLED_KEYS - } - - return connection_kwargs - - -class ClusterParser(DefaultParser): - EXCEPTION_CLASSES = dict_merge( - DefaultParser.EXCEPTION_CLASSES, - { - "ASK": AskError, - "TRYAGAIN": TryAgainError, - "MOVED": MovedError, - "CLUSTERDOWN": ClusterDownError, - "CROSSSLOT": ClusterCrossSlotError, - "MASTERDOWN": MasterDownError, - }, - ) - - -class RedisCluster(RedisClusterCommands): - RedisClusterRequestTTL = 16 - - PRIMARIES = "primaries" - REPLICAS = "replicas" - ALL_NODES = "all" - RANDOM = "random" - DEFAULT_NODE = "default-node" - - NODE_FLAGS = {PRIMARIES, REPLICAS, ALL_NODES, RANDOM, DEFAULT_NODE} - - COMMAND_FLAGS = dict_merge( - list_keys_to_dict( - [ - "ACL CAT", - "ACL DELUSER", - "ACL GENPASS", - "ACL GETUSER", - "ACL HELP", - "ACL LIST", - "ACL LOG", - "ACL LOAD", - "ACL SAVE", - "ACL SETUSER", - "ACL USERS", - "ACL WHOAMI", - "CLIENT LIST", - "CLIENT SETNAME", - "CLIENT GETNAME", - "CONFIG SET", - "CONFIG REWRITE", - "CONFIG RESETSTAT", - "TIME", - "PUBSUB CHANNELS", - "PUBSUB NUMPAT", - "PUBSUB NUMSUB", - "PING", - "INFO", - "SHUTDOWN", - "KEYS", - "SCAN", - "FLUSHALL", - "FLUSHDB", - "DBSIZE", - "BGSAVE", - "SLOWLOG GET", - "SLOWLOG LEN", - "SLOWLOG RESET", - "WAIT", - "SAVE", - "MEMORY PURGE", - "MEMORY MALLOC-STATS", - "MEMORY STATS", - "LASTSAVE", - "CLIENT TRACKINGINFO", - "CLIENT PAUSE", - "CLIENT UNPAUSE", - "CLIENT UNBLOCK", - "CLIENT ID", - "CLIENT REPLY", - "CLIENT GETREDIR", - "CLIENT INFO", - "CLIENT KILL", - "READONLY", - "READWRITE", - "CLUSTER INFO", - "CLUSTER MEET", - "CLUSTER NODES", - "CLUSTER REPLICAS", - "CLUSTER RESET", - "CLUSTER SET-CONFIG-EPOCH", - "CLUSTER SLOTS", - "CLUSTER COUNT-FAILURE-REPORTS", - "CLUSTER KEYSLOT", - "COMMAND", - "COMMAND COUNT", - "COMMAND GETKEYS", - "CONFIG GET", - "DEBUG", - "RANDOMKEY", - "READONLY", - "READWRITE", - "TIME", - ], - DEFAULT_NODE, - ), - list_keys_to_dict( - [ - "CLUSTER COUNTKEYSINSLOT", - "CLUSTER DELSLOTS", - "CLUSTER GETKEYSINSLOT", - "CLUSTER SETSLOT", - ], - SLOT_ID, - ), - ) - - CLUSTER_COMMANDS_RESPONSE_CALLBACKS = { - "CLUSTER ADDSLOTS": bool, - "CLUSTER COUNT-FAILURE-REPORTS": int, - "CLUSTER COUNTKEYSINSLOT": int, - "CLUSTER DELSLOTS": bool, - "CLUSTER FAILOVER": bool, - "CLUSTER FORGET": bool, - "CLUSTER GETKEYSINSLOT": list, - "CLUSTER KEYSLOT": int, - "CLUSTER MEET": bool, - "CLUSTER REPLICATE": bool, - "CLUSTER RESET": bool, - "CLUSTER SAVECONFIG": bool, - "CLUSTER SET-CONFIG-EPOCH": bool, - "CLUSTER SETSLOT": bool, - "CLUSTER SLOTS": parse_cluster_slots, - "ASKING": bool, - "READONLY": bool, - "READWRITE": bool, - } - - RESULT_CALLBACKS = dict_merge( - list_keys_to_dict( - [ - "PUBSUB NUMSUB", - ], - parse_pubsub_numsub, - ), - list_keys_to_dict( - [ - "PUBSUB NUMPAT", - ], - lambda command, res: sum(list(res.values())), - ), - list_keys_to_dict( - [ - "KEYS", - "PUBSUB CHANNELS", - ], - merge_result, - ), - list_keys_to_dict( - [ - "PING", - "CONFIG SET", - "CONFIG REWRITE", - "CONFIG RESETSTAT", - "CLIENT SETNAME", - "BGSAVE", - "SLOWLOG RESET", - "SAVE", - "MEMORY PURGE", - "CLIENT PAUSE", - "CLIENT UNPAUSE", - ], - lambda command, res: all(res.values()) if isinstance(res, dict) else res, - ), - list_keys_to_dict( - [ - "DBSIZE", - "WAIT", - ], - lambda command, res: sum(res.values()) if isinstance(res, dict) else res, - ), - list_keys_to_dict( - [ - "CLIENT UNBLOCK", - ], - lambda command, res: 1 if sum(res.values()) > 0 else 0, - ), - list_keys_to_dict( - [ - "SCAN", - ], - parse_scan_result, - ), - ) - - ERRORS_ALLOW_RETRY = ( - ConnectionError, - TimeoutError, - ClusterDownError, - ) - - def __init__( - self, - host=None, - port=6379, - startup_nodes=None, - cluster_error_retry_attempts=3, - require_full_coverage=True, - skip_full_coverage_check=False, - reinitialize_steps=10, - read_from_replicas=False, - url=None, - **kwargs, - ): - """ - Initialize a new RedisCluster client. - - :startup_nodes: 'list[ClusterNode]' - List of nodes from which initial bootstrapping can be done - :host: 'str' - Can be used to point to a startup node - :port: 'int' - Can be used to point to a startup node - :require_full_coverage: 'bool' - If set to True, as it is by default, all slots must be covered. - If set to False and not all slots are covered, the instance - creation will succeed only if 'cluster-require-full-coverage' - configuration is set to 'no' in all of the cluster's nodes. - Otherwise, RedisClusterException will be thrown. - :skip_full_coverage_check: 'bool' - If require_full_coverage is set to False, a check of - cluster-require-full-coverage config will be executed against all - nodes. Set skip_full_coverage_check to True to skip this check. - Useful for clusters without the CONFIG command (like ElastiCache) - :read_from_replicas: 'bool' - Enable read from replicas in READONLY mode. You can read possibly - stale data. - When set to true, read commands will be assigned between the - primary and its replications in a Round-Robin manner. - :cluster_error_retry_attempts: 'int' - Retry command execution attempts when encountering ClusterDownError - or ConnectionError - :reinitialize_steps: 'int' - Specifies the number of MOVED errors that need to occur before - reinitializing the whole cluster topology. If a MOVED error occurs - and the cluster does not need to be reinitialized on this current - error handling, only the MOVED slot will be patched with the - redirected node. - To reinitialize the cluster on every MOVED error, set - reinitialize_steps to 1. - To avoid reinitializing the cluster on moved errors, set - reinitialize_steps to 0. - - :**kwargs: - Extra arguments that will be sent into Redis instance when created - (See Official redis-py doc for supported kwargs - [https://github.com/andymccurdy/redis-py/blob/master/redis/client.py]) - Some kwargs are not supported and will raise a - RedisClusterException: - - db (Redis do not support database SELECT in cluster mode) - """ - log.info("Creating a new instance of RedisCluster client") - - if startup_nodes is None: - startup_nodes = [] - - if "db" in kwargs: - # Argument 'db' is not possible to use in cluster mode - raise RedisClusterException( - "Argument 'db' is not possible to use in cluster mode" - ) - - # Get the startup node/s - from_url = False - if url is not None: - from_url = True - url_options = parse_url(url) - if "path" in url_options: - raise RedisClusterException( - "RedisCluster does not currently support Unix Domain " - "Socket connections" - ) - if "db" in url_options and url_options["db"] != 0: - # Argument 'db' is not possible to use in cluster mode - raise RedisClusterException( - "A ``db`` querystring option can only be 0 in cluster mode" - ) - kwargs.update(url_options) - host = kwargs.get("host") - port = kwargs.get("port", port) - startup_nodes.append(ClusterNode(host, port)) - elif host is not None and port is not None: - startup_nodes.append(ClusterNode(host, port)) - elif len(startup_nodes) == 0: - # No startup node was provided - raise RedisClusterException( - "RedisCluster requires at least one node to discover the " - "cluster. Please provide one of the followings:\n" - "1. host and port, for example:\n" - " RedisCluster(host='localhost', port=6379)\n" - "2. list of startup nodes, for example:\n" - " RedisCluster(startup_nodes=[ClusterNode('localhost', 6379)," - " ClusterNode('localhost', 6378)])" - ) - log.debug(f"startup_nodes : {startup_nodes}") - # Update the connection arguments - # Whenever a new connection is established, RedisCluster's on_connect - # method should be run - # If the user passed on_connect function we'll save it and run it - # inside the RedisCluster.on_connect() function - self.user_on_connect_func = kwargs.pop("redis_connect_func", None) - kwargs.update({"redis_connect_func": self.on_connect}) - kwargs = cleanup_kwargs(**kwargs) - - self.encoder = Encoder( - kwargs.get("encoding", "utf-8"), - kwargs.get("encoding_errors", "strict"), - kwargs.get("decode_responses", False), - ) - self.cluster_error_retry_attempts = cluster_error_retry_attempts - self.command_flags = self.__class__.COMMAND_FLAGS.copy() - self.node_flags = self.__class__.NODE_FLAGS.copy() - self.read_from_replicas = read_from_replicas - self.reinitialize_counter = 0 - self.reinitialize_steps = reinitialize_steps - self.nodes_manager = None - self.nodes_manager = NodesManager( - startup_nodes=startup_nodes, - from_url=from_url, - require_full_coverage=require_full_coverage, - skip_full_coverage_check=skip_full_coverage_check, - **kwargs, - ) - - self.cluster_response_callbacks = CaseInsensitiveDict( - self.__class__.CLUSTER_COMMANDS_RESPONSE_CALLBACKS - ) - self.result_callbacks = CaseInsensitiveDict(self.__class__.RESULT_CALLBACKS) - self.commands_parser = CommandsParser(self) - self._lock = threading.Lock() - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.close() - - def __del__(self): - self.close() - - def disconnect_connection_pools(self): - for node in self.get_nodes(): - if node.redis_connection: - try: - node.redis_connection.connection_pool.disconnect() - except OSError: - # Client was already disconnected. do nothing - pass - - @classmethod - def from_url(cls, url, **kwargs): - """ - Return a Redis client object configured from the given URL - - For example:: - - redis://[[username]:[password]]@localhost:6379/0 - rediss://[[username]:[password]]@localhost:6379/0 - unix://[[username]:[password]]@/path/to/socket.sock?db=0 - - Three URL schemes are supported: - - - `redis://` creates a TCP socket connection. See more at: - - - `rediss://` creates a SSL wrapped TCP socket connection. See more at: - - - ``unix://``: creates a Unix Domain Socket connection. - - The username, password, hostname, path and all querystring values - are passed through urllib.parse.unquote in order to replace any - percent-encoded values with their corresponding characters. - - There are several ways to specify a database number. The first value - found will be used: - 1. A ``db`` querystring option, e.g. redis://localhost?db=0 - 2. If using the redis:// or rediss:// schemes, the path argument - of the url, e.g. redis://localhost/0 - 3. A ``db`` keyword argument to this function. - - If none of these options are specified, the default db=0 is used. - - All querystring options are cast to their appropriate Python types. - Boolean arguments can be specified with string values "True"/"False" - or "Yes"/"No". Values that cannot be properly cast cause a - ``ValueError`` to be raised. Once parsed, the querystring arguments - and keyword arguments are passed to the ``ConnectionPool``'s - class initializer. In the case of conflicting arguments, querystring - arguments always win. - - """ - return cls(url=url, **kwargs) - - def on_connect(self, connection): - """ - Initialize the connection, authenticate and select a database and send - READONLY if it is set during object initialization. - """ - connection.set_parser(ClusterParser) - connection.on_connect() - - if self.read_from_replicas: - # Sending READONLY command to server to configure connection as - # readonly. Since each cluster node may change its server type due - # to a failover, we should establish a READONLY connection - # regardless of the server type. If this is a primary connection, - # READONLY would not affect executing write commands. - connection.send_command("READONLY") - if str_if_bytes(connection.read_response()) != "OK": - raise ConnectionError("READONLY command failed") - - if self.user_on_connect_func is not None: - self.user_on_connect_func(connection) - - def get_redis_connection(self, node): - if not node.redis_connection: - with self._lock: - if not node.redis_connection: - self.nodes_manager.create_redis_connections([node]) - return node.redis_connection - - def get_node(self, host=None, port=None, node_name=None): - return self.nodes_manager.get_node(host, port, node_name) - - def get_primaries(self): - return self.nodes_manager.get_nodes_by_server_type(PRIMARY) - - def get_replicas(self): - return self.nodes_manager.get_nodes_by_server_type(REPLICA) - - def get_random_node(self): - return random.choice(list(self.nodes_manager.nodes_cache.values())) - - def get_nodes(self): - return list(self.nodes_manager.nodes_cache.values()) - - def get_node_from_key(self, key, replica=False): - """ - Get the node that holds the key's slot. - If replica set to True but the slot doesn't have any replicas, None is - returned. - """ - slot = self.keyslot(key) - slot_cache = self.nodes_manager.slots_cache.get(slot) - if slot_cache is None or len(slot_cache) == 0: - raise SlotNotCoveredError(f'Slot "{slot}" is not covered by the cluster.') - if replica and len(self.nodes_manager.slots_cache[slot]) < 2: - return None - elif replica: - node_idx = 1 - else: - # primary - node_idx = 0 - - return slot_cache[node_idx] - - def get_default_node(self): - """ - Get the cluster's default node - """ - return self.nodes_manager.default_node - - def set_default_node(self, node): - """ - Set the default node of the cluster. - :param node: 'ClusterNode' - :return True if the default node was set, else False - """ - if node is None or self.get_node(node_name=node.name) is None: - log.info( - "The requested node does not exist in the cluster, so " - "the default node was not changed." - ) - return False - self.nodes_manager.default_node = node - log.info(f"Changed the default cluster node to {node}") - return True - - def monitor(self, target_node=None): - """ - Returns a Monitor object for the specified target node. - The default cluster node will be selected if no target node was - specified. - Monitor is useful for handling the MONITOR command to the redis server. - next_command() method returns one command from monitor - listen() method yields commands from monitor. - """ - if target_node is None: - target_node = self.get_default_node() - if target_node.redis_connection is None: - raise RedisClusterException( - f"Cluster Node {target_node.name} has no redis_connection" - ) - return target_node.redis_connection.monitor() - - def pubsub(self, node=None, host=None, port=None, **kwargs): - """ - Allows passing a ClusterNode, or host&port, to get a pubsub instance - connected to the specified node - """ - return ClusterPubSub(self, node=node, host=host, port=port, **kwargs) - - def pipeline(self, transaction=None, shard_hint=None): - """ - Cluster impl: - Pipelines do not work in cluster mode the same way they - do in normal mode. Create a clone of this object so - that simulating pipelines will work correctly. Each - command will be called directly when used and - when calling execute() will only return the result stack. - """ - if shard_hint: - raise RedisClusterException("shard_hint is deprecated in cluster mode") - - if transaction: - raise RedisClusterException("transaction is deprecated in cluster mode") - - return ClusterPipeline( - nodes_manager=self.nodes_manager, - startup_nodes=self.nodes_manager.startup_nodes, - result_callbacks=self.result_callbacks, - cluster_response_callbacks=self.cluster_response_callbacks, - cluster_error_retry_attempts=self.cluster_error_retry_attempts, - read_from_replicas=self.read_from_replicas, - reinitialize_steps=self.reinitialize_steps, - ) - - def _determine_nodes(self, *args, **kwargs): - command = args[0] - nodes_flag = kwargs.pop("nodes_flag", None) - if nodes_flag is not None: - # nodes flag passed by the user - command_flag = nodes_flag - else: - # get the nodes group for this command if it was predefined - command_flag = self.command_flags.get(command) - if command_flag: - log.debug(f"Target node/s for {command}: {command_flag}") - if command_flag == self.__class__.RANDOM: - # return a random node - return [self.get_random_node()] - elif command_flag == self.__class__.PRIMARIES: - # return all primaries - return self.get_primaries() - elif command_flag == self.__class__.REPLICAS: - # return all replicas - return self.get_replicas() - elif command_flag == self.__class__.ALL_NODES: - # return all nodes - return self.get_nodes() - elif command_flag == self.__class__.DEFAULT_NODE: - # return the cluster's default node - return [self.nodes_manager.default_node] - else: - # get the node that holds the key's slot - slot = self.determine_slot(*args) - node = self.nodes_manager.get_node_from_slot( - slot, self.read_from_replicas and command in READ_COMMANDS - ) - log.debug(f"Target for {args}: slot {slot}") - return [node] - - def _should_reinitialized(self): - # To reinitialize the cluster on every MOVED error, - # set reinitialize_steps to 1. - # To avoid reinitializing the cluster on moved errors, set - # reinitialize_steps to 0. - if self.reinitialize_steps == 0: - return False - else: - return self.reinitialize_counter % self.reinitialize_steps == 0 - - def keyslot(self, key): - """ - Calculate keyslot for a given key. - See Keys distribution model in https://redis.io/topics/cluster-spec - """ - k = self.encoder.encode(key) - return key_slot(k) - - def _get_command_keys(self, *args): - """ - Get the keys in the command. If the command has no keys in in, None is - returned. - """ - redis_conn = self.get_default_node().redis_connection - return self.commands_parser.get_keys(redis_conn, *args) - - def determine_slot(self, *args): - """ - Figure out what slot based on command and args - """ - if self.command_flags.get(args[0]) == SLOT_ID: - # The command contains the slot ID - return args[1] - - # Get the keys in the command - keys = self._get_command_keys(*args) - if keys is None or len(keys) == 0: - raise RedisClusterException( - "No way to dispatch this command to Redis Cluster. " - "Missing key.\nYou can execute the command by specifying " - f"target nodes.\nCommand: {args}" - ) - - if len(keys) > 1: - # multi-key command, we need to make sure all keys are mapped to - # the same slot - slots = {self.keyslot(key) for key in keys} - if len(slots) != 1: - raise RedisClusterException( - f"{args[0]} - all keys must map to the same key slot" - ) - return slots.pop() - else: - # single key command - return self.keyslot(keys[0]) - - def reinitialize_caches(self): - self.nodes_manager.initialize() - - def get_encoder(self): - """ - Get the connections' encoder - """ - return self.encoder - - def get_connection_kwargs(self): - """ - Get the connections' key-word arguments - """ - return self.nodes_manager.connection_kwargs - - def _is_nodes_flag(self, target_nodes): - return isinstance(target_nodes, str) and target_nodes in self.node_flags - - def _parse_target_nodes(self, target_nodes): - if isinstance(target_nodes, list): - nodes = target_nodes - elif isinstance(target_nodes, ClusterNode): - # Supports passing a single ClusterNode as a variable - nodes = [target_nodes] - elif isinstance(target_nodes, dict): - # Supports dictionaries of the format {node_name: node}. - # It enables to execute commands with multi nodes as follows: - # rc.cluster_save_config(rc.get_primaries()) - nodes = target_nodes.values() - else: - raise TypeError( - "target_nodes type can be one of the following: " - "node_flag (PRIMARIES, REPLICAS, RANDOM, ALL_NODES)," - "ClusterNode, list, or dict. " - f"The passed type is {type(target_nodes)}" - ) - return nodes - - def execute_command(self, *args, **kwargs): - """ - Wrapper for ERRORS_ALLOW_RETRY error handling. - - It will try the number of times specified by the config option - "self.cluster_error_retry_attempts" which defaults to 3 unless manually - configured. - - If it reaches the number of times, the command will raise the exception - - Key argument :target_nodes: can be passed with the following types: - nodes_flag: PRIMARIES, REPLICAS, ALL_NODES, RANDOM - ClusterNode - list - dict - """ - target_nodes_specified = False - target_nodes = None - passed_targets = kwargs.pop("target_nodes", None) - if passed_targets is not None and not self._is_nodes_flag(passed_targets): - target_nodes = self._parse_target_nodes(passed_targets) - target_nodes_specified = True - # If an error that allows retrying was thrown, the nodes and slots - # cache were reinitialized. We will retry executing the command with - # the updated cluster setup only when the target nodes can be - # determined again with the new cache tables. Therefore, when target - # nodes were passed to this function, we cannot retry the command - # execution since the nodes may not be valid anymore after the tables - # were reinitialized. So in case of passed target nodes, - # retry_attempts will be set to 1. - retry_attempts = ( - 1 if target_nodes_specified else self.cluster_error_retry_attempts - ) - exception = None - for _ in range(0, retry_attempts): - try: - res = {} - if not target_nodes_specified: - # Determine the nodes to execute the command on - target_nodes = self._determine_nodes( - *args, **kwargs, nodes_flag=passed_targets - ) - if not target_nodes: - raise RedisClusterException( - f"No targets were found to execute {args} command on" - ) - for node in target_nodes: - res[node.name] = self._execute_command(node, *args, **kwargs) - # Return the processed result - return self._process_result(args[0], res, **kwargs) - except BaseException as e: - if type(e) in RedisCluster.ERRORS_ALLOW_RETRY: - # The nodes and slots cache were reinitialized. - # Try again with the new cluster setup. - exception = e - else: - # All other errors should be raised. - raise e - - # If it fails the configured number of times then raise exception back - # to caller of this method - raise exception - - def _execute_command(self, target_node, *args, **kwargs): - """ - Send a command to a node in the cluster - """ - command = args[0] - redis_node = None - connection = None - redirect_addr = None - asking = False - moved = False - ttl = int(self.RedisClusterRequestTTL) - connection_error_retry_counter = 0 - - while ttl > 0: - ttl -= 1 - try: - if asking: - target_node = self.get_node(node_name=redirect_addr) - elif moved: - # MOVED occurred and the slots cache was updated, - # refresh the target node - slot = self.determine_slot(*args) - target_node = self.nodes_manager.get_node_from_slot( - slot, self.read_from_replicas and command in READ_COMMANDS - ) - moved = False - - log.debug( - f"Executing command {command} on target node: " - f"{target_node.server_type} {target_node.name}" - ) - redis_node = self.get_redis_connection(target_node) - connection = get_connection(redis_node, *args, **kwargs) - if asking: - connection.send_command("ASKING") - redis_node.parse_response(connection, "ASKING", **kwargs) - asking = False - - connection.send_command(*args) - response = redis_node.parse_response(connection, command, **kwargs) - if command in self.cluster_response_callbacks: - response = self.cluster_response_callbacks[command]( - response, **kwargs - ) - return response - - except (RedisClusterException, BusyLoadingError) as e: - log.exception(type(e)) - raise - except (ConnectionError, TimeoutError) as e: - log.exception(type(e)) - # ConnectionError can also be raised if we couldn't get a - # connection from the pool before timing out, so check that - # this is an actual connection before attempting to disconnect. - if connection is not None: - connection.disconnect() - connection_error_retry_counter += 1 - - # Give the node 0.25 seconds to get back up and retry again - # with same node and configuration. After 5 attempts then try - # to reinitialize the cluster and see if the nodes - # configuration has changed or not - if connection_error_retry_counter < 5: - time.sleep(0.25) - else: - # Hard force of reinitialize of the node/slots setup - # and try again with the new setup - self.nodes_manager.initialize() - raise - except MovedError as e: - # First, we will try to patch the slots/nodes cache with the - # redirected node output and try again. If MovedError exceeds - # 'reinitialize_steps' number of times, we will force - # reinitializing the tables, and then try again. - # 'reinitialize_steps' counter will increase faster when - # the same client object is shared between multiple threads. To - # reduce the frequency you can set this variable in the - # RedisCluster constructor. - log.exception("MovedError") - self.reinitialize_counter += 1 - if self._should_reinitialized(): - self.nodes_manager.initialize() - # Reset the counter - self.reinitialize_counter = 0 - else: - self.nodes_manager.update_moved_exception(e) - moved = True - except TryAgainError: - log.exception("TryAgainError") - - if ttl < self.RedisClusterRequestTTL / 2: - time.sleep(0.05) - except AskError as e: - log.exception("AskError") - - redirect_addr = get_node_name(host=e.host, port=e.port) - asking = True - except ClusterDownError as e: - log.exception("ClusterDownError") - # ClusterDownError can occur during a failover and to get - # self-healed, we will try to reinitialize the cluster layout - # and retry executing the command - time.sleep(0.25) - self.nodes_manager.initialize() - raise e - except ResponseError as e: - message = e.__str__() - log.exception(f"ResponseError: {message}") - raise e - except BaseException as e: - log.exception("BaseException") - if connection: - connection.disconnect() - raise e - finally: - if connection is not None: - redis_node.connection_pool.release(connection) - - raise ClusterError("TTL exhausted.") - - def close(self): - try: - with self._lock: - if self.nodes_manager: - self.nodes_manager.close() - except AttributeError: - # RedisCluster's __init__ can fail before nodes_manager is set - pass - - def _process_result(self, command, res, **kwargs): - """ - Process the result of the executed command. - The function would return a dict or a single value. - - :type command: str - :type res: dict - - `res` should be in the following format: - Dict - """ - if command in self.result_callbacks: - return self.result_callbacks[command](command, res, **kwargs) - elif len(res) == 1: - # When we execute the command on a single node, we can - # remove the dictionary and return a single response - return list(res.values())[0] - else: - return res - - -class ClusterNode: - def __init__(self, host, port, server_type=None, redis_connection=None): - if host == "localhost": - host = socket.gethostbyname(host) - - self.host = host - self.port = port - self.name = get_node_name(host, port) - self.server_type = server_type - self.redis_connection = redis_connection - - def __repr__(self): - return ( - f"[host={self.host}," - f"port={self.port}," - f"name={self.name}," - f"server_type={self.server_type}," - f"redis_connection={self.redis_connection}]" - ) - - def __eq__(self, obj): - return isinstance(obj, ClusterNode) and obj.name == self.name - - def __del__(self): - if self.redis_connection is not None: - self.redis_connection.close() - - -class LoadBalancer: - """ - Round-Robin Load Balancing - """ - - def __init__(self, start_index=0): - self.primary_to_idx = {} - self.start_index = start_index - - def get_server_index(self, primary, list_size): - server_index = self.primary_to_idx.setdefault(primary, self.start_index) - # Update the index - self.primary_to_idx[primary] = (server_index + 1) % list_size - return server_index - - def reset(self): - self.primary_to_idx.clear() - - -class NodesManager: - def __init__( - self, - startup_nodes, - from_url=False, - require_full_coverage=True, - skip_full_coverage_check=False, - lock=None, - **kwargs, - ): - self.nodes_cache = {} - self.slots_cache = {} - self.startup_nodes = {} - self.default_node = None - self.populate_startup_nodes(startup_nodes) - self.from_url = from_url - self._require_full_coverage = require_full_coverage - self._skip_full_coverage_check = skip_full_coverage_check - self._moved_exception = None - self.connection_kwargs = kwargs - self.read_load_balancer = LoadBalancer() - if lock is None: - lock = threading.Lock() - self._lock = lock - self.initialize() - - def get_node(self, host=None, port=None, node_name=None): - """ - Get the requested node from the cluster's nodes. - nodes. - :return: ClusterNode if the node exists, else None - """ - if host and port: - # the user passed host and port - if host == "localhost": - host = socket.gethostbyname(host) - return self.nodes_cache.get(get_node_name(host=host, port=port)) - elif node_name: - return self.nodes_cache.get(node_name) - else: - log.error( - "get_node requires one of the following: " - "1. node name " - "2. host and port" - ) - return None - - def update_moved_exception(self, exception): - self._moved_exception = exception - - def _update_moved_slots(self): - """ - Update the slot's node with the redirected one - """ - e = self._moved_exception - redirected_node = self.get_node(host=e.host, port=e.port) - if redirected_node is not None: - # The node already exists - if redirected_node.server_type is not PRIMARY: - # Update the node's server type - redirected_node.server_type = PRIMARY - else: - # This is a new node, we will add it to the nodes cache - redirected_node = ClusterNode(e.host, e.port, PRIMARY) - self.nodes_cache[redirected_node.name] = redirected_node - if redirected_node in self.slots_cache[e.slot_id]: - # The MOVED error resulted from a failover, and the new slot owner - # had previously been a replica. - old_primary = self.slots_cache[e.slot_id][0] - # Update the old primary to be a replica and add it to the end of - # the slot's node list - old_primary.server_type = REPLICA - self.slots_cache[e.slot_id].append(old_primary) - # Remove the old replica, which is now a primary, from the slot's - # node list - self.slots_cache[e.slot_id].remove(redirected_node) - # Override the old primary with the new one - self.slots_cache[e.slot_id][0] = redirected_node - if self.default_node == old_primary: - # Update the default node with the new primary - self.default_node = redirected_node - else: - # The new slot owner is a new server, or a server from a different - # shard. We need to remove all current nodes from the slot's list - # (including replications) and add just the new node. - self.slots_cache[e.slot_id] = [redirected_node] - # Reset moved_exception - self._moved_exception = None - - def get_node_from_slot(self, slot, read_from_replicas=False, server_type=None): - """ - Gets a node that servers this hash slot - """ - if self._moved_exception: - with self._lock: - if self._moved_exception: - self._update_moved_slots() - - if self.slots_cache.get(slot) is None or len(self.slots_cache[slot]) == 0: - raise SlotNotCoveredError( - f'Slot "{slot}" not covered by the cluster. ' - f'"require_full_coverage={self._require_full_coverage}"' - ) - - if read_from_replicas is True: - # get the server index in a Round-Robin manner - primary_name = self.slots_cache[slot][0].name - node_idx = self.read_load_balancer.get_server_index( - primary_name, len(self.slots_cache[slot]) - ) - elif ( - server_type is None - or server_type == PRIMARY - or len(self.slots_cache[slot]) == 1 - ): - # return a primary - node_idx = 0 - else: - # return a replica - # randomly choose one of the replicas - node_idx = random.randint(1, len(self.slots_cache[slot]) - 1) - - return self.slots_cache[slot][node_idx] - - def get_nodes_by_server_type(self, server_type): - """ - Get all nodes with the specified server type - :param server_type: 'primary' or 'replica' - :return: list of ClusterNode - """ - return [ - node - for node in self.nodes_cache.values() - if node.server_type == server_type - ] - - def populate_startup_nodes(self, nodes): - """ - Populate all startup nodes and filters out any duplicates - """ - for n in nodes: - self.startup_nodes[n.name] = n - - def cluster_require_full_coverage(self, cluster_nodes): - """ - if exists 'cluster-require-full-coverage no' config on redis servers, - then even all slots are not covered, cluster still will be able to - respond - """ - - def node_require_full_coverage(node): - try: - return ( - "yes" - in node.redis_connection.config_get( - "cluster-require-full-coverage" - ).values() - ) - except ConnectionError: - return False - except Exception as e: - raise RedisClusterException( - 'ERROR sending "config get cluster-require-full-coverage"' - f" command to redis server: {node.name}, {e}" - ) - - # at least one node should have cluster-require-full-coverage yes - return any(node_require_full_coverage(node) for node in cluster_nodes.values()) - - def check_slots_coverage(self, slots_cache): - # Validate if all slots are covered or if we should try next - # startup node - for i in range(0, REDIS_CLUSTER_HASH_SLOTS): - if i not in slots_cache: - return False - return True - - def create_redis_connections(self, nodes): - """ - This function will create a redis connection to all nodes in :nodes: - """ - for node in nodes: - if node.redis_connection is None: - node.redis_connection = self.create_redis_node( - host=node.host, - port=node.port, - **self.connection_kwargs, - ) - - def create_redis_node(self, host, port, **kwargs): - if self.from_url: - # Create a redis node with a costumed connection pool - kwargs.update({"host": host}) - kwargs.update({"port": port}) - r = Redis(connection_pool=ConnectionPool(**kwargs)) - else: - r = Redis(host=host, port=port, **kwargs) - return r - - def initialize(self): - """ - Initializes the nodes cache, slots cache and redis connections. - :startup_nodes: - Responsible for discovering other nodes in the cluster - """ - log.debug("Initializing the nodes' topology of the cluster") - self.reset() - tmp_nodes_cache = {} - tmp_slots = {} - disagreements = [] - startup_nodes_reachable = False - fully_covered = False - kwargs = self.connection_kwargs - for startup_node in self.startup_nodes.values(): - try: - if startup_node.redis_connection: - r = startup_node.redis_connection - else: - # Create a new Redis connection and let Redis decode the - # responses so we won't need to handle that - copy_kwargs = copy.deepcopy(kwargs) - copy_kwargs.update({"decode_responses": True, "encoding": "utf-8"}) - r = self.create_redis_node( - startup_node.host, startup_node.port, **copy_kwargs - ) - self.startup_nodes[startup_node.name].redis_connection = r - # Make sure cluster mode is enabled on this node - if bool(r.info().get("cluster_enabled")) is False: - raise RedisClusterException( - "Cluster mode is not enabled on this node" - ) - cluster_slots = str_if_bytes(r.execute_command("CLUSTER SLOTS")) - startup_nodes_reachable = True - except (ConnectionError, TimeoutError) as e: - msg = e.__str__ - log.exception( - "An exception occurred while trying to" - " initialize the cluster using the seed node" - f" {startup_node.name}:\n{msg}" - ) - continue - except ResponseError as e: - log.exception('ReseponseError sending "cluster slots" to redis server') - - # Isn't a cluster connection, so it won't parse these - # exceptions automatically - message = e.__str__() - if "CLUSTERDOWN" in message or "MASTERDOWN" in message: - continue - else: - raise RedisClusterException( - 'ERROR sending "cluster slots" command to redis ' - f"server: {startup_node}. error: {message}" - ) - except Exception as e: - message = e.__str__() - raise RedisClusterException( - 'ERROR sending "cluster slots" command to redis ' - f"server {startup_node.name}. error: {message}" - ) - - # CLUSTER SLOTS command results in the following output: - # [[slot_section[from_slot,to_slot,master,replica1,...,replicaN]]] - # where each node contains the following list: [IP, port, node_id] - # Therefore, cluster_slots[0][2][0] will be the IP address of the - # primary node of the first slot section. - # If there's only one server in the cluster, its ``host`` is '' - # Fix it to the host in startup_nodes - if ( - len(cluster_slots) == 1 - and len(cluster_slots[0][2][0]) == 0 - and len(self.startup_nodes) == 1 - ): - cluster_slots[0][2][0] = startup_node.host - - for slot in cluster_slots: - primary_node = slot[2] - host = primary_node[0] - if host == "": - host = startup_node.host - port = int(primary_node[1]) - - target_node = tmp_nodes_cache.get(get_node_name(host, port)) - if target_node is None: - target_node = ClusterNode(host, port, PRIMARY) - # add this node to the nodes cache - tmp_nodes_cache[target_node.name] = target_node - - for i in range(int(slot[0]), int(slot[1]) + 1): - if i not in tmp_slots: - tmp_slots[i] = [] - tmp_slots[i].append(target_node) - replica_nodes = [slot[j] for j in range(3, len(slot))] - - for replica_node in replica_nodes: - host = replica_node[0] - port = replica_node[1] - - target_replica_node = tmp_nodes_cache.get( - get_node_name(host, port) - ) - if target_replica_node is None: - target_replica_node = ClusterNode(host, port, REPLICA) - tmp_slots[i].append(target_replica_node) - # add this node to the nodes cache - tmp_nodes_cache[ - target_replica_node.name - ] = target_replica_node - else: - # Validate that 2 nodes want to use the same slot cache - # setup - tmp_slot = tmp_slots[i][0] - if tmp_slot.name != target_node.name: - disagreements.append( - f"{tmp_slot.name} vs {target_node.name} on slot: {i}" - ) - - if len(disagreements) > 5: - raise RedisClusterException( - f"startup_nodes could not agree on a valid " - f'slots cache: {", ".join(disagreements)}' - ) - - fully_covered = self.check_slots_coverage(tmp_slots) - if fully_covered: - # Don't need to continue to the next startup node if all - # slots are covered - break - - if not startup_nodes_reachable: - raise RedisClusterException( - "Redis Cluster cannot be connected. Please provide at least " - "one reachable node. " - ) - - # Create Redis connections to all nodes - self.create_redis_connections(list(tmp_nodes_cache.values())) - - # Check if the slots are not fully covered - if not fully_covered and self._require_full_coverage: - # Despite the requirement that the slots be covered, there - # isn't a full coverage - raise RedisClusterException( - f"All slots are not covered after query all startup_nodes. " - f"{len(self.slots_cache)} of {REDIS_CLUSTER_HASH_SLOTS} " - f"covered..." - ) - elif not fully_covered and not self._require_full_coverage: - # The user set require_full_coverage to False. - # In case of full coverage requirement in the cluster's Redis - # configurations, we will raise an exception. Otherwise, we may - # continue with partial coverage. - # see Redis Cluster configuration parameters in - # https://redis.io/topics/cluster-tutorial - if ( - not self._skip_full_coverage_check - and self.cluster_require_full_coverage(tmp_nodes_cache) - ): - raise RedisClusterException( - "Not all slots are covered but the cluster's " - "configuration requires full coverage. Set " - "cluster-require-full-coverage configuration to no on " - "all of the cluster nodes if you wish the cluster to " - "be able to serve without being fully covered." - f"{len(self.slots_cache)} of {REDIS_CLUSTER_HASH_SLOTS} " - f"covered..." - ) - - # Set the tmp variables to the real variables - self.nodes_cache = tmp_nodes_cache - self.slots_cache = tmp_slots - # Set the default node - self.default_node = self.get_nodes_by_server_type(PRIMARY)[0] - # Populate the startup nodes with all discovered nodes - self.populate_startup_nodes(self.nodes_cache.values()) - # If initialize was called after a MovedError, clear it - self._moved_exception = None - - def close(self): - self.default_node = None - for node in self.nodes_cache.values(): - if node.redis_connection: - node.redis_connection.close() - - def reset(self): - try: - self.read_load_balancer.reset() - except TypeError: - # The read_load_balancer is None, do nothing - pass - - -class ClusterPubSub(PubSub): - """ - Wrapper for PubSub class. - - IMPORTANT: before using ClusterPubSub, read about the known limitations - with pubsub in Cluster mode and learn how to workaround them: - https://redis-py-cluster.readthedocs.io/en/stable/pubsub.html - """ - - def __init__(self, redis_cluster, node=None, host=None, port=None, **kwargs): - """ - When a pubsub instance is created without specifying a node, a single - node will be transparently chosen for the pubsub connection on the - first command execution. The node will be determined by: - 1. Hashing the channel name in the request to find its keyslot - 2. Selecting a node that handles the keyslot: If read_from_replicas is - set to true, a replica can be selected. - - :type redis_cluster: RedisCluster - :type node: ClusterNode - :type host: str - :type port: int - """ - log.info("Creating new instance of ClusterPubSub") - self.node = None - self.set_pubsub_node(redis_cluster, node, host, port) - connection_pool = ( - None - if self.node is None - else redis_cluster.get_redis_connection(self.node).connection_pool - ) - self.cluster = redis_cluster - super().__init__( - **kwargs, connection_pool=connection_pool, encoder=redis_cluster.encoder - ) - - def set_pubsub_node(self, cluster, node=None, host=None, port=None): - """ - The pubsub node will be set according to the passed node, host and port - When none of the node, host, or port are specified - the node is set - to None and will be determined by the keyslot of the channel in the - first command to be executed. - RedisClusterException will be thrown if the passed node does not exist - in the cluster. - If host is passed without port, or vice versa, a DataError will be - thrown. - :type cluster: RedisCluster - :type node: ClusterNode - :type host: str - :type port: int - """ - if node is not None: - # node is passed by the user - self._raise_on_invalid_node(cluster, node, node.host, node.port) - pubsub_node = node - elif host is not None and port is not None: - # host and port passed by the user - node = cluster.get_node(host=host, port=port) - self._raise_on_invalid_node(cluster, node, host, port) - pubsub_node = node - elif any([host, port]) is True: - # only 'host' or 'port' passed - raise DataError("Passing a host requires passing a port, " "and vice versa") - else: - # nothing passed by the user. set node to None - pubsub_node = None - - self.node = pubsub_node - - def get_pubsub_node(self): - """ - Get the node that is being used as the pubsub connection - """ - return self.node - - def _raise_on_invalid_node(self, redis_cluster, node, host, port): - """ - Raise a RedisClusterException if the node is None or doesn't exist in - the cluster. - """ - if node is None or redis_cluster.get_node(node_name=node.name) is None: - raise RedisClusterException( - f"Node {host}:{port} doesn't exist in the cluster" - ) - - def execute_command(self, *args, **kwargs): - """ - Execute a publish/subscribe command. - - Taken code from redis-py and tweak to make it work within a cluster. - """ - # NOTE: don't parse the response in this function -- it could pull a - # legitimate message off the stack if the connection is already - # subscribed to one or more channels - - if self.connection is None: - if self.connection_pool is None: - if len(args) > 1: - # Hash the first channel and get one of the nodes holding - # this slot - channel = args[1] - slot = self.cluster.keyslot(channel) - node = self.cluster.nodes_manager.get_node_from_slot( - slot, self.cluster.read_from_replicas - ) - else: - # Get a random node - node = self.cluster.get_random_node() - self.node = node - redis_connection = self.cluster.get_redis_connection(node) - self.connection_pool = redis_connection.connection_pool - self.connection = self.connection_pool.get_connection( - "pubsub", self.shard_hint - ) - # register a callback that re-subscribes to any channels we - # were listening to when we were disconnected - self.connection.register_connect_callback(self.on_connect) - connection = self.connection - self._execute(connection, connection.send_command, *args) - - def get_redis_connection(self): - """ - Get the Redis connection of the pubsub connected node. - """ - if self.node is not None: - return self.node.redis_connection - - -class ClusterPipeline(RedisCluster): - """ - Support for Redis pipeline - in cluster mode - """ - - ERRORS_ALLOW_RETRY = ( - ConnectionError, - TimeoutError, - MovedError, - AskError, - TryAgainError, - ) - - def __init__( - self, - nodes_manager, - result_callbacks=None, - cluster_response_callbacks=None, - startup_nodes=None, - read_from_replicas=False, - cluster_error_retry_attempts=5, - reinitialize_steps=10, - **kwargs, - ): - """ """ - log.info("Creating new instance of ClusterPipeline") - self.command_stack = [] - self.nodes_manager = nodes_manager - self.refresh_table_asap = False - self.result_callbacks = ( - result_callbacks or self.__class__.RESULT_CALLBACKS.copy() - ) - self.startup_nodes = startup_nodes if startup_nodes else [] - self.read_from_replicas = read_from_replicas - self.command_flags = self.__class__.COMMAND_FLAGS.copy() - self.cluster_response_callbacks = cluster_response_callbacks - self.cluster_error_retry_attempts = cluster_error_retry_attempts - self.reinitialize_counter = 0 - self.reinitialize_steps = reinitialize_steps - self.encoder = Encoder( - kwargs.get("encoding", "utf-8"), - kwargs.get("encoding_errors", "strict"), - kwargs.get("decode_responses", False), - ) - - # The commands parser refers to the parent - # so that we don't push the COMMAND command - # onto the stack - self.commands_parser = CommandsParser(super()) - - def __repr__(self): - """ """ - return f"{type(self).__name__}" - - def __enter__(self): - """ """ - return self - - def __exit__(self, exc_type, exc_value, traceback): - """ """ - self.reset() - - def __del__(self): - try: - self.reset() - except Exception: - pass - - def __len__(self): - """ """ - return len(self.command_stack) - - def __nonzero__(self): - "Pipeline instances should always evaluate to True on Python 2.7" - return True - - def __bool__(self): - "Pipeline instances should always evaluate to True on Python 3+" - return True - - def execute_command(self, *args, **kwargs): - """ - Wrapper function for pipeline_execute_command - """ - return self.pipeline_execute_command(*args, **kwargs) - - def pipeline_execute_command(self, *args, **options): - """ - Appends the executed command to the pipeline's command stack - """ - self.command_stack.append( - PipelineCommand(args, options, len(self.command_stack)) - ) - return self - - def raise_first_error(self, stack): - """ - Raise the first exception on the stack - """ - for c in stack: - r = c.result - if isinstance(r, Exception): - self.annotate_exception(r, c.position + 1, c.args) - raise r - - def annotate_exception(self, exception, number, command): - """ - Provides extra context to the exception prior to it being handled - """ - cmd = " ".join(map(safe_str, command)) - msg = ( - f"Command # {number} ({cmd}) of pipeline " - f"caused error: {exception.args[0]}" - ) - exception.args = (msg,) + exception.args[1:] - - def execute(self, raise_on_error=True): - """ - Execute all the commands in the current pipeline - """ - stack = self.command_stack - try: - return self.send_cluster_commands(stack, raise_on_error) - finally: - self.reset() - - def reset(self): - """ - Reset back to empty pipeline. - """ - self.command_stack = [] - - self.scripts = set() - - # TODO: Implement - # make sure to reset the connection state in the event that we were - # watching something - # if self.watching and self.connection: - # try: - # # call this manually since our unwatch or - # # immediate_execute_command methods can call reset() - # self.connection.send_command('UNWATCH') - # self.connection.read_response() - # except ConnectionError: - # # disconnect will also remove any previous WATCHes - # self.connection.disconnect() - - # clean up the other instance attributes - self.watching = False - self.explicit_transaction = False - - # TODO: Implement - # we can safely return the connection to the pool here since we're - # sure we're no longer WATCHing anything - # if self.connection: - # self.connection_pool.release(self.connection) - # self.connection = None - - def send_cluster_commands( - self, stack, raise_on_error=True, allow_redirections=True - ): - """ - Wrapper for CLUSTERDOWN error handling. - - If the cluster reports it is down it is assumed that: - - connection_pool was disconnected - - connection_pool was reseted - - refereh_table_asap set to True - - It will try the number of times specified by - the config option "self.cluster_error_retry_attempts" - which defaults to 3 unless manually configured. - - If it reaches the number of times, the command will - raises ClusterDownException. - """ - if not stack: - return [] - - for _ in range(0, self.cluster_error_retry_attempts): - try: - return self._send_cluster_commands( - stack, - raise_on_error=raise_on_error, - allow_redirections=allow_redirections, - ) - except ClusterDownError: - # Try again with the new cluster setup. All other errors - # should be raised. - pass - - # If it fails the configured number of times then raise - # exception back to caller of this method - raise ClusterDownError("CLUSTERDOWN error. Unable to rebuild the cluster") - - def _send_cluster_commands( - self, stack, raise_on_error=True, allow_redirections=True - ): - """ - Send a bunch of cluster commands to the redis cluster. - - `allow_redirections` If the pipeline should follow - `ASK` & `MOVED` responses automatically. If set - to false it will raise RedisClusterException. - """ - # the first time sending the commands we send all of - # the commands that were queued up. - # if we have to run through it again, we only retry - # the commands that failed. - attempt = sorted(stack, key=lambda x: x.position) - - # build a list of node objects based on node names we need to - nodes = {} - - # as we move through each command that still needs to be processed, - # we figure out the slot number that command maps to, then from - # the slot determine the node. - for c in attempt: - # refer to our internal node -> slot table that - # tells us where a given - # command should route to. - slot = self.determine_slot(*c.args) - node = self.nodes_manager.get_node_from_slot( - slot, self.read_from_replicas and c.args[0] in READ_COMMANDS - ) - - # now that we know the name of the node - # ( it's just a string in the form of host:port ) - # we can build a list of commands for each node. - node_name = node.name - if node_name not in nodes: - redis_node = self.get_redis_connection(node) - connection = get_connection(redis_node, c.args) - nodes[node_name] = NodeCommands( - redis_node.parse_response, redis_node.connection_pool, connection - ) - - nodes[node_name].append(c) - - # send the commands in sequence. - # we write to all the open sockets for each node first, - # before reading anything - # this allows us to flush all the requests out across the - # network essentially in parallel - # so that we can read them all in parallel as they come back. - # we dont' multiplex on the sockets as they come available, - # but that shouldn't make too much difference. - node_commands = nodes.values() - for n in node_commands: - n.write() - - for n in node_commands: - n.read() - - # release all of the redis connections we allocated earlier - # back into the connection pool. - # we used to do this step as part of a try/finally block, - # but it is really dangerous to - # release connections back into the pool if for some - # reason the socket has data still left in it - # from a previous operation. The write and - # read operations already have try/catch around them for - # all known types of errors including connection - # and socket level errors. - # So if we hit an exception, something really bad - # happened and putting any oF - # these connections back into the pool is a very bad idea. - # the socket might have unread buffer still sitting in it, - # and then the next time we read from it we pass the - # buffered result back from a previous command and - # every single request after to that connection will always get - # a mismatched result. - for n in nodes.values(): - n.connection_pool.release(n.connection) - - # if the response isn't an exception it is a - # valid response from the node - # we're all done with that command, YAY! - # if we have more commands to attempt, we've run into problems. - # collect all the commands we are allowed to retry. - # (MOVED, ASK, or connection errors or timeout errors) - attempt = sorted( - ( - c - for c in attempt - if isinstance(c.result, ClusterPipeline.ERRORS_ALLOW_RETRY) - ), - key=lambda x: x.position, - ) - if attempt and allow_redirections: - # RETRY MAGIC HAPPENS HERE! - # send these remaing comamnds one at a time using `execute_command` - # in the main client. This keeps our retry logic - # in one place mostly, - # and allows us to be more confident in correctness of behavior. - # at this point any speed gains from pipelining have been lost - # anyway, so we might as well make the best - # attempt to get the correct behavior. - # - # The client command will handle retries for each - # individual command sequentially as we pass each - # one into `execute_command`. Any exceptions - # that bubble out should only appear once all - # retries have been exhausted. - # - # If a lot of commands have failed, we'll be setting the - # flag to rebuild the slots table from scratch. - # So MOVED errors should correct themselves fairly quickly. - log.exception( - f"An exception occurred during pipeline execution. " - f"args: {attempt[-1].args}, " - f"error: {type(attempt[-1].result).__name__} " - f"{str(attempt[-1].result)}" - ) - self.reinitialize_counter += 1 - if self._should_reinitialized(): - self.nodes_manager.initialize() - for c in attempt: - try: - # send each command individually like we - # do in the main client. - c.result = super().execute_command(*c.args, **c.options) - except RedisError as e: - c.result = e - - # turn the response back into a simple flat array that corresponds - # to the sequence of commands issued in the stack in pipeline.execute() - response = [c.result for c in sorted(stack, key=lambda x: x.position)] - - if raise_on_error: - self.raise_first_error(stack) - - return response - - def _fail_on_redirect(self, allow_redirections): - """ """ - if not allow_redirections: - raise RedisClusterException( - "ASK & MOVED redirection not allowed in this pipeline" - ) - - def eval(self): - """ """ - raise RedisClusterException("method eval() is not implemented") - - def multi(self): - """ """ - raise RedisClusterException("method multi() is not implemented") - - def immediate_execute_command(self, *args, **options): - """ """ - raise RedisClusterException( - "method immediate_execute_command() is not implemented" - ) - - def _execute_transaction(self, *args, **kwargs): - """ """ - raise RedisClusterException("method _execute_transaction() is not implemented") - - def load_scripts(self): - """ """ - raise RedisClusterException("method load_scripts() is not implemented") - - def watch(self, *names): - """ """ - raise RedisClusterException("method watch() is not implemented") - - def unwatch(self): - """ """ - raise RedisClusterException("method unwatch() is not implemented") - - def script_load_for_pipeline(self, *args, **kwargs): - """ """ - raise RedisClusterException( - "method script_load_for_pipeline() is not implemented" - ) - - def delete(self, *names): - """ - "Delete a key specified by ``names``" - """ - if len(names) != 1: - raise RedisClusterException( - "deleting multiple keys is not " "implemented in pipeline command" - ) - - return self.execute_command("DEL", names[0]) - - -def block_pipeline_command(func): - """ - Prints error because some pipelined commands should - be blocked when running in cluster-mode - """ - - def inner(*args, **kwargs): - raise RedisClusterException( - f"ERROR: Calling pipelined function {func.__name__} is blocked " - f"when running redis in cluster mode..." - ) - - return inner - - -# Blocked pipeline commands -ClusterPipeline.bitop = block_pipeline_command(RedisCluster.bitop) -ClusterPipeline.brpoplpush = block_pipeline_command(RedisCluster.brpoplpush) -ClusterPipeline.client_getname = block_pipeline_command(RedisCluster.client_getname) -ClusterPipeline.client_list = block_pipeline_command(RedisCluster.client_list) -ClusterPipeline.client_setname = block_pipeline_command(RedisCluster.client_setname) -ClusterPipeline.config_set = block_pipeline_command(RedisCluster.config_set) -ClusterPipeline.dbsize = block_pipeline_command(RedisCluster.dbsize) -ClusterPipeline.flushall = block_pipeline_command(RedisCluster.flushall) -ClusterPipeline.flushdb = block_pipeline_command(RedisCluster.flushdb) -ClusterPipeline.keys = block_pipeline_command(RedisCluster.keys) -ClusterPipeline.mget = block_pipeline_command(RedisCluster.mget) -ClusterPipeline.move = block_pipeline_command(RedisCluster.move) -ClusterPipeline.mset = block_pipeline_command(RedisCluster.mset) -ClusterPipeline.msetnx = block_pipeline_command(RedisCluster.msetnx) -ClusterPipeline.pfmerge = block_pipeline_command(RedisCluster.pfmerge) -ClusterPipeline.pfcount = block_pipeline_command(RedisCluster.pfcount) -ClusterPipeline.ping = block_pipeline_command(RedisCluster.ping) -ClusterPipeline.publish = block_pipeline_command(RedisCluster.publish) -ClusterPipeline.randomkey = block_pipeline_command(RedisCluster.randomkey) -ClusterPipeline.rename = block_pipeline_command(RedisCluster.rename) -ClusterPipeline.renamenx = block_pipeline_command(RedisCluster.renamenx) -ClusterPipeline.rpoplpush = block_pipeline_command(RedisCluster.rpoplpush) -ClusterPipeline.scan = block_pipeline_command(RedisCluster.scan) -ClusterPipeline.sdiff = block_pipeline_command(RedisCluster.sdiff) -ClusterPipeline.sdiffstore = block_pipeline_command(RedisCluster.sdiffstore) -ClusterPipeline.sinter = block_pipeline_command(RedisCluster.sinter) -ClusterPipeline.sinterstore = block_pipeline_command(RedisCluster.sinterstore) -ClusterPipeline.smove = block_pipeline_command(RedisCluster.smove) -ClusterPipeline.sort = block_pipeline_command(RedisCluster.sort) -ClusterPipeline.sunion = block_pipeline_command(RedisCluster.sunion) -ClusterPipeline.sunionstore = block_pipeline_command(RedisCluster.sunionstore) -ClusterPipeline.readwrite = block_pipeline_command(RedisCluster.readwrite) -ClusterPipeline.readonly = block_pipeline_command(RedisCluster.readonly) - - -class PipelineCommand: - """ """ - - def __init__(self, args, options=None, position=None): - self.args = args - if options is None: - options = {} - self.options = options - self.position = position - self.result = None - self.node = None - self.asking = False - - -class NodeCommands: - """ """ - - def __init__(self, parse_response, connection_pool, connection): - """ """ - self.parse_response = parse_response - self.connection_pool = connection_pool - self.connection = connection - self.commands = [] - - def append(self, c): - """ """ - self.commands.append(c) - - def write(self): - """ - Code borrowed from Redis so it can be fixed - """ - connection = self.connection - commands = self.commands - - # We are going to clobber the commands with the write, so go ahead - # and ensure that nothing is sitting there from a previous run. - for c in commands: - c.result = None - - # build up all commands into a single request to increase network perf - # send all the commands and catch connection and timeout errors. - try: - connection.send_packed_command( - connection.pack_commands([c.args for c in commands]) - ) - except (ConnectionError, TimeoutError) as e: - for c in commands: - c.result = e - - def read(self): - """ """ - connection = self.connection - for c in self.commands: - - # if there is a result on this command, - # it means we ran into an exception - # like a connection error. Trying to parse - # a response on a connection that - # is no longer open will result in a - # connection error raised by redis-py. - # but redis-py doesn't check in parse_response - # that the sock object is - # still set and if you try to - # read from a closed connection, it will - # result in an AttributeError because - # it will do a readline() call on None. - # This can have all kinds of nasty side-effects. - # Treating this case as a connection error - # is fine because it will dump - # the connection object back into the - # pool and on the next write, it will - # explicitly open the connection and all will be well. - if c.result is None: - try: - c.result = self.parse_response(connection, c.args[0], **c.options) - except (ConnectionError, TimeoutError) as e: - for c in self.commands: - c.result = e - return - except RedisError: - c.result = sys.exc_info()[1] diff --git a/redis/commands/__init__.py b/redis/commands/__init__.py index 07fa7f1..d6fea59 100644 --- a/redis/commands/__init__.py +++ b/redis/commands/__init__.py @@ -1,15 +1,11 @@ -from .cluster import RedisClusterCommands from .core import CoreCommands from .helpers import list_or_args from .parser import CommandsParser -from .redismodules import RedisModuleCommands -from .sentinel import SentinelCommands __all__ = [ - "RedisClusterCommands", "CommandsParser", "CoreCommands", "list_or_args", - "RedisModuleCommands", - "SentinelCommands", + 'CoreCommands', + 'list_or_args' ] diff --git a/redis/commands/bf/__init__.py b/redis/commands/bf/__init__.py deleted file mode 100644 index f34e11d..0000000 --- a/redis/commands/bf/__init__.py +++ /dev/null @@ -1,204 +0,0 @@ -from redis.client import bool_ok - -from ..helpers import parse_to_list -from .commands import * # noqa -from .info import BFInfo, CFInfo, CMSInfo, TDigestInfo, TopKInfo - - -class AbstractBloom(object): - """ - The client allows to interact with RedisBloom and use all of - it's functionality. - - - BF for Bloom Filter - - CF for Cuckoo Filter - - CMS for Count-Min Sketch - - TOPK for TopK Data Structure - - TDIGEST for estimate rank statistics - """ - - @staticmethod - def appendItems(params, items): - """Append ITEMS to params.""" - params.extend(["ITEMS"]) - params += items - - @staticmethod - def appendError(params, error): - """Append ERROR to params.""" - if error is not None: - params.extend(["ERROR", error]) - - @staticmethod - def appendCapacity(params, capacity): - """Append CAPACITY to params.""" - if capacity is not None: - params.extend(["CAPACITY", capacity]) - - @staticmethod - def appendExpansion(params, expansion): - """Append EXPANSION to params.""" - if expansion is not None: - params.extend(["EXPANSION", expansion]) - - @staticmethod - def appendNoScale(params, noScale): - """Append NONSCALING tag to params.""" - if noScale is not None: - params.extend(["NONSCALING"]) - - @staticmethod - def appendWeights(params, weights): - """Append WEIGHTS to params.""" - if len(weights) > 0: - params.append("WEIGHTS") - params += weights - - @staticmethod - def appendNoCreate(params, noCreate): - """Append NOCREATE tag to params.""" - if noCreate is not None: - params.extend(["NOCREATE"]) - - @staticmethod - def appendItemsAndIncrements(params, items, increments): - """Append pairs of items and increments to params.""" - for i in range(len(items)): - params.append(items[i]) - params.append(increments[i]) - - @staticmethod - def appendValuesAndWeights(params, items, weights): - """Append pairs of items and weights to params.""" - for i in range(len(items)): - params.append(items[i]) - params.append(weights[i]) - - @staticmethod - def appendMaxIterations(params, max_iterations): - """Append MAXITERATIONS to params.""" - if max_iterations is not None: - params.extend(["MAXITERATIONS", max_iterations]) - - @staticmethod - def appendBucketSize(params, bucket_size): - """Append BUCKETSIZE to params.""" - if bucket_size is not None: - params.extend(["BUCKETSIZE", bucket_size]) - - -class CMSBloom(CMSCommands, AbstractBloom): - def __init__(self, client, **kwargs): - """Create a new RedisBloom client.""" - # Set the module commands' callbacks - MODULE_CALLBACKS = { - CMS_INITBYDIM: bool_ok, - CMS_INITBYPROB: bool_ok, - # CMS_INCRBY: spaceHolder, - # CMS_QUERY: spaceHolder, - CMS_MERGE: bool_ok, - CMS_INFO: CMSInfo, - } - - self.client = client - self.commandmixin = CMSCommands - self.execute_command = client.execute_command - - for k, v in MODULE_CALLBACKS.items(): - self.client.set_response_callback(k, v) - - -class TOPKBloom(TOPKCommands, AbstractBloom): - def __init__(self, client, **kwargs): - """Create a new RedisBloom client.""" - # Set the module commands' callbacks - MODULE_CALLBACKS = { - TOPK_RESERVE: bool_ok, - TOPK_ADD: parse_to_list, - TOPK_INCRBY: parse_to_list, - # TOPK_QUERY: spaceHolder, - # TOPK_COUNT: spaceHolder, - TOPK_LIST: parse_to_list, - TOPK_INFO: TopKInfo, - } - - self.client = client - self.commandmixin = TOPKCommands - self.execute_command = client.execute_command - - for k, v in MODULE_CALLBACKS.items(): - self.client.set_response_callback(k, v) - - -class CFBloom(CFCommands, AbstractBloom): - def __init__(self, client, **kwargs): - """Create a new RedisBloom client.""" - # Set the module commands' callbacks - MODULE_CALLBACKS = { - CF_RESERVE: bool_ok, - # CF_ADD: spaceHolder, - # CF_ADDNX: spaceHolder, - # CF_INSERT: spaceHolder, - # CF_INSERTNX: spaceHolder, - # CF_EXISTS: spaceHolder, - # CF_DEL: spaceHolder, - # CF_COUNT: spaceHolder, - # CF_SCANDUMP: spaceHolder, - # CF_LOADCHUNK: spaceHolder, - CF_INFO: CFInfo, - } - - self.client = client - self.commandmixin = CFCommands - self.execute_command = client.execute_command - - for k, v in MODULE_CALLBACKS.items(): - self.client.set_response_callback(k, v) - - -class TDigestBloom(TDigestCommands, AbstractBloom): - def __init__(self, client, **kwargs): - """Create a new RedisBloom client.""" - # Set the module commands' callbacks - MODULE_CALLBACKS = { - TDIGEST_CREATE: bool_ok, - # TDIGEST_RESET: bool_ok, - # TDIGEST_ADD: spaceHolder, - # TDIGEST_MERGE: spaceHolder, - TDIGEST_CDF: float, - TDIGEST_QUANTILE: float, - TDIGEST_MIN: float, - TDIGEST_MAX: float, - TDIGEST_INFO: TDigestInfo, - } - - self.client = client - self.commandmixin = TDigestCommands - self.execute_command = client.execute_command - - for k, v in MODULE_CALLBACKS.items(): - self.client.set_response_callback(k, v) - - -class BFBloom(BFCommands, AbstractBloom): - def __init__(self, client, **kwargs): - """Create a new RedisBloom client.""" - # Set the module commands' callbacks - MODULE_CALLBACKS = { - BF_RESERVE: bool_ok, - # BF_ADD: spaceHolder, - # BF_MADD: spaceHolder, - # BF_INSERT: spaceHolder, - # BF_EXISTS: spaceHolder, - # BF_MEXISTS: spaceHolder, - # BF_SCANDUMP: spaceHolder, - # BF_LOADCHUNK: spaceHolder, - BF_INFO: BFInfo, - } - - self.client = client - self.commandmixin = BFCommands - self.execute_command = client.execute_command - - for k, v in MODULE_CALLBACKS.items(): - self.client.set_response_callback(k, v) diff --git a/redis/commands/bf/commands.py b/redis/commands/bf/commands.py deleted file mode 100644 index 7fc507d..0000000 --- a/redis/commands/bf/commands.py +++ /dev/null @@ -1,498 +0,0 @@ -from redis.client import NEVER_DECODE -from redis.exceptions import ModuleError -from redis.utils import HIREDIS_AVAILABLE - -BF_RESERVE = "BF.RESERVE" -BF_ADD = "BF.ADD" -BF_MADD = "BF.MADD" -BF_INSERT = "BF.INSERT" -BF_EXISTS = "BF.EXISTS" -BF_MEXISTS = "BF.MEXISTS" -BF_SCANDUMP = "BF.SCANDUMP" -BF_LOADCHUNK = "BF.LOADCHUNK" -BF_INFO = "BF.INFO" - -CF_RESERVE = "CF.RESERVE" -CF_ADD = "CF.ADD" -CF_ADDNX = "CF.ADDNX" -CF_INSERT = "CF.INSERT" -CF_INSERTNX = "CF.INSERTNX" -CF_EXISTS = "CF.EXISTS" -CF_DEL = "CF.DEL" -CF_COUNT = "CF.COUNT" -CF_SCANDUMP = "CF.SCANDUMP" -CF_LOADCHUNK = "CF.LOADCHUNK" -CF_INFO = "CF.INFO" - -CMS_INITBYDIM = "CMS.INITBYDIM" -CMS_INITBYPROB = "CMS.INITBYPROB" -CMS_INCRBY = "CMS.INCRBY" -CMS_QUERY = "CMS.QUERY" -CMS_MERGE = "CMS.MERGE" -CMS_INFO = "CMS.INFO" - -TOPK_RESERVE = "TOPK.RESERVE" -TOPK_ADD = "TOPK.ADD" -TOPK_INCRBY = "TOPK.INCRBY" -TOPK_QUERY = "TOPK.QUERY" -TOPK_COUNT = "TOPK.COUNT" -TOPK_LIST = "TOPK.LIST" -TOPK_INFO = "TOPK.INFO" - -TDIGEST_CREATE = "TDIGEST.CREATE" -TDIGEST_RESET = "TDIGEST.RESET" -TDIGEST_ADD = "TDIGEST.ADD" -TDIGEST_MERGE = "TDIGEST.MERGE" -TDIGEST_CDF = "TDIGEST.CDF" -TDIGEST_QUANTILE = "TDIGEST.QUANTILE" -TDIGEST_MIN = "TDIGEST.MIN" -TDIGEST_MAX = "TDIGEST.MAX" -TDIGEST_INFO = "TDIGEST.INFO" - - -class BFCommands: - """Bloom Filter commands.""" - - # region Bloom Filter Functions - def create(self, key, errorRate, capacity, expansion=None, noScale=None): - """ - Create a new Bloom Filter `key` with desired probability of false positives - `errorRate` expected entries to be inserted as `capacity`. - Default expansion value is 2. By default, filter is auto-scaling. - For more information see `BF.RESERVE `_. - """ # noqa - params = [key, errorRate, capacity] - self.appendExpansion(params, expansion) - self.appendNoScale(params, noScale) - return self.execute_command(BF_RESERVE, *params) - - def add(self, key, item): - """ - Add to a Bloom Filter `key` an `item`. - For more information see `BF.ADD `_. - """ # noqa - params = [key, item] - return self.execute_command(BF_ADD, *params) - - def madd(self, key, *items): - """ - Add to a Bloom Filter `key` multiple `items`. - For more information see `BF.MADD `_. - """ # noqa - params = [key] - params += items - return self.execute_command(BF_MADD, *params) - - def insert( - self, - key, - items, - capacity=None, - error=None, - noCreate=None, - expansion=None, - noScale=None, - ): - """ - Add to a Bloom Filter `key` multiple `items`. - - If `nocreate` remain `None` and `key` does not exist, a new Bloom Filter - `key` will be created with desired probability of false positives `errorRate` - and expected entries to be inserted as `size`. - For more information see `BF.INSERT `_. - """ # noqa - params = [key] - self.appendCapacity(params, capacity) - self.appendError(params, error) - self.appendExpansion(params, expansion) - self.appendNoCreate(params, noCreate) - self.appendNoScale(params, noScale) - self.appendItems(params, items) - - return self.execute_command(BF_INSERT, *params) - - def exists(self, key, item): - """ - Check whether an `item` exists in Bloom Filter `key`. - For more information see `BF.EXISTS `_. - """ # noqa - params = [key, item] - return self.execute_command(BF_EXISTS, *params) - - def mexists(self, key, *items): - """ - Check whether `items` exist in Bloom Filter `key`. - For more information see `BF.MEXISTS `_. - """ # noqa - params = [key] - params += items - return self.execute_command(BF_MEXISTS, *params) - - def scandump(self, key, iter): - """ - Begin an incremental save of the bloom filter `key`. - - This is useful for large bloom filters which cannot fit into the normal SAVE and RESTORE model. - The first time this command is called, the value of `iter` should be 0. - This command will return successive (iter, data) pairs until (0, NULL) to indicate completion. - For more information see `BF.SCANDUMP `_. - """ # noqa - if HIREDIS_AVAILABLE: - raise ModuleError("This command cannot be used when hiredis is available.") - - params = [key, iter] - options = {} - options[NEVER_DECODE] = [] - return self.execute_command(BF_SCANDUMP, *params, **options) - - def loadchunk(self, key, iter, data): - """ - Restore a filter previously saved using SCANDUMP. - - See the SCANDUMP command for example usage. - This command will overwrite any bloom filter stored under key. - Ensure that the bloom filter will not be modified between invocations. - For more information see `BF.LOADCHUNK `_. - """ # noqa - params = [key, iter, data] - return self.execute_command(BF_LOADCHUNK, *params) - - def info(self, key): - """ - Return capacity, size, number of filters, number of items inserted, and expansion rate. - For more information see `BF.INFO `_. - """ # noqa - return self.execute_command(BF_INFO, key) - - -class CFCommands: - """Cuckoo Filter commands.""" - - # region Cuckoo Filter Functions - def create( - self, key, capacity, expansion=None, bucket_size=None, max_iterations=None - ): - """ - Create a new Cuckoo Filter `key` an initial `capacity` items. - For more information see `CF.RESERVE `_. - """ # noqa - params = [key, capacity] - self.appendExpansion(params, expansion) - self.appendBucketSize(params, bucket_size) - self.appendMaxIterations(params, max_iterations) - return self.execute_command(CF_RESERVE, *params) - - def add(self, key, item): - """ - Add an `item` to a Cuckoo Filter `key`. - For more information see `CF.ADD `_. - """ # noqa - params = [key, item] - return self.execute_command(CF_ADD, *params) - - def addnx(self, key, item): - """ - Add an `item` to a Cuckoo Filter `key` only if item does not yet exist. - Command might be slower that `add`. - For more information see `CF.ADDNX `_. - """ # noqa - params = [key, item] - return self.execute_command(CF_ADDNX, *params) - - def insert(self, key, items, capacity=None, nocreate=None): - """ - Add multiple `items` to a Cuckoo Filter `key`, allowing the filter - to be created with a custom `capacity` if it does not yet exist. - `items` must be provided as a list. - For more information see `CF.INSERT `_. - """ # noqa - params = [key] - self.appendCapacity(params, capacity) - self.appendNoCreate(params, nocreate) - self.appendItems(params, items) - return self.execute_command(CF_INSERT, *params) - - def insertnx(self, key, items, capacity=None, nocreate=None): - """ - Add multiple `items` to a Cuckoo Filter `key` only if they do not exist yet, - allowing the filter to be created with a custom `capacity` if it does not yet exist. - `items` must be provided as a list. - For more information see `CF.INSERTNX `_. - """ # noqa - params = [key] - self.appendCapacity(params, capacity) - self.appendNoCreate(params, nocreate) - self.appendItems(params, items) - return self.execute_command(CF_INSERTNX, *params) - - def exists(self, key, item): - """ - Check whether an `item` exists in Cuckoo Filter `key`. - For more information see `CF.EXISTS `_. - """ # noqa - params = [key, item] - return self.execute_command(CF_EXISTS, *params) - - def delete(self, key, item): - """ - Delete `item` from `key`. - For more information see `CF.DEL `_. - """ # noqa - params = [key, item] - return self.execute_command(CF_DEL, *params) - - def count(self, key, item): - """ - Return the number of times an `item` may be in the `key`. - For more information see `CF.COUNT `_. - """ # noqa - params = [key, item] - return self.execute_command(CF_COUNT, *params) - - def scandump(self, key, iter): - """ - Begin an incremental save of the Cuckoo filter `key`. - - This is useful for large Cuckoo filters which cannot fit into the normal - SAVE and RESTORE model. - The first time this command is called, the value of `iter` should be 0. - This command will return successive (iter, data) pairs until - (0, NULL) to indicate completion. - For more information see `CF.SCANDUMP `_. - """ # noqa - params = [key, iter] - return self.execute_command(CF_SCANDUMP, *params) - - def loadchunk(self, key, iter, data): - """ - Restore a filter previously saved using SCANDUMP. See the SCANDUMP command for example usage. - - This command will overwrite any Cuckoo filter stored under key. - Ensure that the Cuckoo filter will not be modified between invocations. - For more information see `CF.LOADCHUNK `_. - """ # noqa - params = [key, iter, data] - return self.execute_command(CF_LOADCHUNK, *params) - - def info(self, key): - """ - Return size, number of buckets, number of filter, number of items inserted, - number of items deleted, bucket size, expansion rate, and max iteration. - For more information see `CF.INFO `_. - """ # noqa - return self.execute_command(CF_INFO, key) - - -class TOPKCommands: - """TOP-k Filter commands.""" - - def reserve(self, key, k, width, depth, decay): - """ - Create a new Top-K Filter `key` with desired probability of false - positives `errorRate` expected entries to be inserted as `size`. - For more information see `TOPK.RESERVE `_. - """ # noqa - params = [key, k, width, depth, decay] - return self.execute_command(TOPK_RESERVE, *params) - - def add(self, key, *items): - """ - Add one `item` or more to a Top-K Filter `key`. - For more information see `TOPK.ADD `_. - """ # noqa - params = [key] - params += items - return self.execute_command(TOPK_ADD, *params) - - def incrby(self, key, items, increments): - """ - Add/increase `items` to a Top-K Sketch `key` by ''increments''. - Both `items` and `increments` are lists. - For more information see `TOPK.INCRBY `_. - - Example: - - >>> topkincrby('A', ['foo'], [1]) - """ # noqa - params = [key] - self.appendItemsAndIncrements(params, items, increments) - return self.execute_command(TOPK_INCRBY, *params) - - def query(self, key, *items): - """ - Check whether one `item` or more is a Top-K item at `key`. - For more information see `TOPK.QUERY `_. - """ # noqa - params = [key] - params += items - return self.execute_command(TOPK_QUERY, *params) - - def count(self, key, *items): - """ - Return count for one `item` or more from `key`. - For more information see `TOPK.COUNT `_. - """ # noqa - params = [key] - params += items - return self.execute_command(TOPK_COUNT, *params) - - def list(self, key, withcount=False): - """ - Return full list of items in Top-K list of `key`. - If `withcount` set to True, return full list of items - with probabilistic count in Top-K list of `key`. - For more information see `TOPK.LIST `_. - """ # noqa - params = [key] - if withcount: - params.append("WITHCOUNT") - return self.execute_command(TOPK_LIST, *params) - - def info(self, key): - """ - Return k, width, depth and decay values of `key`. - For more information see `TOPK.INFO `_. - """ # noqa - return self.execute_command(TOPK_INFO, key) - - -class TDigestCommands: - def create(self, key, compression): - """ - Allocate the memory and initialize the t-digest. - For more information see `TDIGEST.CREATE `_. - """ # noqa - params = [key, compression] - return self.execute_command(TDIGEST_CREATE, *params) - - def reset(self, key): - """ - Reset the sketch `key` to zero - empty out the sketch and re-initialize it. - For more information see `TDIGEST.RESET `_. - """ # noqa - return self.execute_command(TDIGEST_RESET, key) - - def add(self, key, values, weights): - """ - Add one or more samples (value with weight) to a sketch `key`. - Both `values` and `weights` are lists. - For more information see `TDIGEST.ADD `_. - - Example: - - >>> tdigestadd('A', [1500.0], [1.0]) - """ # noqa - params = [key] - self.appendValuesAndWeights(params, values, weights) - return self.execute_command(TDIGEST_ADD, *params) - - def merge(self, toKey, fromKey): - """ - Merge all of the values from 'fromKey' to 'toKey' sketch. - For more information see `TDIGEST.MERGE `_. - """ # noqa - params = [toKey, fromKey] - return self.execute_command(TDIGEST_MERGE, *params) - - def min(self, key): - """ - Return minimum value from the sketch `key`. Will return DBL_MAX if the sketch is empty. - For more information see `TDIGEST.MIN `_. - """ # noqa - return self.execute_command(TDIGEST_MIN, key) - - def max(self, key): - """ - Return maximum value from the sketch `key`. Will return DBL_MIN if the sketch is empty. - For more information see `TDIGEST.MAX `_. - """ # noqa - return self.execute_command(TDIGEST_MAX, key) - - def quantile(self, key, quantile): - """ - Return double value estimate of the cutoff such that a specified fraction of the data - added to this TDigest would be less than or equal to the cutoff. - For more information see `TDIGEST.QUANTILE `_. - """ # noqa - params = [key, quantile] - return self.execute_command(TDIGEST_QUANTILE, *params) - - def cdf(self, key, value): - """ - Return double fraction of all points added which are <= value. - For more information see `TDIGEST.CDF `_. - """ # noqa - params = [key, value] - return self.execute_command(TDIGEST_CDF, *params) - - def info(self, key): - """ - Return Compression, Capacity, Merged Nodes, Unmerged Nodes, Merged Weight, Unmerged Weight - and Total Compressions. - For more information see `TDIGEST.INFO `_. - """ # noqa - return self.execute_command(TDIGEST_INFO, key) - - -class CMSCommands: - """Count-Min Sketch Commands""" - - # region Count-Min Sketch Functions - def initbydim(self, key, width, depth): - """ - Initialize a Count-Min Sketch `key` to dimensions (`width`, `depth`) specified by user. - For more information see `CMS.INITBYDIM `_. - """ # noqa - params = [key, width, depth] - return self.execute_command(CMS_INITBYDIM, *params) - - def initbyprob(self, key, error, probability): - """ - Initialize a Count-Min Sketch `key` to characteristics (`error`, `probability`) specified by user. - For more information see `CMS.INITBYPROB `_. - """ # noqa - params = [key, error, probability] - return self.execute_command(CMS_INITBYPROB, *params) - - def incrby(self, key, items, increments): - """ - Add/increase `items` to a Count-Min Sketch `key` by ''increments''. - Both `items` and `increments` are lists. - For more information see `CMS.INCRBY `_. - - Example: - - >>> cmsincrby('A', ['foo'], [1]) - """ # noqa - params = [key] - self.appendItemsAndIncrements(params, items, increments) - return self.execute_command(CMS_INCRBY, *params) - - def query(self, key, *items): - """ - Return count for an `item` from `key`. Multiple items can be queried with one call. - For more information see `CMS.QUERY `_. - """ # noqa - params = [key] - params += items - return self.execute_command(CMS_QUERY, *params) - - def merge(self, destKey, numKeys, srcKeys, weights=[]): - """ - Merge `numKeys` of sketches into `destKey`. Sketches specified in `srcKeys`. - All sketches must have identical width and depth. - `Weights` can be used to multiply certain sketches. Default weight is 1. - Both `srcKeys` and `weights` are lists. - For more information see `CMS.MERGE `_. - """ # noqa - params = [destKey, numKeys] - params += srcKeys - self.appendWeights(params, weights) - return self.execute_command(CMS_MERGE, *params) - - def info(self, key): - """ - Return width, depth and total count of the sketch. - For more information see `CMS.INFO `_. - """ # noqa - return self.execute_command(CMS_INFO, key) diff --git a/redis/commands/bf/info.py b/redis/commands/bf/info.py deleted file mode 100644 index 24c5419..0000000 --- a/redis/commands/bf/info.py +++ /dev/null @@ -1,85 +0,0 @@ -from ..helpers import nativestr - - -class BFInfo(object): - capacity = None - size = None - filterNum = None - insertedNum = None - expansionRate = None - - def __init__(self, args): - response = dict(zip(map(nativestr, args[::2]), args[1::2])) - self.capacity = response["Capacity"] - self.size = response["Size"] - self.filterNum = response["Number of filters"] - self.insertedNum = response["Number of items inserted"] - self.expansionRate = response["Expansion rate"] - - -class CFInfo(object): - size = None - bucketNum = None - filterNum = None - insertedNum = None - deletedNum = None - bucketSize = None - expansionRate = None - maxIteration = None - - def __init__(self, args): - response = dict(zip(map(nativestr, args[::2]), args[1::2])) - self.size = response["Size"] - self.bucketNum = response["Number of buckets"] - self.filterNum = response["Number of filters"] - self.insertedNum = response["Number of items inserted"] - self.deletedNum = response["Number of items deleted"] - self.bucketSize = response["Bucket size"] - self.expansionRate = response["Expansion rate"] - self.maxIteration = response["Max iterations"] - - -class CMSInfo(object): - width = None - depth = None - count = None - - def __init__(self, args): - response = dict(zip(map(nativestr, args[::2]), args[1::2])) - self.width = response["width"] - self.depth = response["depth"] - self.count = response["count"] - - -class TopKInfo(object): - k = None - width = None - depth = None - decay = None - - def __init__(self, args): - response = dict(zip(map(nativestr, args[::2]), args[1::2])) - self.k = response["k"] - self.width = response["width"] - self.depth = response["depth"] - self.decay = response["decay"] - - -class TDigestInfo(object): - compression = None - capacity = None - mergedNodes = None - unmergedNodes = None - mergedWeight = None - unmergedWeight = None - totalCompressions = None - - def __init__(self, args): - response = dict(zip(map(nativestr, args[::2]), args[1::2])) - self.compression = response["Compression"] - self.capacity = response["Capacity"] - self.mergedNodes = response["Merged nodes"] - self.unmergedNodes = response["Unmerged nodes"] - self.mergedWeight = response["Merged weight"] - self.unmergedWeight = response["Unmerged weight"] - self.totalCompressions = response["Total compressions"] diff --git a/redis/commands/cluster.py b/redis/commands/cluster.py deleted file mode 100644 index 5d0e804..0000000 --- a/redis/commands/cluster.py +++ /dev/null @@ -1,412 +0,0 @@ -from redis.crc import key_slot -from redis.exceptions import RedisClusterException, RedisError - -from .core import ACLCommands, DataAccessCommands, ManagementCommands, PubSubCommands -from .helpers import list_or_args - - -class ClusterMultiKeyCommands: - """ - A class containing commands that handle more than one key - """ - - def _partition_keys_by_slot(self, keys): - """ - Split keys into a dictionary that maps a slot to - a list of keys. - """ - slots_to_keys = {} - for key in keys: - k = self.encoder.encode(key) - slot = key_slot(k) - slots_to_keys.setdefault(slot, []).append(key) - - return slots_to_keys - - def mget_nonatomic(self, keys, *args): - """ - Splits the keys into different slots and then calls MGET - for the keys of every slot. This operation will not be atomic - if keys belong to more than one slot. - - Returns a list of values ordered identically to ``keys`` - """ - - from redis.client import EMPTY_RESPONSE - - options = {} - if not args: - options[EMPTY_RESPONSE] = [] - - # Concatenate all keys into a list - keys = list_or_args(keys, args) - # Split keys into slots - slots_to_keys = self._partition_keys_by_slot(keys) - - # Call MGET for every slot and concatenate - # the results - # We must make sure that the keys are returned in order - all_results = {} - for slot_keys in slots_to_keys.values(): - slot_values = self.execute_command("MGET", *slot_keys, **options) - - slot_results = dict(zip(slot_keys, slot_values)) - all_results.update(slot_results) - - # Sort the results - vals_in_order = [all_results[key] for key in keys] - return vals_in_order - - def mset_nonatomic(self, mapping): - """ - Sets key/values based on a mapping. Mapping is a dictionary of - key/value pairs. Both keys and values should be strings or types that - can be cast to a string via str(). - - Splits the keys into different slots and then calls MSET - for the keys of every slot. This operation will not be atomic - if keys belong to more than one slot. - """ - - # Partition the keys by slot - slots_to_pairs = {} - for pair in mapping.items(): - # encode the key - k = self.encoder.encode(pair[0]) - slot = key_slot(k) - slots_to_pairs.setdefault(slot, []).extend(pair) - - # Call MSET for every slot and concatenate - # the results (one result per slot) - res = [] - for pairs in slots_to_pairs.values(): - res.append(self.execute_command("MSET", *pairs)) - - return res - - def _split_command_across_slots(self, command, *keys): - """ - Runs the given command once for the keys - of each slot. Returns the sum of the return values. - """ - # Partition the keys by slot - slots_to_keys = self._partition_keys_by_slot(keys) - - # Sum up the reply from each command - total = 0 - for slot_keys in slots_to_keys.values(): - total += self.execute_command(command, *slot_keys) - - return total - - def exists(self, *keys): - """ - Returns the number of ``names`` that exist in the - whole cluster. The keys are first split up into slots - and then an EXISTS command is sent for every slot - """ - return self._split_command_across_slots("EXISTS", *keys) - - def delete(self, *keys): - """ - Deletes the given keys in the cluster. - The keys are first split up into slots - and then an DEL command is sent for every slot - - Non-existant keys are ignored. - Returns the number of keys that were deleted. - """ - return self._split_command_across_slots("DEL", *keys) - - def touch(self, *keys): - """ - Updates the last access time of given keys across the - cluster. - - The keys are first split up into slots - and then an TOUCH command is sent for every slot - - Non-existant keys are ignored. - Returns the number of keys that were touched. - """ - return self._split_command_across_slots("TOUCH", *keys) - - def unlink(self, *keys): - """ - Remove the specified keys in a different thread. - - The keys are first split up into slots - and then an TOUCH command is sent for every slot - - Non-existant keys are ignored. - Returns the number of keys that were unlinked. - """ - return self._split_command_across_slots("UNLINK", *keys) - - -class ClusterManagementCommands(ManagementCommands): - """ - A class for Redis Cluster management commands - - The class inherits from Redis's core ManagementCommands class and do the - required adjustments to work with cluster mode - """ - - def slaveof(self, *args, **kwargs): - raise RedisClusterException("SLAVEOF is not supported in cluster mode") - - def replicaof(self, *args, **kwargs): - raise RedisClusterException("REPLICAOF is not supported in cluster" " mode") - - def swapdb(self, *args, **kwargs): - raise RedisClusterException("SWAPDB is not supported in cluster" " mode") - - -class ClusterDataAccessCommands(DataAccessCommands): - """ - A class for Redis Cluster Data Access Commands - - The class inherits from Redis's core DataAccessCommand class and do the - required adjustments to work with cluster mode - """ - - def stralgo( - self, - algo, - value1, - value2, - specific_argument="strings", - len=False, - idx=False, - minmatchlen=None, - withmatchlen=False, - **kwargs, - ): - target_nodes = kwargs.pop("target_nodes", None) - if specific_argument == "strings" and target_nodes is None: - target_nodes = "default-node" - kwargs.update({"target_nodes": target_nodes}) - return super().stralgo( - algo, - value1, - value2, - specific_argument, - len, - idx, - minmatchlen, - withmatchlen, - **kwargs, - ) - - -class RedisClusterCommands( - ClusterMultiKeyCommands, - ClusterManagementCommands, - ACLCommands, - PubSubCommands, - ClusterDataAccessCommands, -): - """ - A class for all Redis Cluster commands - - For key-based commands, the target node(s) will be internally determined - by the keys' hash slot. - Non-key-based commands can be executed with the 'target_nodes' argument to - target specific nodes. By default, if target_nodes is not specified, the - command will be executed on the default cluster node. - - - :param :target_nodes: type can be one of the followings: - - nodes flag: ALL_NODES, PRIMARIES, REPLICAS, RANDOM - - 'ClusterNode' - - 'list(ClusterNodes)' - - 'dict(any:clusterNodes)' - - for example: - r.cluster_info(target_nodes=RedisCluster.ALL_NODES) - """ - - def cluster_addslots(self, target_node, *slots): - """ - Assign new hash slots to receiving node. Sends to specified node. - - :target_node: 'ClusterNode' - The node to execute the command on - """ - return self.execute_command( - "CLUSTER ADDSLOTS", *slots, target_nodes=target_node - ) - - def cluster_countkeysinslot(self, slot_id): - """ - Return the number of local keys in the specified hash slot - Send to node based on specified slot_id - """ - return self.execute_command("CLUSTER COUNTKEYSINSLOT", slot_id) - - def cluster_count_failure_report(self, node_id): - """ - Return the number of failure reports active for a given node - Sends to a random node - """ - return self.execute_command("CLUSTER COUNT-FAILURE-REPORTS", node_id) - - def cluster_delslots(self, *slots): - """ - Set hash slots as unbound in the cluster. - It determines by it self what node the slot is in and sends it there - - Returns a list of the results for each processed slot. - """ - return [self.execute_command("CLUSTER DELSLOTS", slot) for slot in slots] - - def cluster_failover(self, target_node, option=None): - """ - Forces a slave to perform a manual failover of its master - Sends to specified node - - :target_node: 'ClusterNode' - The node to execute the command on - """ - if option: - if option.upper() not in ["FORCE", "TAKEOVER"]: - raise RedisError( - f"Invalid option for CLUSTER FAILOVER command: {option}" - ) - else: - return self.execute_command( - "CLUSTER FAILOVER", option, target_nodes=target_node - ) - else: - return self.execute_command("CLUSTER FAILOVER", target_nodes=target_node) - - def cluster_info(self, target_nodes=None): - """ - Provides info about Redis Cluster node state. - The command will be sent to a random node in the cluster if no target - node is specified. - """ - return self.execute_command("CLUSTER INFO", target_nodes=target_nodes) - - def cluster_keyslot(self, key): - """ - Returns the hash slot of the specified key - Sends to random node in the cluster - """ - return self.execute_command("CLUSTER KEYSLOT", key) - - def cluster_meet(self, host, port, target_nodes=None): - """ - Force a node cluster to handshake with another node. - Sends to specified node. - """ - return self.execute_command( - "CLUSTER MEET", host, port, target_nodes=target_nodes - ) - - def cluster_nodes(self): - """ - Force a node cluster to handshake with another node - - Sends to random node in the cluster - """ - return self.execute_command("CLUSTER NODES") - - def cluster_replicate(self, target_nodes, node_id): - """ - Reconfigure a node as a slave of the specified master node - """ - return self.execute_command( - "CLUSTER REPLICATE", node_id, target_nodes=target_nodes - ) - - def cluster_reset(self, soft=True, target_nodes=None): - """ - Reset a Redis Cluster node - - If 'soft' is True then it will send 'SOFT' argument - If 'soft' is False then it will send 'HARD' argument - """ - return self.execute_command( - "CLUSTER RESET", b"SOFT" if soft else b"HARD", target_nodes=target_nodes - ) - - def cluster_save_config(self, target_nodes=None): - """ - Forces the node to save cluster state on disk - """ - return self.execute_command("CLUSTER SAVECONFIG", target_nodes=target_nodes) - - def cluster_get_keys_in_slot(self, slot, num_keys): - """ - Returns the number of keys in the specified cluster slot - """ - return self.execute_command("CLUSTER GETKEYSINSLOT", slot, num_keys) - - def cluster_set_config_epoch(self, epoch, target_nodes=None): - """ - Set the configuration epoch in a new node - """ - return self.execute_command( - "CLUSTER SET-CONFIG-EPOCH", epoch, target_nodes=target_nodes - ) - - def cluster_setslot(self, target_node, node_id, slot_id, state): - """ - Bind an hash slot to a specific node - - :target_node: 'ClusterNode' - The node to execute the command on - """ - if state.upper() in ("IMPORTING", "NODE", "MIGRATING"): - return self.execute_command( - "CLUSTER SETSLOT", slot_id, state, node_id, target_nodes=target_node - ) - elif state.upper() == "STABLE": - raise RedisError('For "stable" state please use ' "cluster_setslot_stable") - else: - raise RedisError(f"Invalid slot state: {state}") - - def cluster_setslot_stable(self, slot_id): - """ - Clears migrating / importing state from the slot. - It determines by it self what node the slot is in and sends it there. - """ - return self.execute_command("CLUSTER SETSLOT", slot_id, "STABLE") - - def cluster_replicas(self, node_id, target_nodes=None): - """ - Provides a list of replica nodes replicating from the specified primary - target node. - """ - return self.execute_command( - "CLUSTER REPLICAS", node_id, target_nodes=target_nodes - ) - - def cluster_slots(self, target_nodes=None): - """ - Get array of Cluster slot to node mappings - """ - return self.execute_command("CLUSTER SLOTS", target_nodes=target_nodes) - - def readonly(self, target_nodes=None): - """ - Enables read queries. - The command will be sent to the default cluster node if target_nodes is - not specified. - """ - if target_nodes == "replicas" or target_nodes == "all": - # read_from_replicas will only be enabled if the READONLY command - # is sent to all replicas - self.read_from_replicas = True - return self.execute_command("READONLY", target_nodes=target_nodes) - - def readwrite(self, target_nodes=None): - """ - Disables read queries. - The command will be sent to the default cluster node if target_nodes is - not specified. - """ - # Reset read from replicas flag - self.read_from_replicas = False - return self.execute_command("READWRITE", target_nodes=target_nodes) diff --git a/redis/commands/graph/__init__.py b/redis/commands/graph/__init__.py deleted file mode 100644 index 7b9972a..0000000 --- a/redis/commands/graph/__init__.py +++ /dev/null @@ -1,162 +0,0 @@ -from ..helpers import quote_string, random_string, stringify_param_value -from .commands import GraphCommands -from .edge import Edge # noqa -from .node import Node # noqa -from .path import Path # noqa - - -class Graph(GraphCommands): - """ - Graph, collection of nodes and edges. - """ - - def __init__(self, client, name=random_string()): - """ - Create a new graph. - """ - self.NAME = name # Graph key - self.client = client - self.execute_command = client.execute_command - - self.nodes = {} - self.edges = [] - self._labels = [] # List of node labels. - self._properties = [] # List of properties. - self._relationshipTypes = [] # List of relation types. - self.version = 0 # Graph version - - @property - def name(self): - return self.NAME - - def _clear_schema(self): - self._labels = [] - self._properties = [] - self._relationshipTypes = [] - - def _refresh_schema(self): - self._clear_schema() - self._refresh_labels() - self._refresh_relations() - self._refresh_attributes() - - def _refresh_labels(self): - lbls = self.labels() - - # Unpack data. - self._labels = [None] * len(lbls) - for i, l in enumerate(lbls): - self._labels[i] = l[0] - - def _refresh_relations(self): - rels = self.relationshipTypes() - - # Unpack data. - self._relationshipTypes = [None] * len(rels) - for i, r in enumerate(rels): - self._relationshipTypes[i] = r[0] - - def _refresh_attributes(self): - props = self.propertyKeys() - - # Unpack data. - self._properties = [None] * len(props) - for i, p in enumerate(props): - self._properties[i] = p[0] - - def get_label(self, idx): - """ - Returns a label by it's index - - Args: - - idx: - The index of the label - """ - try: - label = self._labels[idx] - except IndexError: - # Refresh labels. - self._refresh_labels() - label = self._labels[idx] - return label - - def get_relation(self, idx): - """ - Returns a relationship type by it's index - - Args: - - idx: - The index of the relation - """ - try: - relationship_type = self._relationshipTypes[idx] - except IndexError: - # Refresh relationship types. - self._refresh_relations() - relationship_type = self._relationshipTypes[idx] - return relationship_type - - def get_property(self, idx): - """ - Returns a property by it's index - - Args: - - idx: - The index of the property - """ - try: - propertie = self._properties[idx] - except IndexError: - # Refresh properties. - self._refresh_attributes() - propertie = self._properties[idx] - return propertie - - def add_node(self, node): - """ - Adds a node to the graph. - """ - if node.alias is None: - node.alias = random_string() - self.nodes[node.alias] = node - - def add_edge(self, edge): - """ - Adds an edge to the graph. - """ - if not (self.nodes[edge.src_node.alias] and self.nodes[edge.dest_node.alias]): - raise AssertionError("Both edge's end must be in the graph") - - self.edges.append(edge) - - def _build_params_header(self, params): - if not isinstance(params, dict): - raise TypeError("'params' must be a dict") - # Header starts with "CYPHER" - params_header = "CYPHER " - for key, value in params.items(): - params_header += str(key) + "=" + stringify_param_value(value) + " " - return params_header - - # Procedures. - def call_procedure(self, procedure, *args, read_only=False, **kwagrs): - args = [quote_string(arg) for arg in args] - q = f"CALL {procedure}({','.join(args)})" - - y = kwagrs.get("y", None) - if y: - q += f" YIELD {','.join(y)}" - - return self.query(q, read_only=read_only) - - def labels(self): - return self.call_procedure("db.labels", read_only=True).result_set - - def relationshipTypes(self): - return self.call_procedure("db.relationshipTypes", read_only=True).result_set - - def propertyKeys(self): - return self.call_procedure("db.propertyKeys", read_only=True).result_set diff --git a/redis/commands/graph/commands.py b/redis/commands/graph/commands.py deleted file mode 100644 index e097936..0000000 --- a/redis/commands/graph/commands.py +++ /dev/null @@ -1,202 +0,0 @@ -from redis import DataError -from redis.exceptions import ResponseError - -from .exceptions import VersionMismatchException -from .query_result import QueryResult - - -class GraphCommands: - """RedisGraph Commands""" - - def commit(self): - """ - Create entire graph. - For more information see `CREATE `_. # noqa - """ - if len(self.nodes) == 0 and len(self.edges) == 0: - return None - - query = "CREATE " - for _, node in self.nodes.items(): - query += str(node) + "," - - query += ",".join([str(edge) for edge in self.edges]) - - # Discard leading comma. - if query[-1] == ",": - query = query[:-1] - - return self.query(query) - - def query(self, q, params=None, timeout=None, read_only=False, profile=False): - """ - Executes a query against the graph. - For more information see `GRAPH.QUERY `_. # noqa - - Args: - - ------- - q : - The query. - params : dict - Query parameters. - timeout : int - Maximum runtime for read queries in milliseconds. - read_only : bool - Executes a readonly query if set to True. - profile : bool - Return details on results produced by and time - spent in each operation. - """ - - # maintain original 'q' - query = q - - # handle query parameters - if params is not None: - query = self._build_params_header(params) + query - - # construct query command - # ask for compact result-set format - # specify known graph version - if profile: - cmd = "GRAPH.PROFILE" - else: - cmd = "GRAPH.RO_QUERY" if read_only else "GRAPH.QUERY" - command = [cmd, self.name, query, "--compact"] - - # include timeout is specified - if timeout: - if not isinstance(timeout, int): - raise Exception("Timeout argument must be a positive integer") - command += ["timeout", timeout] - - # issue query - try: - response = self.execute_command(*command) - return QueryResult(self, response, profile) - except ResponseError as e: - if "wrong number of arguments" in str(e): - print( - "Note: RedisGraph Python requires server version 2.2.8 or above" - ) # noqa - if "unknown command" in str(e) and read_only: - # `GRAPH.RO_QUERY` is unavailable in older versions. - return self.query(q, params, timeout, read_only=False) - raise e - except VersionMismatchException as e: - # client view over the graph schema is out of sync - # set client version and refresh local schema - self.version = e.version - self._refresh_schema() - # re-issue query - return self.query(q, params, timeout, read_only) - - def merge(self, pattern): - """ - Merge pattern. - For more information see `MERGE `_. # noqa - """ - query = "MERGE " - query += str(pattern) - - return self.query(query) - - def delete(self): - """ - Deletes graph. - For more information see `DELETE `_. # noqa - """ - self._clear_schema() - return self.execute_command("GRAPH.DELETE", self.name) - - # declared here, to override the built in redis.db.flush() - def flush(self): - """ - Commit the graph and reset the edges and the nodes to zero length. - """ - self.commit() - self.nodes = {} - self.edges = [] - - def explain(self, query, params=None): - """ - Get the execution plan for given query, - Returns an array of operations. - For more information see `GRAPH.EXPLAIN `_. # noqa - - Args: - - ------- - query: - The query that will be executed. - params: dict - Query parameters. - """ - if params is not None: - query = self._build_params_header(params) + query - - plan = self.execute_command("GRAPH.EXPLAIN", self.name, query) - return "\n".join(plan) - - def bulk(self, **kwargs): - """Internal only. Not supported.""" - raise NotImplementedError( - "GRAPH.BULK is internal only. " - "Use https://github.com/redisgraph/redisgraph-bulk-loader." - ) - - def profile(self, query): - """ - Execute a query and produce an execution plan augmented with metrics - for each operation's execution. Return a string representation of a - query execution plan, with details on results produced by and time - spent in each operation. - For more information see `GRAPH.PROFILE `_. # noqa - """ - return self.query(query, profile=True) - - def slowlog(self): - """ - Get a list containing up to 10 of the slowest queries issued - against the given graph ID. - For more information see `GRAPH.SLOWLOG `_. # noqa - - Each item in the list has the following structure: - 1. A unix timestamp at which the log entry was processed. - 2. The issued command. - 3. The issued query. - 4. The amount of time needed for its execution, in milliseconds. - """ - return self.execute_command("GRAPH.SLOWLOG", self.name) - - def config(self, name, value=None, set=False): - """ - Retrieve or update a RedisGraph configuration. - For more information see `GRAPH.CONFIG `_. # noqa - - Args: - - name : str - The name of the configuration - value : - The value we want to ser (can be used only when `set` is on) - set : bool - Turn on to set a configuration. Default behavior is get. - """ - params = ["SET" if set else "GET", name] - if value is not None: - if set: - params.append(value) - else: - raise DataError( - "``value`` can be provided only when ``set`` is True" - ) # noqa - return self.execute_command("GRAPH.CONFIG", *params) - - def list_keys(self): - """ - Lists all graph keys in the keyspace. - For more information see `GRAPH.LIST `_. # noqa - """ - return self.execute_command("GRAPH.LIST") diff --git a/redis/commands/graph/edge.py b/redis/commands/graph/edge.py deleted file mode 100644 index b334293..0000000 --- a/redis/commands/graph/edge.py +++ /dev/null @@ -1,87 +0,0 @@ -from ..helpers import quote_string -from .node import Node - - -class Edge: - """ - An edge connecting two nodes. - """ - - def __init__(self, src_node, relation, dest_node, edge_id=None, properties=None): - """ - Create a new edge. - """ - if src_node is None or dest_node is None: - # NOTE(bors-42): It makes sense to change AssertionError to - # ValueError here - raise AssertionError("Both src_node & dest_node must be provided") - - self.id = edge_id - self.relation = relation or "" - self.properties = properties or {} - self.src_node = src_node - self.dest_node = dest_node - - def toString(self): - res = "" - if self.properties: - props = ",".join( - key + ":" + str(quote_string(val)) - for key, val in sorted(self.properties.items()) - ) - res += "{" + props + "}" - - return res - - def __str__(self): - # Source node. - if isinstance(self.src_node, Node): - res = str(self.src_node) - else: - res = "()" - - # Edge - res += "-[" - if self.relation: - res += ":" + self.relation - if self.properties: - props = ",".join( - key + ":" + str(quote_string(val)) - for key, val in sorted(self.properties.items()) - ) - res += "{" + props + "}" - res += "]->" - - # Dest node. - if isinstance(self.dest_node, Node): - res += str(self.dest_node) - else: - res += "()" - - return res - - def __eq__(self, rhs): - # Quick positive check, if both IDs are set. - if self.id is not None and rhs.id is not None and self.id == rhs.id: - return True - - # Source and destination nodes should match. - if self.src_node != rhs.src_node: - return False - - if self.dest_node != rhs.dest_node: - return False - - # Relation should match. - if self.relation != rhs.relation: - return False - - # Quick check for number of properties. - if len(self.properties) != len(rhs.properties): - return False - - # Compare properties. - if self.properties != rhs.properties: - return False - - return True diff --git a/redis/commands/graph/exceptions.py b/redis/commands/graph/exceptions.py deleted file mode 100644 index 4bbac10..0000000 --- a/redis/commands/graph/exceptions.py +++ /dev/null @@ -1,3 +0,0 @@ -class VersionMismatchException(Exception): - def __init__(self, version): - self.version = version diff --git a/redis/commands/graph/node.py b/redis/commands/graph/node.py deleted file mode 100644 index 47e4eeb..0000000 --- a/redis/commands/graph/node.py +++ /dev/null @@ -1,84 +0,0 @@ -from ..helpers import quote_string - - -class Node: - """ - A node within the graph. - """ - - def __init__(self, node_id=None, alias=None, label=None, properties=None): - """ - Create a new node. - """ - self.id = node_id - self.alias = alias - if isinstance(label, list): - label = [inner_label for inner_label in label if inner_label != ""] - - if ( - label is None - or label == "" - or (isinstance(label, list) and len(label) == 0) - ): - self.label = None - self.labels = None - elif isinstance(label, str): - self.label = label - self.labels = [label] - elif isinstance(label, list) and all( - [isinstance(inner_label, str) for inner_label in label] - ): - self.label = label[0] - self.labels = label - else: - raise AssertionError( - "label should be either None, " "string or a list of strings" - ) - - self.properties = properties or {} - - def toString(self): - res = "" - if self.properties: - props = ",".join( - key + ":" + str(quote_string(val)) - for key, val in sorted(self.properties.items()) - ) - res += "{" + props + "}" - - return res - - def __str__(self): - res = "(" - if self.alias: - res += self.alias - if self.labels: - res += ":" + ":".join(self.labels) - if self.properties: - props = ",".join( - key + ":" + str(quote_string(val)) - for key, val in sorted(self.properties.items()) - ) - res += "{" + props + "}" - res += ")" - - return res - - def __eq__(self, rhs): - # Quick positive check, if both IDs are set. - if self.id is not None and rhs.id is not None and self.id != rhs.id: - return False - - # Label should match. - if self.label != rhs.label: - return False - - # Quick check for number of properties. - if len(self.properties) != len(rhs.properties): - return False - - # Compare properties. - if self.properties != rhs.properties: - return False - - return True diff --git a/redis/commands/graph/path.py b/redis/commands/graph/path.py deleted file mode 100644 index 6f2214a..0000000 --- a/redis/commands/graph/path.py +++ /dev/null @@ -1,74 +0,0 @@ -from .edge import Edge -from .node import Node - - -class Path: - def __init__(self, nodes, edges): - if not (isinstance(nodes, list) and isinstance(edges, list)): - raise TypeError("nodes and edges must be list") - - self._nodes = nodes - self._edges = edges - self.append_type = Node - - @classmethod - def new_empty_path(cls): - return cls([], []) - - def nodes(self): - return self._nodes - - def edges(self): - return self._edges - - def get_node(self, index): - return self._nodes[index] - - def get_relationship(self, index): - return self._edges[index] - - def first_node(self): - return self._nodes[0] - - def last_node(self): - return self._nodes[-1] - - def edge_count(self): - return len(self._edges) - - def nodes_count(self): - return len(self._nodes) - - def add_node(self, node): - if not isinstance(node, self.append_type): - raise AssertionError("Add Edge before adding Node") - self._nodes.append(node) - self.append_type = Edge - return self - - def add_edge(self, edge): - if not isinstance(edge, self.append_type): - raise AssertionError("Add Node before adding Edge") - self._edges.append(edge) - self.append_type = Node - return self - - def __eq__(self, other): - return self.nodes() == other.nodes() and self.edges() == other.edges() - - def __str__(self): - res = "<" - edge_count = self.edge_count() - for i in range(0, edge_count): - node_id = self.get_node(i).id - res += "(" + str(node_id) + ")" - edge = self.get_relationship(i) - res += ( - "-[" + str(int(edge.id)) + "]->" - if edge.src_node == node_id - else "<-[" + str(int(edge.id)) + "]-" - ) - node_id = self.get_node(edge_count).id - res += "(" + str(node_id) + ")" - res += ">" - return res diff --git a/redis/commands/graph/query_result.py b/redis/commands/graph/query_result.py deleted file mode 100644 index e9d9f4d..0000000 --- a/redis/commands/graph/query_result.py +++ /dev/null @@ -1,362 +0,0 @@ -from collections import OrderedDict - -# from prettytable import PrettyTable -from redis import ResponseError - -from .edge import Edge -from .exceptions import VersionMismatchException -from .node import Node -from .path import Path - -LABELS_ADDED = "Labels added" -NODES_CREATED = "Nodes created" -NODES_DELETED = "Nodes deleted" -RELATIONSHIPS_DELETED = "Relationships deleted" -PROPERTIES_SET = "Properties set" -RELATIONSHIPS_CREATED = "Relationships created" -INDICES_CREATED = "Indices created" -INDICES_DELETED = "Indices deleted" -CACHED_EXECUTION = "Cached execution" -INTERNAL_EXECUTION_TIME = "internal execution time" - -STATS = [ - LABELS_ADDED, - NODES_CREATED, - PROPERTIES_SET, - RELATIONSHIPS_CREATED, - NODES_DELETED, - RELATIONSHIPS_DELETED, - INDICES_CREATED, - INDICES_DELETED, - CACHED_EXECUTION, - INTERNAL_EXECUTION_TIME, -] - - -class ResultSetColumnTypes: - COLUMN_UNKNOWN = 0 - COLUMN_SCALAR = 1 - COLUMN_NODE = 2 # Unused as of RedisGraph v2.1.0, retained for backwards compatibility. # noqa - COLUMN_RELATION = 3 # Unused as of RedisGraph v2.1.0, retained for backwards compatibility. # noqa - - -class ResultSetScalarTypes: - VALUE_UNKNOWN = 0 - VALUE_NULL = 1 - VALUE_STRING = 2 - VALUE_INTEGER = 3 - VALUE_BOOLEAN = 4 - VALUE_DOUBLE = 5 - VALUE_ARRAY = 6 - VALUE_EDGE = 7 - VALUE_NODE = 8 - VALUE_PATH = 9 - VALUE_MAP = 10 - VALUE_POINT = 11 - - -class QueryResult: - def __init__(self, graph, response, profile=False): - """ - A class that represents a result of the query operation. - - Args: - - graph: - The graph on which the query was executed. - response: - The response from the server. - profile: - A boolean indicating if the query command was "GRAPH.PROFILE" - """ - self.graph = graph - self.header = [] - self.result_set = [] - - # in case of an error an exception will be raised - self._check_for_errors(response) - - if len(response) == 1: - self.parse_statistics(response[0]) - elif profile: - self.parse_profile(response) - else: - # start by parsing statistics, matches the one we have - self.parse_statistics(response[-1]) # Last element. - self.parse_results(response) - - def _check_for_errors(self, response): - if isinstance(response[0], ResponseError): - error = response[0] - if str(error) == "version mismatch": - version = response[1] - error = VersionMismatchException(version) - raise error - - # If we encountered a run-time error, the last response - # element will be an exception - if isinstance(response[-1], ResponseError): - raise response[-1] - - def parse_results(self, raw_result_set): - self.header = self.parse_header(raw_result_set) - - # Empty header. - if len(self.header) == 0: - return - - self.result_set = self.parse_records(raw_result_set) - - def parse_statistics(self, raw_statistics): - self.statistics = {} - - # decode statistics - for idx, stat in enumerate(raw_statistics): - if isinstance(stat, bytes): - raw_statistics[idx] = stat.decode() - - for s in STATS: - v = self._get_value(s, raw_statistics) - if v is not None: - self.statistics[s] = v - - def parse_header(self, raw_result_set): - # An array of column name/column type pairs. - header = raw_result_set[0] - return header - - def parse_records(self, raw_result_set): - records = [] - result_set = raw_result_set[1] - for row in result_set: - record = [] - for idx, cell in enumerate(row): - if self.header[idx][0] == ResultSetColumnTypes.COLUMN_SCALAR: # noqa - record.append(self.parse_scalar(cell)) - elif self.header[idx][0] == ResultSetColumnTypes.COLUMN_NODE: # noqa - record.append(self.parse_node(cell)) - elif ( - self.header[idx][0] == ResultSetColumnTypes.COLUMN_RELATION - ): # noqa - record.append(self.parse_edge(cell)) - else: - print("Unknown column type.\n") - records.append(record) - - return records - - def parse_entity_properties(self, props): - # [[name, value type, value] X N] - properties = {} - for prop in props: - prop_name = self.graph.get_property(prop[0]) - prop_value = self.parse_scalar(prop[1:]) - properties[prop_name] = prop_value - - return properties - - def parse_string(self, cell): - if isinstance(cell, bytes): - return cell.decode() - elif not isinstance(cell, str): - return str(cell) - else: - return cell - - def parse_node(self, cell): - # Node ID (integer), - # [label string offset (integer)], - # [[name, value type, value] X N] - - node_id = int(cell[0]) - labels = None - if len(cell[1]) > 0: - labels = [] - for inner_label in cell[1]: - labels.append(self.graph.get_label(inner_label)) - properties = self.parse_entity_properties(cell[2]) - return Node(node_id=node_id, label=labels, properties=properties) - - def parse_edge(self, cell): - # Edge ID (integer), - # reltype string offset (integer), - # src node ID offset (integer), - # dest node ID offset (integer), - # [[name, value, value type] X N] - - edge_id = int(cell[0]) - relation = self.graph.get_relation(cell[1]) - src_node_id = int(cell[2]) - dest_node_id = int(cell[3]) - properties = self.parse_entity_properties(cell[4]) - return Edge( - src_node_id, relation, dest_node_id, edge_id=edge_id, properties=properties - ) - - def parse_path(self, cell): - nodes = self.parse_scalar(cell[0]) - edges = self.parse_scalar(cell[1]) - return Path(nodes, edges) - - def parse_map(self, cell): - m = OrderedDict() - n_entries = len(cell) - - # A map is an array of key value pairs. - # 1. key (string) - # 2. array: (value type, value) - for i in range(0, n_entries, 2): - key = self.parse_string(cell[i]) - m[key] = self.parse_scalar(cell[i + 1]) - - return m - - def parse_point(self, cell): - p = {} - # A point is received an array of the form: [latitude, longitude] - # It is returned as a map of the form: {"latitude": latitude, "longitude": longitude} # noqa - p["latitude"] = float(cell[0]) - p["longitude"] = float(cell[1]) - return p - - def parse_scalar(self, cell): - scalar_type = int(cell[0]) - value = cell[1] - scalar = None - - if scalar_type == ResultSetScalarTypes.VALUE_NULL: - scalar = None - - elif scalar_type == ResultSetScalarTypes.VALUE_STRING: - scalar = self.parse_string(value) - - elif scalar_type == ResultSetScalarTypes.VALUE_INTEGER: - scalar = int(value) - - elif scalar_type == ResultSetScalarTypes.VALUE_BOOLEAN: - value = value.decode() if isinstance(value, bytes) else value - if value == "true": - scalar = True - elif value == "false": - scalar = False - else: - print("Unknown boolean type\n") - - elif scalar_type == ResultSetScalarTypes.VALUE_DOUBLE: - scalar = float(value) - - elif scalar_type == ResultSetScalarTypes.VALUE_ARRAY: - # array variable is introduced only for readability - scalar = array = value - for i in range(len(array)): - scalar[i] = self.parse_scalar(array[i]) - - elif scalar_type == ResultSetScalarTypes.VALUE_NODE: - scalar = self.parse_node(value) - - elif scalar_type == ResultSetScalarTypes.VALUE_EDGE: - scalar = self.parse_edge(value) - - elif scalar_type == ResultSetScalarTypes.VALUE_PATH: - scalar = self.parse_path(value) - - elif scalar_type == ResultSetScalarTypes.VALUE_MAP: - scalar = self.parse_map(value) - - elif scalar_type == ResultSetScalarTypes.VALUE_POINT: - scalar = self.parse_point(value) - - elif scalar_type == ResultSetScalarTypes.VALUE_UNKNOWN: - print("Unknown scalar type\n") - - return scalar - - def parse_profile(self, response): - self.result_set = [x[0 : x.index(",")].strip() for x in response] - - # """Prints the data from the query response: - # 1. First row result_set contains the columns names. - # Thus the first row in PrettyTable will contain the - # columns. - # 2. The row after that will contain the data returned, - # or 'No Data returned' if there is none. - # 3. Prints the statistics of the query. - # """ - - # def pretty_print(self): - # if not self.is_empty(): - # header = [col[1] for col in self.header] - # tbl = PrettyTable(header) - - # for row in self.result_set: - # record = [] - # for idx, cell in enumerate(row): - # if type(cell) is Node: - # record.append(cell.toString()) - # elif type(cell) is Edge: - # record.append(cell.toString()) - # else: - # record.append(cell) - # tbl.add_row(record) - - # if len(self.result_set) == 0: - # tbl.add_row(['No data returned.']) - - # print(str(tbl) + '\n') - - # for stat in self.statistics: - # print("%s %s" % (stat, self.statistics[stat])) - - def is_empty(self): - return len(self.result_set) == 0 - - @staticmethod - def _get_value(prop, statistics): - for stat in statistics: - if prop in stat: - return float(stat.split(": ")[1].split(" ")[0]) - - return None - - def _get_stat(self, stat): - return self.statistics[stat] if stat in self.statistics else 0 - - @property - def labels_added(self): - return self._get_stat(LABELS_ADDED) - - @property - def nodes_created(self): - return self._get_stat(NODES_CREATED) - - @property - def nodes_deleted(self): - return self._get_stat(NODES_DELETED) - - @property - def properties_set(self): - return self._get_stat(PROPERTIES_SET) - - @property - def relationships_created(self): - return self._get_stat(RELATIONSHIPS_CREATED) - - @property - def relationships_deleted(self): - return self._get_stat(RELATIONSHIPS_DELETED) - - @property - def indices_created(self): - return self._get_stat(INDICES_CREATED) - - @property - def indices_deleted(self): - return self._get_stat(INDICES_DELETED) - - @property - def cached_execution(self): - return self._get_stat(CACHED_EXECUTION) == 1 - - @property - def run_time_ms(self): - return self._get_stat(INTERNAL_EXECUTION_TIME) diff --git a/redis/commands/json/__init__.py b/redis/commands/json/__init__.py deleted file mode 100644 index 12c0648..0000000 --- a/redis/commands/json/__init__.py +++ /dev/null @@ -1,118 +0,0 @@ -from json import JSONDecodeError, JSONDecoder, JSONEncoder - -import redis - -from ..helpers import nativestr -from .commands import JSONCommands -from .decoders import bulk_of_jsons, decode_list - - -class JSON(JSONCommands): - """ - Create a client for talking to json. - - :param decoder: - :type json.JSONDecoder: An instance of json.JSONDecoder - - :param encoder: - :type json.JSONEncoder: An instance of json.JSONEncoder - """ - - def __init__( - self, - client, - version=None, - decoder=JSONDecoder(), - encoder=JSONEncoder(), - ): - """ - Create a client for talking to json. - - :param decoder: - :type json.JSONDecoder: An instance of json.JSONDecoder - - :param encoder: - :type json.JSONEncoder: An instance of json.JSONEncoder - """ - # Set the module commands' callbacks - self.MODULE_CALLBACKS = { - "JSON.CLEAR": int, - "JSON.DEL": int, - "JSON.FORGET": int, - "JSON.GET": self._decode, - "JSON.MGET": bulk_of_jsons(self._decode), - "JSON.SET": lambda r: r and nativestr(r) == "OK", - "JSON.NUMINCRBY": self._decode, - "JSON.NUMMULTBY": self._decode, - "JSON.TOGGLE": self._decode, - "JSON.STRAPPEND": self._decode, - "JSON.STRLEN": self._decode, - "JSON.ARRAPPEND": self._decode, - "JSON.ARRINDEX": self._decode, - "JSON.ARRINSERT": self._decode, - "JSON.ARRLEN": self._decode, - "JSON.ARRPOP": self._decode, - "JSON.ARRTRIM": self._decode, - "JSON.OBJLEN": self._decode, - "JSON.OBJKEYS": self._decode, - "JSON.RESP": self._decode, - "JSON.DEBUG": self._decode, - } - - self.client = client - self.execute_command = client.execute_command - self.MODULE_VERSION = version - - for key, value in self.MODULE_CALLBACKS.items(): - self.client.set_response_callback(key, value) - - self.__encoder__ = encoder - self.__decoder__ = decoder - - def _decode(self, obj): - """Get the decoder.""" - if obj is None: - return obj - - try: - x = self.__decoder__.decode(obj) - if x is None: - raise TypeError - return x - except TypeError: - try: - return self.__decoder__.decode(obj.decode()) - except AttributeError: - return decode_list(obj) - except (AttributeError, JSONDecodeError): - return decode_list(obj) - - def _encode(self, obj): - """Get the encoder.""" - return self.__encoder__.encode(obj) - - def pipeline(self, transaction=True, shard_hint=None): - """Creates a pipeline for the JSON module, that can be used for executing - JSON commands, as well as classic core commands. - - Usage example: - - r = redis.Redis() - pipe = r.json().pipeline() - pipe.jsonset('foo', '.', {'hello!': 'world'}) - pipe.jsonget('foo') - pipe.jsonget('notakey') - """ - p = Pipeline( - connection_pool=self.client.connection_pool, - response_callbacks=self.MODULE_CALLBACKS, - transaction=transaction, - shard_hint=shard_hint, - ) - p._encode = self._encode - p._decode = self._decode - return p - - -class Pipeline(JSONCommands, redis.client.Pipeline): - """Pipeline for the module.""" diff --git a/redis/commands/json/commands.py b/redis/commands/json/commands.py deleted file mode 100644 index a132b8e..0000000 --- a/redis/commands/json/commands.py +++ /dev/null @@ -1,329 +0,0 @@ -import os -from json import JSONDecodeError, loads - -from deprecated import deprecated - -from redis.exceptions import DataError - -from .decoders import decode_dict_keys -from .path import Path - - -class JSONCommands: - """json commands.""" - - def arrappend(self, name, path=Path.rootPath(), *args): - """Append the objects ``args`` to the array under the - ``path` in key ``name``. - - For more information: https://oss.redis.com/redisjson/commands/#jsonarrappend - """ # noqa - pieces = [name, str(path)] - for o in args: - pieces.append(self._encode(o)) - return self.execute_command("JSON.ARRAPPEND", *pieces) - - def arrindex(self, name, path, scalar, start=0, stop=-1): - """ - Return the index of ``scalar`` in the JSON array under ``path`` at key - ``name``. - - The search can be limited using the optional inclusive ``start`` - and exclusive ``stop`` indices. - - For more information: https://oss.redis.com/redisjson/commands/#jsonarrindex - """ # noqa - return self.execute_command( - "JSON.ARRINDEX", name, str(path), self._encode(scalar), start, stop - ) - - def arrinsert(self, name, path, index, *args): - """Insert the objects ``args`` to the array at index ``index`` - under the ``path` in key ``name``. - - For more information: https://oss.redis.com/redisjson/commands/#jsonarrinsert - """ # noqa - pieces = [name, str(path), index] - for o in args: - pieces.append(self._encode(o)) - return self.execute_command("JSON.ARRINSERT", *pieces) - - def arrlen(self, name, path=Path.rootPath()): - """Return the length of the array JSON value under ``path`` - at key``name``. - - For more information: https://oss.redis.com/redisjson/commands/#jsonarrlen - """ # noqa - return self.execute_command("JSON.ARRLEN", name, str(path)) - - def arrpop(self, name, path=Path.rootPath(), index=-1): - """Pop the element at ``index`` in the array JSON value under - ``path`` at key ``name``. - - For more information: https://oss.redis.com/redisjson/commands/#jsonarrpop - """ # noqa - return self.execute_command("JSON.ARRPOP", name, str(path), index) - - def arrtrim(self, name, path, start, stop): - """Trim the array JSON value under ``path`` at key ``name`` to the - inclusive range given by ``start`` and ``stop``. - - For more information: https://oss.redis.com/redisjson/commands/#jsonarrtrim - """ # noqa - return self.execute_command("JSON.ARRTRIM", name, str(path), start, stop) - - def type(self, name, path=Path.rootPath()): - """Get the type of the JSON value under ``path`` from key ``name``. - - For more information: https://oss.redis.com/redisjson/commands/#jsontype - """ # noqa - return self.execute_command("JSON.TYPE", name, str(path)) - - def resp(self, name, path=Path.rootPath()): - """Return the JSON value under ``path`` at key ``name``. - - For more information: https://oss.redis.com/redisjson/commands/#jsonresp - """ # noqa - return self.execute_command("JSON.RESP", name, str(path)) - - def objkeys(self, name, path=Path.rootPath()): - """Return the key names in the dictionary JSON value under ``path`` at - key ``name``. - - For more information: https://oss.redis.com/redisjson/commands/#jsonobjkeys - """ # noqa - return self.execute_command("JSON.OBJKEYS", name, str(path)) - - def objlen(self, name, path=Path.rootPath()): - """Return the length of the dictionary JSON value under ``path`` at key - ``name``. - - For more information: https://oss.redis.com/redisjson/commands/#jsonobjlen - """ # noqa - return self.execute_command("JSON.OBJLEN", name, str(path)) - - def numincrby(self, name, path, number): - """Increment the numeric (integer or floating point) JSON value under - ``path`` at key ``name`` by the provided ``number``. - - For more information: https://oss.redis.com/redisjson/commands/#jsonnumincrby - """ # noqa - return self.execute_command( - "JSON.NUMINCRBY", name, str(path), self._encode(number) - ) - - @deprecated(version="4.0.0", reason="deprecated since redisjson 1.0.0") - def nummultby(self, name, path, number): - """Multiply the numeric (integer or floating point) JSON value under - ``path`` at key ``name`` with the provided ``number``. - - For more information: https://oss.redis.com/redisjson/commands/#jsonnummultby - """ # noqa - return self.execute_command( - "JSON.NUMMULTBY", name, str(path), self._encode(number) - ) - - def clear(self, name, path=Path.rootPath()): - """ - Empty arrays and objects (to have zero slots/keys without deleting the - array/object). - - Return the count of cleared paths (ignoring non-array and non-objects - paths). - - For more information: https://oss.redis.com/redisjson/commands/#jsonclear - """ # noqa - return self.execute_command("JSON.CLEAR", name, str(path)) - - def delete(self, key, path=Path.rootPath()): - """Delete the JSON value stored at key ``key`` under ``path``. - - For more information: https://oss.redis.com/redisjson/commands/#jsondel - """ - return self.execute_command("JSON.DEL", key, str(path)) - - # forget is an alias for delete - forget = delete - - def get(self, name, *args, no_escape=False): - """ - Get the object stored as a JSON value at key ``name``. - - ``args`` is zero or more paths, and defaults to root path - ```no_escape`` is a boolean flag to add no_escape option to get - non-ascii characters - - For more information: https://oss.redis.com/redisjson/commands/#jsonget - """ # noqa - pieces = [name] - if no_escape: - pieces.append("noescape") - - if len(args) == 0: - pieces.append(Path.rootPath()) - - else: - for p in args: - pieces.append(str(p)) - - # Handle case where key doesn't exist. The JSONDecoder would raise a - # TypeError exception since it can't decode None - try: - return self.execute_command("JSON.GET", *pieces) - except TypeError: - return None - - def mget(self, keys, path): - """ - Get the objects stored as a JSON values under ``path``. ``keys`` - is a list of one or more keys. - - For more information: https://oss.redis.com/redisjson/commands/#jsonmget - """ # noqa - pieces = [] - pieces += keys - pieces.append(str(path)) - return self.execute_command("JSON.MGET", *pieces) - - def set(self, name, path, obj, nx=False, xx=False, decode_keys=False): - """ - Set the JSON value at key ``name`` under the ``path`` to ``obj``. - - ``nx`` if set to True, set ``value`` only if it does not exist. - ``xx`` if set to True, set ``value`` only if it exists. - ``decode_keys`` If set to True, the keys of ``obj`` will be decoded - with utf-8. - - For the purpose of using this within a pipeline, this command is also - aliased to jsonset. - - For more information: https://oss.redis.com/redisjson/commands/#jsonset - """ - if decode_keys: - obj = decode_dict_keys(obj) - - pieces = [name, str(path), self._encode(obj)] - - # Handle existential modifiers - if nx and xx: - raise Exception( - "nx and xx are mutually exclusive: use one, the " - "other or neither - but not both" - ) - elif nx: - pieces.append("NX") - elif xx: - pieces.append("XX") - return self.execute_command("JSON.SET", *pieces) - - def set_file(self, name, path, file_name, nx=False, xx=False, decode_keys=False): - """ - Set the JSON value at key ``name`` under the ``path`` to the content - of the json file ``file_name``. - - ``nx`` if set to True, set ``value`` only if it does not exist. - ``xx`` if set to True, set ``value`` only if it exists. - ``decode_keys`` If set to True, the keys of ``obj`` will be decoded - with utf-8. - - """ - - with open(file_name, "r") as fp: - file_content = loads(fp.read()) - - return self.set(name, path, file_content, nx=nx, xx=xx, decode_keys=decode_keys) - - def set_path(self, json_path, root_folder, nx=False, xx=False, decode_keys=False): - """ - Iterate over ``root_folder`` and set each JSON file to a value - under ``json_path`` with the file name as the key. - - ``nx`` if set to True, set ``value`` only if it does not exist. - ``xx`` if set to True, set ``value`` only if it exists. - ``decode_keys`` If set to True, the keys of ``obj`` will be decoded - with utf-8. - - """ - set_files_result = {} - for root, dirs, files in os.walk(root_folder): - for file in files: - file_path = os.path.join(root, file) - try: - file_name = file_path.rsplit(".")[0] - self.set_file( - file_name, - json_path, - file_path, - nx=nx, - xx=xx, - decode_keys=decode_keys, - ) - set_files_result[file_path] = True - except JSONDecodeError: - set_files_result[file_path] = False - - return set_files_result - - def strlen(self, name, path=None): - """Return the length of the string JSON value under ``path`` at key - ``name``. - - For more information: https://oss.redis.com/redisjson/commands/#jsonstrlen - """ # noqa - pieces = [name] - if path is not None: - pieces.append(str(path)) - return self.execute_command("JSON.STRLEN", *pieces) - - def toggle(self, name, path=Path.rootPath()): - """Toggle boolean value under ``path`` at key ``name``. - returning the new value. - - For more information: https://oss.redis.com/redisjson/commands/#jsontoggle - """ # noqa - return self.execute_command("JSON.TOGGLE", name, str(path)) - - def strappend(self, name, value, path=Path.rootPath()): - """Append to the string JSON value. If two options are specified after - the key name, the path is determined to be the first. If a single - option is passed, then the rootpath (i.e Path.rootPath()) is used. - - For more information: https://oss.redis.com/redisjson/commands/#jsonstrappend - """ # noqa - pieces = [name, str(path), self._encode(value)] - return self.execute_command("JSON.STRAPPEND", *pieces) - - def debug(self, subcommand, key=None, path=Path.rootPath()): - """Return the memory usage in bytes of a value under ``path`` from - key ``name``. - - For more information: https://oss.redis.com/redisjson/commands/#jsondebg - """ # noqa - valid_subcommands = ["MEMORY", "HELP"] - if subcommand not in valid_subcommands: - raise DataError("The only valid subcommands are ", str(valid_subcommands)) - pieces = [subcommand] - if subcommand == "MEMORY": - if key is None: - raise DataError("No key specified") - pieces.append(key) - pieces.append(str(path)) - return self.execute_command("JSON.DEBUG", *pieces) - - @deprecated( - version="4.0.0", reason="redisjson-py supported this, call get directly." - ) - def jsonget(self, *args, **kwargs): - return self.get(*args, **kwargs) - - @deprecated( - version="4.0.0", reason="redisjson-py supported this, call get directly." - ) - def jsonmget(self, *args, **kwargs): - return self.mget(*args, **kwargs) - - @deprecated( - version="4.0.0", reason="redisjson-py supported this, call get directly." - ) - def jsonset(self, *args, **kwargs): - return self.set(*args, **kwargs) diff --git a/redis/commands/json/decoders.py b/redis/commands/json/decoders.py deleted file mode 100644 index b938471..0000000 --- a/redis/commands/json/decoders.py +++ /dev/null @@ -1,60 +0,0 @@ -import copy -import re - -from ..helpers import nativestr - - -def bulk_of_jsons(d): - """Replace serialized JSON values with objects in a - bulk array response (list). - """ - - def _f(b): - for index, item in enumerate(b): - if item is not None: - b[index] = d(item) - return b - - return _f - - -def decode_dict_keys(obj): - """Decode the keys of the given dictionary with utf-8.""" - newobj = copy.copy(obj) - for k in obj.keys(): - if isinstance(k, bytes): - newobj[k.decode("utf-8")] = newobj[k] - newobj.pop(k) - return newobj - - -def unstring(obj): - """ - Attempt to parse string to native integer formats. - One can't simply call int/float in a try/catch because there is a - semantic difference between (for example) 15.0 and 15. - """ - floatreg = "^\\d+.\\d+$" - match = re.findall(floatreg, obj) - if match != []: - return float(match[0]) - - intreg = "^\\d+$" - match = re.findall(intreg, obj) - if match != []: - return int(match[0]) - return obj - - -def decode_list(b): - """ - Given a non-deserializable object, make a best effort to - return a useful set of results. - """ - if isinstance(b, list): - return [nativestr(obj) for obj in b] - elif isinstance(b, bytes): - return unstring(nativestr(b)) - elif isinstance(b, str): - return unstring(b) - return b diff --git a/redis/commands/json/path.py b/redis/commands/json/path.py deleted file mode 100644 index f0a413a..0000000 --- a/redis/commands/json/path.py +++ /dev/null @@ -1,16 +0,0 @@ -class Path: - """This class represents a path in a JSON value.""" - - strPath = "" - - @staticmethod - def rootPath(): - """Return the root path's string representation.""" - return "." - - def __init__(self, path): - """Make a new path based on the string representation in `path`.""" - self.strPath = path - - def __repr__(self): - return self.strPath diff --git a/redis/commands/redismodules.py b/redis/commands/redismodules.py deleted file mode 100644 index eafd650..0000000 --- a/redis/commands/redismodules.py +++ /dev/null @@ -1,83 +0,0 @@ -from json import JSONDecoder, JSONEncoder - - -class RedisModuleCommands: - """This class contains the wrapper functions to bring supported redis - modules into the command namepsace. - """ - - def json(self, encoder=JSONEncoder(), decoder=JSONDecoder()): - """Access the json namespace, providing support for redis json.""" - - from .json import JSON - - jj = JSON(client=self, encoder=encoder, decoder=decoder) - return jj - - def ft(self, index_name="idx"): - """Access the search namespace, providing support for redis search.""" - - from .search import Search - - s = Search(client=self, index_name=index_name) - return s - - def ts(self): - """Access the timeseries namespace, providing support for - redis timeseries data. - """ - - from .timeseries import TimeSeries - - s = TimeSeries(client=self) - return s - - def bf(self): - """Access the bloom namespace.""" - - from .bf import BFBloom - - bf = BFBloom(client=self) - return bf - - def cf(self): - """Access the bloom namespace.""" - - from .bf import CFBloom - - cf = CFBloom(client=self) - return cf - - def cms(self): - """Access the bloom namespace.""" - - from .bf import CMSBloom - - cms = CMSBloom(client=self) - return cms - - def topk(self): - """Access the bloom namespace.""" - - from .bf import TOPKBloom - - topk = TOPKBloom(client=self) - return topk - - def tdigest(self): - """Access the bloom namespace.""" - - from .bf import TDigestBloom - - tdigest = TDigestBloom(client=self) - return tdigest - - def graph(self, index_name="idx"): - """Access the timeseries namespace, providing support for - redis timeseries data. - """ - - from .graph import Graph - - g = Graph(client=self, name=index_name) - return g diff --git a/redis/commands/search/__init__.py b/redis/commands/search/__init__.py deleted file mode 100644 index 94bc037..0000000 --- a/redis/commands/search/__init__.py +++ /dev/null @@ -1,96 +0,0 @@ -from .commands import SearchCommands - - -class Search(SearchCommands): - """ - Create a client for talking to search. - It abstracts the API of the module and lets you just use the engine. - """ - - class BatchIndexer: - """ - A batch indexer allows you to automatically batch - document indexing in pipelines, flushing it every N documents. - """ - - def __init__(self, client, chunk_size=1000): - - self.client = client - self.execute_command = client.execute_command - self.pipeline = client.pipeline(transaction=False, shard_hint=None) - self.total = 0 - self.chunk_size = chunk_size - self.current_chunk = 0 - - def __del__(self): - if self.current_chunk: - self.commit() - - def add_document( - self, - doc_id, - nosave=False, - score=1.0, - payload=None, - replace=False, - partial=False, - no_create=False, - **fields, - ): - """ - Add a document to the batch query - """ - self.client._add_document( - doc_id, - conn=self.pipeline, - nosave=nosave, - score=score, - payload=payload, - replace=replace, - partial=partial, - no_create=no_create, - **fields, - ) - self.current_chunk += 1 - self.total += 1 - if self.current_chunk >= self.chunk_size: - self.commit() - - def add_document_hash( - self, - doc_id, - score=1.0, - replace=False, - ): - """ - Add a hash to the batch query - """ - self.client._add_document_hash( - doc_id, - conn=self.pipeline, - score=score, - replace=replace, - ) - self.current_chunk += 1 - self.total += 1 - if self.current_chunk >= self.chunk_size: - self.commit() - - def commit(self): - """ - Manually commit and flush the batch indexing query - """ - self.pipeline.execute() - self.current_chunk = 0 - - def __init__(self, client, index_name="idx"): - """ - Create a new Client for the given index_name. - The default name is `idx` - - If conn is not None, we employ an already existing redis connection - """ - self.client = client - self.index_name = index_name - self.execute_command = client.execute_command - self.pipeline = client.pipeline diff --git a/redis/commands/search/_util.py b/redis/commands/search/_util.py deleted file mode 100644 index dd1dff3..0000000 --- a/redis/commands/search/_util.py +++ /dev/null @@ -1,7 +0,0 @@ -def to_string(s): - if isinstance(s, str): - return s - elif isinstance(s, bytes): - return s.decode("utf-8", "ignore") - else: - return s # Not a string we care about diff --git a/redis/commands/search/aggregation.py b/redis/commands/search/aggregation.py deleted file mode 100644 index 061e69c..0000000 --- a/redis/commands/search/aggregation.py +++ /dev/null @@ -1,357 +0,0 @@ -FIELDNAME = object() - - -class Limit: - def __init__(self, offset=0, count=0): - self.offset = offset - self.count = count - - def build_args(self): - if self.count: - return ["LIMIT", str(self.offset), str(self.count)] - else: - return [] - - -class Reducer: - """ - Base reducer object for all reducers. - - See the `redisearch.reducers` module for the actual reducers. - """ - - NAME = None - - def __init__(self, *args): - self._args = args - self._field = None - self._alias = None - - def alias(self, alias): - """ - Set the alias for this reducer. - - ### Parameters - - - **alias**: The value of the alias for this reducer. If this is the - special value `aggregation.FIELDNAME` then this reducer will be - aliased using the same name as the field upon which it operates. - Note that using `FIELDNAME` is only possible on reducers which - operate on a single field value. - - This method returns the `Reducer` object making it suitable for - chaining. - """ - if alias is FIELDNAME: - if not self._field: - raise ValueError("Cannot use FIELDNAME alias with no field") - # Chop off initial '@' - alias = self._field[1:] - self._alias = alias - return self - - @property - def args(self): - return self._args - - -class SortDirection: - """ - This special class is used to indicate sort direction. - """ - - DIRSTRING = None - - def __init__(self, field): - self.field = field - - -class Asc(SortDirection): - """ - Indicate that the given field should be sorted in ascending order - """ - - DIRSTRING = "ASC" - - -class Desc(SortDirection): - """ - Indicate that the given field should be sorted in descending order - """ - - DIRSTRING = "DESC" - - -class AggregateRequest: - """ - Aggregation request which can be passed to `Client.aggregate`. - """ - - def __init__(self, query="*"): - """ - Create an aggregation request. This request may then be passed to - `client.aggregate()`. - - In order for the request to be usable, it must contain at least one - group. - - - **query** Query string for filtering records. - - All member methods (except `build_args()`) - return the object itself, making them useful for chaining. - """ - self._query = query - self._aggregateplan = [] - self._loadfields = [] - self._loadall = False - self._limit = Limit() - self._max = 0 - self._with_schema = False - self._verbatim = False - self._cursor = [] - - def load(self, *fields): - """ - Indicate the fields to be returned in the response. These fields are - returned in addition to any others implicitly specified. - - ### Parameters - - - **fields**: If fields not specified, all the fields will be loaded. - Otherwise, fields should be given in the format of `@field`. - """ - if fields: - self._loadfields.extend(fields) - else: - self._loadall = True - return self - - def group_by(self, fields, *reducers): - """ - Specify by which fields to group the aggregation. - - ### Parameters - - - **fields**: Fields to group by. This can either be a single string, - or a list of strings. both cases, the field should be specified as - `@field`. - - **reducers**: One or more reducers. Reducers may be found in the - `aggregation` module. - """ - fields = [fields] if isinstance(fields, str) else fields - reducers = [reducers] if isinstance(reducers, Reducer) else reducers - - ret = ["GROUPBY", str(len(fields)), *fields] - for reducer in reducers: - ret += ["REDUCE", reducer.NAME, str(len(reducer.args))] - ret.extend(reducer.args) - if reducer._alias is not None: - ret += ["AS", reducer._alias] - - self._aggregateplan.extend(ret) - return self - - def apply(self, **kwexpr): - """ - Specify one or more projection expressions to add to each result - - ### Parameters - - - **kwexpr**: One or more key-value pairs for a projection. The key is - the alias for the projection, and the value is the projection - expression itself, for example `apply(square_root="sqrt(@foo)")` - """ - for alias, expr in kwexpr.items(): - ret = ["APPLY", expr] - if alias is not None: - ret += ["AS", alias] - self._aggregateplan.extend(ret) - - return self - - def limit(self, offset, num): - """ - Sets the limit for the most recent group or query. - - If no group has been defined yet (via `group_by()`) then this sets - the limit for the initial pool of results from the query. Otherwise, - this limits the number of items operated on from the previous group. - - Setting a limit on the initial search results may be useful when - attempting to execute an aggregation on a sample of a large data set. - - ### Parameters - - - **offset**: Result offset from which to begin paging - - **num**: Number of results to return - - - Example of sorting the initial results: - - ``` - AggregateRequest("@sale_amount:[10000, inf]")\ - .limit(0, 10)\ - .group_by("@state", r.count()) - ``` - - Will only group by the states found in the first 10 results of the - query `@sale_amount:[10000, inf]`. On the other hand, - - ``` - AggregateRequest("@sale_amount:[10000, inf]")\ - .limit(0, 1000)\ - .group_by("@state", r.count()\ - .limit(0, 10) - ``` - - Will group all the results matching the query, but only return the - first 10 groups. - - If you only wish to return a *top-N* style query, consider using - `sort_by()` instead. - - """ - self._limit = Limit(offset, num) - return self - - def sort_by(self, *fields, **kwargs): - """ - Indicate how the results should be sorted. This can also be used for - *top-N* style queries - - ### Parameters - - - **fields**: The fields by which to sort. This can be either a single - field or a list of fields. If you wish to specify order, you can - use the `Asc` or `Desc` wrapper classes. - - **max**: Maximum number of results to return. This can be - used instead of `LIMIT` and is also faster. - - - Example of sorting by `foo` ascending and `bar` descending: - - ``` - sort_by(Asc("@foo"), Desc("@bar")) - ``` - - Return the top 10 customers: - - ``` - AggregateRequest()\ - .group_by("@customer", r.sum("@paid").alias(FIELDNAME))\ - .sort_by(Desc("@paid"), max=10) - ``` - """ - if isinstance(fields, (str, SortDirection)): - fields = [fields] - - fields_args = [] - for f in fields: - if isinstance(f, SortDirection): - fields_args += [f.field, f.DIRSTRING] - else: - fields_args += [f] - - ret = ["SORTBY", str(len(fields_args))] - ret.extend(fields_args) - max = kwargs.get("max", 0) - if max > 0: - ret += ["MAX", str(max)] - - self._aggregateplan.extend(ret) - return self - - def filter(self, expressions): - """ - Specify filter for post-query results using predicates relating to - values in the result set. - - ### Parameters - - - **fields**: Fields to group by. This can either be a single string, - or a list of strings. - """ - if isinstance(expressions, str): - expressions = [expressions] - - for expression in expressions: - self._aggregateplan.extend(["FILTER", expression]) - - return self - - def with_schema(self): - """ - If set, the `schema` property will contain a list of `[field, type]` - entries in the result object. - """ - self._with_schema = True - return self - - def verbatim(self): - self._verbatim = True - return self - - def cursor(self, count=0, max_idle=0.0): - args = ["WITHCURSOR"] - if count: - args += ["COUNT", str(count)] - if max_idle: - args += ["MAXIDLE", str(max_idle * 1000)] - self._cursor = args - return self - - def build_args(self): - # @foo:bar ... - ret = [self._query] - - if self._with_schema: - ret.append("WITHSCHEMA") - - if self._verbatim: - ret.append("VERBATIM") - - if self._cursor: - ret += self._cursor - - if self._loadall: - ret.append("LOAD") - ret.append("*") - elif self._loadfields: - ret.append("LOAD") - ret.append(str(len(self._loadfields))) - ret.extend(self._loadfields) - - ret.extend(self._aggregateplan) - - ret += self._limit.build_args() - - return ret - - -class Cursor: - def __init__(self, cid): - self.cid = cid - self.max_idle = 0 - self.count = 0 - - def build_args(self): - args = [str(self.cid)] - if self.max_idle: - args += ["MAXIDLE", str(self.max_idle)] - if self.count: - args += ["COUNT", str(self.count)] - return args - - -class AggregateResult: - def __init__(self, rows, cursor, schema): - self.rows = rows - self.cursor = cursor - self.schema = schema - - def __repr__(self): - cid = self.cursor.cid if self.cursor else -1 - return ( - f"<{self.__class__.__name__} at 0x{id(self):x} " - f"Rows={len(self.rows)}, Cursor={cid}>" - ) diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py deleted file mode 100644 index 4ec6fc9..0000000 --- a/redis/commands/search/commands.py +++ /dev/null @@ -1,790 +0,0 @@ -import itertools -import time - -from ..helpers import parse_to_dict -from ._util import to_string -from .aggregation import AggregateRequest, AggregateResult, Cursor -from .document import Document -from .query import Query -from .result import Result -from .suggestion import SuggestionParser - -NUMERIC = "NUMERIC" - -CREATE_CMD = "FT.CREATE" -ALTER_CMD = "FT.ALTER" -SEARCH_CMD = "FT.SEARCH" -ADD_CMD = "FT.ADD" -ADDHASH_CMD = "FT.ADDHASH" -DROP_CMD = "FT.DROP" -EXPLAIN_CMD = "FT.EXPLAIN" -EXPLAINCLI_CMD = "FT.EXPLAINCLI" -DEL_CMD = "FT.DEL" -AGGREGATE_CMD = "FT.AGGREGATE" -PROFILE_CMD = "FT.PROFILE" -CURSOR_CMD = "FT.CURSOR" -SPELLCHECK_CMD = "FT.SPELLCHECK" -DICT_ADD_CMD = "FT.DICTADD" -DICT_DEL_CMD = "FT.DICTDEL" -DICT_DUMP_CMD = "FT.DICTDUMP" -GET_CMD = "FT.GET" -MGET_CMD = "FT.MGET" -CONFIG_CMD = "FT.CONFIG" -TAGVALS_CMD = "FT.TAGVALS" -ALIAS_ADD_CMD = "FT.ALIASADD" -ALIAS_UPDATE_CMD = "FT.ALIASUPDATE" -ALIAS_DEL_CMD = "FT.ALIASDEL" -INFO_CMD = "FT.INFO" -SUGADD_COMMAND = "FT.SUGADD" -SUGDEL_COMMAND = "FT.SUGDEL" -SUGLEN_COMMAND = "FT.SUGLEN" -SUGGET_COMMAND = "FT.SUGGET" -SYNUPDATE_CMD = "FT.SYNUPDATE" -SYNDUMP_CMD = "FT.SYNDUMP" - -NOOFFSETS = "NOOFFSETS" -NOFIELDS = "NOFIELDS" -STOPWORDS = "STOPWORDS" -WITHSCORES = "WITHSCORES" -FUZZY = "FUZZY" -WITHPAYLOADS = "WITHPAYLOADS" - - -class SearchCommands: - """Search commands.""" - - def batch_indexer(self, chunk_size=100): - """ - Create a new batch indexer from the client with a given chunk size - """ - return self.BatchIndexer(self, chunk_size=chunk_size) - - def create_index( - self, - fields, - no_term_offsets=False, - no_field_flags=False, - stopwords=None, - definition=None, - ): - """ - Create the search index. The index must not already exist. - - ### Parameters: - - - **fields**: a list of TextField or NumericField objects - - **no_term_offsets**: If true, we will not save term offsets in - the index - - **no_field_flags**: If true, we will not save field flags that - allow searching in specific fields - - **stopwords**: If not None, we create the index with this custom - stopword list. The list can be empty - - For more information: https://oss.redis.com/redisearch/Commands/#ftcreate - """ # noqa - - args = [CREATE_CMD, self.index_name] - if definition is not None: - args += definition.args - if no_term_offsets: - args.append(NOOFFSETS) - if no_field_flags: - args.append(NOFIELDS) - if stopwords is not None and isinstance(stopwords, (list, tuple, set)): - args += [STOPWORDS, len(stopwords)] - if len(stopwords) > 0: - args += list(stopwords) - - args.append("SCHEMA") - try: - args += list(itertools.chain(*(f.redis_args() for f in fields))) - except TypeError: - args += fields.redis_args() - - return self.execute_command(*args) - - def alter_schema_add(self, fields): - """ - Alter the existing search index by adding new fields. The index - must already exist. - - ### Parameters: - - - **fields**: a list of Field objects to add for the index - - For more information: https://oss.redis.com/redisearch/Commands/#ftalter_schema_add - """ # noqa - - args = [ALTER_CMD, self.index_name, "SCHEMA", "ADD"] - try: - args += list(itertools.chain(*(f.redis_args() for f in fields))) - except TypeError: - args += fields.redis_args() - - return self.execute_command(*args) - - def dropindex(self, delete_documents=False): - """ - Drop the index if it exists. - Replaced `drop_index` in RediSearch 2.0. - Default behavior was changed to not delete the indexed documents. - - ### Parameters: - - - **delete_documents**: If `True`, all documents will be deleted. - For more information: https://oss.redis.com/redisearch/Commands/#ftdropindex - """ # noqa - keep_str = "" if delete_documents else "KEEPDOCS" - return self.execute_command(DROP_CMD, self.index_name, keep_str) - - def _add_document( - self, - doc_id, - conn=None, - nosave=False, - score=1.0, - payload=None, - replace=False, - partial=False, - language=None, - no_create=False, - **fields, - ): - """ - Internal add_document used for both batch and single doc indexing - """ - if conn is None: - conn = self.client - - if partial or no_create: - replace = True - - args = [ADD_CMD, self.index_name, doc_id, score] - if nosave: - args.append("NOSAVE") - if payload is not None: - args.append("PAYLOAD") - args.append(payload) - if replace: - args.append("REPLACE") - if partial: - args.append("PARTIAL") - if no_create: - args.append("NOCREATE") - if language: - args += ["LANGUAGE", language] - args.append("FIELDS") - args += list(itertools.chain(*fields.items())) - return conn.execute_command(*args) - - def _add_document_hash( - self, - doc_id, - conn=None, - score=1.0, - language=None, - replace=False, - ): - """ - Internal add_document_hash used for both batch and single doc indexing - """ - if conn is None: - conn = self.client - - args = [ADDHASH_CMD, self.index_name, doc_id, score] - - if replace: - args.append("REPLACE") - - if language: - args += ["LANGUAGE", language] - - return conn.execute_command(*args) - - def add_document( - self, - doc_id, - nosave=False, - score=1.0, - payload=None, - replace=False, - partial=False, - language=None, - no_create=False, - **fields, - ): - """ - Add a single document to the index. - - ### Parameters - - - **doc_id**: the id of the saved document. - - **nosave**: if set to true, we just index the document, and don't - save a copy of it. This means that searches will just - return ids. - - **score**: the document ranking, between 0.0 and 1.0 - - **payload**: optional inner-index payload we can save for fast - i access in scoring functions - - **replace**: if True, and the document already is in the index, - we perform an update and reindex the document - - **partial**: if True, the fields specified will be added to the - existing document. - This has the added benefit that any fields specified - with `no_index` - will not be reindexed again. Implies `replace` - - **language**: Specify the language used for document tokenization. - - **no_create**: if True, the document is only updated and reindexed - if it already exists. - If the document does not exist, an error will be - returned. Implies `replace` - - **fields** kwargs dictionary of the document fields to be saved - and/or indexed. - NOTE: Geo points shoule be encoded as strings of "lon,lat" - - For more information: https://oss.redis.com/redisearch/Commands/#ftadd - """ # noqa - return self._add_document( - doc_id, - conn=None, - nosave=nosave, - score=score, - payload=payload, - replace=replace, - partial=partial, - language=language, - no_create=no_create, - **fields, - ) - - def add_document_hash( - self, - doc_id, - score=1.0, - language=None, - replace=False, - ): - """ - Add a hash document to the index. - - ### Parameters - - - **doc_id**: the document's id. This has to be an existing HASH key - in Redis that will hold the fields the index needs. - - **score**: the document ranking, between 0.0 and 1.0 - - **replace**: if True, and the document already is in the index, we - perform an update and reindex the document - - **language**: Specify the language used for document tokenization. - - For more information: https://oss.redis.com/redisearch/Commands/#ftaddhash - """ # noqa - return self._add_document_hash( - doc_id, - conn=None, - score=score, - language=language, - replace=replace, - ) - - def delete_document(self, doc_id, conn=None, delete_actual_document=False): - """ - Delete a document from index - Returns 1 if the document was deleted, 0 if not - - ### Parameters - - - **delete_actual_document**: if set to True, RediSearch also delete - the actual document if it is in the index - - For more information: https://oss.redis.com/redisearch/Commands/#ftdel - """ # noqa - args = [DEL_CMD, self.index_name, doc_id] - if conn is None: - conn = self.client - if delete_actual_document: - args.append("DD") - - return conn.execute_command(*args) - - def load_document(self, id): - """ - Load a single document by id - """ - fields = self.client.hgetall(id) - f2 = {to_string(k): to_string(v) for k, v in fields.items()} - fields = f2 - - try: - del fields["id"] - except KeyError: - pass - - return Document(id=id, **fields) - - def get(self, *ids): - """ - Returns the full contents of multiple documents. - - ### Parameters - - - **ids**: the ids of the saved documents. - - For more information https://oss.redis.com/redisearch/Commands/#ftget - """ - - return self.client.execute_command(MGET_CMD, self.index_name, *ids) - - def info(self): - """ - Get info an stats about the the current index, including the number of - documents, memory consumption, etc - - For more information https://oss.redis.com/redisearch/Commands/#ftinfo - """ - - res = self.client.execute_command(INFO_CMD, self.index_name) - it = map(to_string, res) - return dict(zip(it, it)) - - def _mk_query_args(self, query): - args = [self.index_name] - - if isinstance(query, str): - # convert the query from a text to a query object - query = Query(query) - if not isinstance(query, Query): - raise ValueError(f"Bad query type {type(query)}") - - args += query.get_args() - return args, query - - def search(self, query): - """ - Search the index for a given query, and return a result of documents - - ### Parameters - - - **query**: the search query. Either a text for simple queries with - default parameters, or a Query object for complex queries. - See RediSearch's documentation on query format - - For more information: https://oss.redis.com/redisearch/Commands/#ftsearch - """ # noqa - args, query = self._mk_query_args(query) - st = time.time() - res = self.execute_command(SEARCH_CMD, *args) - - return Result( - res, - not query._no_content, - duration=(time.time() - st) * 1000.0, - has_payload=query._with_payloads, - with_scores=query._with_scores, - ) - - def explain(self, query): - """Returns the execution plan for a complex query. - - For more information: https://oss.redis.com/redisearch/Commands/#ftexplain - """ # noqa - args, query_text = self._mk_query_args(query) - return self.execute_command(EXPLAIN_CMD, *args) - - def explain_cli(self, query): # noqa - raise NotImplementedError("EXPLAINCLI will not be implemented.") - - def aggregate(self, query): - """ - Issue an aggregation query. - - ### Parameters - - **query**: This can be either an `AggregateRequest`, or a `Cursor` - - An `AggregateResult` object is returned. You can access the rows from - its `rows` property, which will always yield the rows of the result. - - Fpr more information: https://oss.redis.com/redisearch/Commands/#ftaggregate - """ # noqa - if isinstance(query, AggregateRequest): - has_cursor = bool(query._cursor) - cmd = [AGGREGATE_CMD, self.index_name] + query.build_args() - elif isinstance(query, Cursor): - has_cursor = True - cmd = [CURSOR_CMD, "READ", self.index_name] + query.build_args() - else: - raise ValueError("Bad query", query) - - raw = self.execute_command(*cmd) - return self._get_AggregateResult(raw, query, has_cursor) - - def _get_AggregateResult(self, raw, query, has_cursor): - if has_cursor: - if isinstance(query, Cursor): - query.cid = raw[1] - cursor = query - else: - cursor = Cursor(raw[1]) - raw = raw[0] - else: - cursor = None - - if isinstance(query, AggregateRequest) and query._with_schema: - schema = raw[0] - rows = raw[2:] - else: - schema = None - rows = raw[1:] - - return AggregateResult(rows, cursor, schema) - - def profile(self, query, limited=False): - """ - Performs a search or aggregate command and collects performance - information. - - ### Parameters - - **query**: This can be either an `AggregateRequest`, `Query` or - string. - **limited**: If set to True, removes details of reader iterator. - - """ - st = time.time() - cmd = [PROFILE_CMD, self.index_name, ""] - if limited: - cmd.append("LIMITED") - cmd.append("QUERY") - - if isinstance(query, AggregateRequest): - cmd[2] = "AGGREGATE" - cmd += query.build_args() - elif isinstance(query, Query): - cmd[2] = "SEARCH" - cmd += query.get_args() - else: - raise ValueError("Must provide AggregateRequest object or " "Query object.") - - res = self.execute_command(*cmd) - - if isinstance(query, AggregateRequest): - result = self._get_AggregateResult(res[0], query, query._cursor) - else: - result = Result( - res[0], - not query._no_content, - duration=(time.time() - st) * 1000.0, - has_payload=query._with_payloads, - with_scores=query._with_scores, - ) - - return result, parse_to_dict(res[1]) - - def spellcheck(self, query, distance=None, include=None, exclude=None): - """ - Issue a spellcheck query - - ### Parameters - - **query**: search query. - **distance***: the maximal Levenshtein distance for spelling - suggestions (default: 1, max: 4). - **include**: specifies an inclusion custom dictionary. - **exclude**: specifies an exclusion custom dictionary. - - For more information: https://oss.redis.com/redisearch/Commands/#ftspellcheck - """ # noqa - cmd = [SPELLCHECK_CMD, self.index_name, query] - if distance: - cmd.extend(["DISTANCE", distance]) - - if include: - cmd.extend(["TERMS", "INCLUDE", include]) - - if exclude: - cmd.extend(["TERMS", "EXCLUDE", exclude]) - - raw = self.execute_command(*cmd) - - corrections = {} - if raw == 0: - return corrections - - for _correction in raw: - if isinstance(_correction, int) and _correction == 0: - continue - - if len(_correction) != 3: - continue - if not _correction[2]: - continue - if not _correction[2][0]: - continue - - # For spellcheck output - # 1) 1) "TERM" - # 2) "{term1}" - # 3) 1) 1) "{score1}" - # 2) "{suggestion1}" - # 2) 1) "{score2}" - # 2) "{suggestion2}" - # - # Following dictionary will be made - # corrections = { - # '{term1}': [ - # {'score': '{score1}', 'suggestion': '{suggestion1}'}, - # {'score': '{score2}', 'suggestion': '{suggestion2}'} - # ] - # } - corrections[_correction[1]] = [ - {"score": _item[0], "suggestion": _item[1]} for _item in _correction[2] - ] - - return corrections - - def dict_add(self, name, *terms): - """Adds terms to a dictionary. - - ### Parameters - - - **name**: Dictionary name. - - **terms**: List of items for adding to the dictionary. - - For more information: https://oss.redis.com/redisearch/Commands/#ftdictadd - """ # noqa - cmd = [DICT_ADD_CMD, name] - cmd.extend(terms) - return self.execute_command(*cmd) - - def dict_del(self, name, *terms): - """Deletes terms from a dictionary. - - ### Parameters - - - **name**: Dictionary name. - - **terms**: List of items for removing from the dictionary. - - For more information: https://oss.redis.com/redisearch/Commands/#ftdictdel - """ # noqa - cmd = [DICT_DEL_CMD, name] - cmd.extend(terms) - return self.execute_command(*cmd) - - def dict_dump(self, name): - """Dumps all terms in the given dictionary. - - ### Parameters - - - **name**: Dictionary name. - - For more information: https://oss.redis.com/redisearch/Commands/#ftdictdump - """ # noqa - cmd = [DICT_DUMP_CMD, name] - return self.execute_command(*cmd) - - def config_set(self, option, value): - """Set runtime configuration option. - - ### Parameters - - - **option**: the name of the configuration option. - - **value**: a value for the configuration option. - - For more information: https://oss.redis.com/redisearch/Commands/#ftconfig - """ # noqa - cmd = [CONFIG_CMD, "SET", option, value] - raw = self.execute_command(*cmd) - return raw == "OK" - - def config_get(self, option): - """Get runtime configuration option value. - - ### Parameters - - - **option**: the name of the configuration option. - - For more information: https://oss.redis.com/redisearch/Commands/#ftconfig - """ # noqa - cmd = [CONFIG_CMD, "GET", option] - res = {} - raw = self.execute_command(*cmd) - if raw: - for kvs in raw: - res[kvs[0]] = kvs[1] - return res - - def tagvals(self, tagfield): - """ - Return a list of all possible tag values - - ### Parameters - - - **tagfield**: Tag field name - - For more information: https://oss.redis.com/redisearch/Commands/#fttagvals - """ # noqa - - return self.execute_command(TAGVALS_CMD, self.index_name, tagfield) - - def aliasadd(self, alias): - """ - Alias a search index - will fail if alias already exists - - ### Parameters - - - **alias**: Name of the alias to create - - For more information: https://oss.redis.com/redisearch/Commands/#ftaliasadd - """ # noqa - - return self.execute_command(ALIAS_ADD_CMD, alias, self.index_name) - - def aliasupdate(self, alias): - """ - Updates an alias - will fail if alias does not already exist - - ### Parameters - - - **alias**: Name of the alias to create - - For more information: https://oss.redis.com/redisearch/Commands/#ftaliasupdate - """ # noqa - - return self.execute_command(ALIAS_UPDATE_CMD, alias, self.index_name) - - def aliasdel(self, alias): - """ - Removes an alias to a search index - - ### Parameters - - - **alias**: Name of the alias to delete - - For more information: https://oss.redis.com/redisearch/Commands/#ftaliasdel - """ # noqa - return self.execute_command(ALIAS_DEL_CMD, alias) - - def sugadd(self, key, *suggestions, **kwargs): - """ - Add suggestion terms to the AutoCompleter engine. Each suggestion has - a score and string. - If kwargs["increment"] is true and the terms are already in the - server's dictionary, we increment their scores. - - For more information: https://oss.redis.com/redisearch/master/Commands/#ftsugadd - """ # noqa - # If Transaction is not False it will MULTI/EXEC which will error - pipe = self.pipeline(transaction=False) - for sug in suggestions: - args = [SUGADD_COMMAND, key, sug.string, sug.score] - if kwargs.get("increment"): - args.append("INCR") - if sug.payload: - args.append("PAYLOAD") - args.append(sug.payload) - - pipe.execute_command(*args) - - return pipe.execute()[-1] - - def suglen(self, key): - """ - Return the number of entries in the AutoCompleter index. - - For more information https://oss.redis.com/redisearch/master/Commands/#ftsuglen - """ # noqa - return self.execute_command(SUGLEN_COMMAND, key) - - def sugdel(self, key, string): - """ - Delete a string from the AutoCompleter index. - Returns 1 if the string was found and deleted, 0 otherwise. - - For more information: https://oss.redis.com/redisearch/master/Commands/#ftsugdel - """ # noqa - return self.execute_command(SUGDEL_COMMAND, key, string) - - def sugget( - self, key, prefix, fuzzy=False, num=10, with_scores=False, with_payloads=False - ): - """ - Get a list of suggestions from the AutoCompleter, for a given prefix. - - Parameters: - - prefix : str - The prefix we are searching. **Must be valid ascii or utf-8** - fuzzy : bool - If set to true, the prefix search is done in fuzzy mode. - **NOTE**: Running fuzzy searches on short (<3 letters) prefixes - can be very - slow, and even scan the entire index. - with_scores : bool - If set to true, we also return the (refactored) score of - each suggestion. - This is normally not needed, and is NOT the original score - inserted into the index. - with_payloads : bool - Return suggestion payloads - num : int - The maximum number of results we return. Note that we might - return less. The algorithm trims irrelevant suggestions. - - Returns: - - list: - A list of Suggestion objects. If with_scores was False, the - score of all suggestions is 1. - - For more information: https://oss.redis.com/redisearch/master/Commands/#ftsugget - """ # noqa - args = [SUGGET_COMMAND, key, prefix, "MAX", num] - if fuzzy: - args.append(FUZZY) - if with_scores: - args.append(WITHSCORES) - if with_payloads: - args.append(WITHPAYLOADS) - - ret = self.execute_command(*args) - results = [] - if not ret: - return results - - parser = SuggestionParser(with_scores, with_payloads, ret) - return [s for s in parser] - - def synupdate(self, groupid, skipinitial=False, *terms): - """ - Updates a synonym group. - The command is used to create or update a synonym group with - additional terms. - Only documents which were indexed after the update will be affected. - - Parameters: - - groupid : - Synonym group id. - skipinitial : bool - If set to true, we do not scan and index. - terms : - The terms. - - For more information: https://oss.redis.com/redisearch/Commands/#ftsynupdate - """ # noqa - cmd = [SYNUPDATE_CMD, self.index_name, groupid] - if skipinitial: - cmd.extend(["SKIPINITIALSCAN"]) - cmd.extend(terms) - return self.execute_command(*cmd) - - def syndump(self): - """ - Dumps the contents of a synonym group. - - The command is used to dump the synonyms data structure. - Returns a list of synonym terms and their synonym group ids. - - For more information: https://oss.redis.com/redisearch/Commands/#ftsyndump - """ # noqa - raw = self.execute_command(SYNDUMP_CMD, self.index_name) - return {raw[i]: raw[i + 1] for i in range(0, len(raw), 2)} diff --git a/redis/commands/search/document.py b/redis/commands/search/document.py deleted file mode 100644 index 5b30505..0000000 --- a/redis/commands/search/document.py +++ /dev/null @@ -1,13 +0,0 @@ -class Document: - """ - Represents a single document in a result set - """ - - def __init__(self, id, payload=None, **fields): - self.id = id - self.payload = payload - for k, v in fields.items(): - setattr(self, k, v) - - def __repr__(self): - return f"Document {self.__dict__}" diff --git a/redis/commands/search/field.py b/redis/commands/search/field.py deleted file mode 100644 index 69e3908..0000000 --- a/redis/commands/search/field.py +++ /dev/null @@ -1,92 +0,0 @@ -class Field: - - NUMERIC = "NUMERIC" - TEXT = "TEXT" - WEIGHT = "WEIGHT" - GEO = "GEO" - TAG = "TAG" - SORTABLE = "SORTABLE" - NOINDEX = "NOINDEX" - AS = "AS" - - def __init__(self, name, args=[], sortable=False, no_index=False, as_name=None): - self.name = name - self.args = args - self.args_suffix = list() - self.as_name = as_name - - if sortable: - self.args_suffix.append(Field.SORTABLE) - if no_index: - self.args_suffix.append(Field.NOINDEX) - - if no_index and not sortable: - raise ValueError("Non-Sortable non-Indexable fields are ignored") - - def append_arg(self, value): - self.args.append(value) - - def redis_args(self): - args = [self.name] - if self.as_name: - args += [self.AS, self.as_name] - args += self.args - args += self.args_suffix - return args - - -class TextField(Field): - """ - TextField is used to define a text field in a schema definition - """ - - NOSTEM = "NOSTEM" - PHONETIC = "PHONETIC" - - def __init__( - self, name, weight=1.0, no_stem=False, phonetic_matcher=None, **kwargs - ): - Field.__init__(self, name, args=[Field.TEXT, Field.WEIGHT, weight], **kwargs) - - if no_stem: - Field.append_arg(self, self.NOSTEM) - if phonetic_matcher and phonetic_matcher in [ - "dm:en", - "dm:fr", - "dm:pt", - "dm:es", - ]: - Field.append_arg(self, self.PHONETIC) - Field.append_arg(self, phonetic_matcher) - - -class NumericField(Field): - """ - NumericField is used to define a numeric field in a schema definition - """ - - def __init__(self, name, **kwargs): - Field.__init__(self, name, args=[Field.NUMERIC], **kwargs) - - -class GeoField(Field): - """ - GeoField is used to define a geo-indexing field in a schema definition - """ - - def __init__(self, name, **kwargs): - Field.__init__(self, name, args=[Field.GEO], **kwargs) - - -class TagField(Field): - """ - TagField is a tag-indexing field with simpler compression and tokenization. - See http://redisearch.io/Tags/ - """ - - SEPARATOR = "SEPARATOR" - - def __init__(self, name, separator=",", **kwargs): - Field.__init__( - self, name, args=[Field.TAG, self.SEPARATOR, separator], **kwargs - ) diff --git a/redis/commands/search/indexDefinition.py b/redis/commands/search/indexDefinition.py deleted file mode 100644 index 0c7a3b0..0000000 --- a/redis/commands/search/indexDefinition.py +++ /dev/null @@ -1,79 +0,0 @@ -from enum import Enum - - -class IndexType(Enum): - """Enum of the currently supported index types.""" - - HASH = 1 - JSON = 2 - - -class IndexDefinition: - """IndexDefinition is used to define a index definition for automatic - indexing on Hash or Json update.""" - - def __init__( - self, - prefix=[], - filter=None, - language_field=None, - language=None, - score_field=None, - score=1.0, - payload_field=None, - index_type=None, - ): - self.args = [] - self._appendIndexType(index_type) - self._appendPrefix(prefix) - self._appendFilter(filter) - self._appendLanguage(language_field, language) - self._appendScore(score_field, score) - self._appendPayload(payload_field) - - def _appendIndexType(self, index_type): - """Append `ON HASH` or `ON JSON` according to the enum.""" - if index_type is IndexType.HASH: - self.args.extend(["ON", "HASH"]) - elif index_type is IndexType.JSON: - self.args.extend(["ON", "JSON"]) - elif index_type is not None: - raise RuntimeError(f"index_type must be one of {list(IndexType)}") - - def _appendPrefix(self, prefix): - """Append PREFIX.""" - if len(prefix) > 0: - self.args.append("PREFIX") - self.args.append(len(prefix)) - for p in prefix: - self.args.append(p) - - def _appendFilter(self, filter): - """Append FILTER.""" - if filter is not None: - self.args.append("FILTER") - self.args.append(filter) - - def _appendLanguage(self, language_field, language): - """Append LANGUAGE_FIELD and LANGUAGE.""" - if language_field is not None: - self.args.append("LANGUAGE_FIELD") - self.args.append(language_field) - if language is not None: - self.args.append("LANGUAGE") - self.args.append(language) - - def _appendScore(self, score_field, score): - """Append SCORE_FIELD and SCORE.""" - if score_field is not None: - self.args.append("SCORE_FIELD") - self.args.append(score_field) - if score is not None: - self.args.append("SCORE") - self.args.append(score) - - def _appendPayload(self, payload_field): - """Append PAYLOAD_FIELD.""" - if payload_field is not None: - self.args.append("PAYLOAD_FIELD") - self.args.append(payload_field) diff --git a/redis/commands/search/query.py b/redis/commands/search/query.py deleted file mode 100644 index 2bb8347..0000000 --- a/redis/commands/search/query.py +++ /dev/null @@ -1,322 +0,0 @@ -class Query: - """ - Query is used to build complex queries that have more parameters than just - the query string. The query string is set in the constructor, and other - options have setter functions. - - The setter functions return the query object, so they can be chained, - i.e. `Query("foo").verbatim().filter(...)` etc. - """ - - def __init__(self, query_string): - """ - Create a new query object. - The query string is set in the constructor, and other options have - setter functions. - """ - - self._query_string = query_string - self._offset = 0 - self._num = 10 - self._no_content = False - self._no_stopwords = False - self._fields = None - self._verbatim = False - self._with_payloads = False - self._with_scores = False - self._scorer = False - self._filters = list() - self._ids = None - self._slop = -1 - self._in_order = False - self._sortby = None - self._return_fields = [] - self._summarize_fields = [] - self._highlight_fields = [] - self._language = None - self._expander = None - - def query_string(self): - """Return the query string of this query only.""" - return self._query_string - - def limit_ids(self, *ids): - """Limit the results to a specific set of pre-known document - ids of any length.""" - self._ids = ids - return self - - def return_fields(self, *fields): - """Add fields to return fields.""" - self._return_fields += fields - return self - - def return_field(self, field, as_field=None): - """Add field to return fields (Optional: add 'AS' name - to the field).""" - self._return_fields.append(field) - if as_field is not None: - self._return_fields += ("AS", as_field) - return self - - def _mk_field_list(self, fields): - if not fields: - return [] - return [fields] if isinstance(fields, str) else list(fields) - - def summarize(self, fields=None, context_len=None, num_frags=None, sep=None): - """ - Return an abridged format of the field, containing only the segments of - the field which contain the matching term(s). - - If `fields` is specified, then only the mentioned fields are - summarized; otherwise all results are summarized. - - Server side defaults are used for each option (except `fields`) - if not specified - - - **fields** List of fields to summarize. All fields are summarized - if not specified - - **context_len** Amount of context to include with each fragment - - **num_frags** Number of fragments per document - - **sep** Separator string to separate fragments - """ - args = ["SUMMARIZE"] - fields = self._mk_field_list(fields) - if fields: - args += ["FIELDS", str(len(fields))] + fields - - if context_len is not None: - args += ["LEN", str(context_len)] - if num_frags is not None: - args += ["FRAGS", str(num_frags)] - if sep is not None: - args += ["SEPARATOR", sep] - - self._summarize_fields = args - return self - - def highlight(self, fields=None, tags=None): - """ - Apply specified markup to matched term(s) within the returned field(s). - - - **fields** If specified then only those mentioned fields are - highlighted, otherwise all fields are highlighted - - **tags** A list of two strings to surround the match. - """ - args = ["HIGHLIGHT"] - fields = self._mk_field_list(fields) - if fields: - args += ["FIELDS", str(len(fields))] + fields - if tags: - args += ["TAGS"] + list(tags) - - self._highlight_fields = args - return self - - def language(self, language): - """ - Analyze the query as being in the specified language. - - :param language: The language (e.g. `chinese` or `english`) - """ - self._language = language - return self - - def slop(self, slop): - """Allow a maximum of N intervening non matched terms between - phrase terms (0 means exact phrase). - """ - self._slop = slop - return self - - def in_order(self): - """ - Match only documents where the query terms appear in - the same order in the document. - i.e. for the query "hello world", we do not match "world hello" - """ - self._in_order = True - return self - - def scorer(self, scorer): - """ - Use a different scoring function to evaluate document relevance. - Default is `TFIDF`. - - :param scorer: The scoring function to use - (e.g. `TFIDF.DOCNORM` or `BM25`) - """ - self._scorer = scorer - return self - - def get_args(self): - """Format the redis arguments for this query and return them.""" - args = [self._query_string] - args += self._get_args_tags() - args += self._summarize_fields + self._highlight_fields - args += ["LIMIT", self._offset, self._num] - return args - - def _get_args_tags(self): - args = [] - if self._no_content: - args.append("NOCONTENT") - if self._fields: - args.append("INFIELDS") - args.append(len(self._fields)) - args += self._fields - if self._verbatim: - args.append("VERBATIM") - if self._no_stopwords: - args.append("NOSTOPWORDS") - if self._filters: - for flt in self._filters: - if not isinstance(flt, Filter): - raise AttributeError("Did not receive a Filter object.") - args += flt.args - if self._with_payloads: - args.append("WITHPAYLOADS") - if self._scorer: - args += ["SCORER", self._scorer] - if self._with_scores: - args.append("WITHSCORES") - if self._ids: - args.append("INKEYS") - args.append(len(self._ids)) - args += self._ids - if self._slop >= 0: - args += ["SLOP", self._slop] - if self._in_order: - args.append("INORDER") - if self._return_fields: - args.append("RETURN") - args.append(len(self._return_fields)) - args += self._return_fields - if self._sortby: - if not isinstance(self._sortby, SortbyField): - raise AttributeError("Did not receive a SortByField.") - args.append("SORTBY") - args += self._sortby.args - if self._language: - args += ["LANGUAGE", self._language] - if self._expander: - args += ["EXPANDER", self._expander] - - return args - - def paging(self, offset, num): - """ - Set the paging for the query (defaults to 0..10). - - - **offset**: Paging offset for the results. Defaults to 0 - - **num**: How many results do we want - """ - self._offset = offset - self._num = num - return self - - def verbatim(self): - """Set the query to be verbatim, i.e. use no query expansion - or stemming. - """ - self._verbatim = True - return self - - def no_content(self): - """Set the query to only return ids and not the document content.""" - self._no_content = True - return self - - def no_stopwords(self): - """ - Prevent the query from being filtered for stopwords. - Only useful in very big queries that you are certain contain - no stopwords. - """ - self._no_stopwords = True - return self - - def with_payloads(self): - """Ask the engine to return document payloads.""" - self._with_payloads = True - return self - - def with_scores(self): - """Ask the engine to return document search scores.""" - self._with_scores = True - return self - - def limit_fields(self, *fields): - """ - Limit the search to specific TEXT fields only. - - - **fields**: A list of strings, case sensitive field names - from the defined schema. - """ - self._fields = fields - return self - - def add_filter(self, flt): - """ - Add a numeric or geo filter to the query. - **Currently only one of each filter is supported by the engine** - - - **flt**: A NumericFilter or GeoFilter object, used on a - corresponding field - """ - - self._filters.append(flt) - return self - - def sort_by(self, field, asc=True): - """ - Add a sortby field to the query. - - - **field** - the name of the field to sort by - - **asc** - when `True`, sorting will be done in asceding order - """ - self._sortby = SortbyField(field, asc) - return self - - def expander(self, expander): - """ - Add a expander field to the query. - - - **expander** - the name of the expander - """ - self._expander = expander - return self - - -class Filter: - def __init__(self, keyword, field, *args): - self.args = [keyword, field] + list(args) - - -class NumericFilter(Filter): - INF = "+inf" - NEG_INF = "-inf" - - def __init__(self, field, minval, maxval, minExclusive=False, maxExclusive=False): - args = [ - minval if not minExclusive else f"({minval}", - maxval if not maxExclusive else f"({maxval}", - ] - - Filter.__init__(self, "FILTER", field, *args) - - -class GeoFilter(Filter): - METERS = "m" - KILOMETERS = "km" - FEET = "ft" - MILES = "mi" - - def __init__(self, field, lon, lat, radius, unit=KILOMETERS): - Filter.__init__(self, "GEOFILTER", field, lon, lat, radius, unit) - - -class SortbyField: - def __init__(self, field, asc=True): - self.args = [field, "ASC" if asc else "DESC"] diff --git a/redis/commands/search/querystring.py b/redis/commands/search/querystring.py deleted file mode 100644 index 1da0387..0000000 --- a/redis/commands/search/querystring.py +++ /dev/null @@ -1,314 +0,0 @@ -def tags(*t): - """ - Indicate that the values should be matched to a tag field - - ### Parameters - - - **t**: Tags to search for - """ - if not t: - raise ValueError("At least one tag must be specified") - return TagValue(*t) - - -def between(a, b, inclusive_min=True, inclusive_max=True): - """ - Indicate that value is a numeric range - """ - return RangeValue(a, b, inclusive_min=inclusive_min, inclusive_max=inclusive_max) - - -def equal(n): - """ - Match a numeric value - """ - return between(n, n) - - -def lt(n): - """ - Match any value less than n - """ - return between(None, n, inclusive_max=False) - - -def le(n): - """ - Match any value less or equal to n - """ - return between(None, n, inclusive_max=True) - - -def gt(n): - """ - Match any value greater than n - """ - return between(n, None, inclusive_min=False) - - -def ge(n): - """ - Match any value greater or equal to n - """ - return between(n, None, inclusive_min=True) - - -def geo(lat, lon, radius, unit="km"): - """ - Indicate that value is a geo region - """ - return GeoValue(lat, lon, radius, unit) - - -class Value: - @property - def combinable(self): - """ - Whether this type of value may be combined with other values - for the same field. This makes the filter potentially more efficient - """ - return False - - @staticmethod - def make_value(v): - """ - Convert an object to a value, if it is not a value already - """ - if isinstance(v, Value): - return v - return ScalarValue(v) - - def to_string(self): - raise NotImplementedError() - - def __str__(self): - return self.to_string() - - -class RangeValue(Value): - combinable = False - - def __init__(self, a, b, inclusive_min=False, inclusive_max=False): - if a is None: - a = "-inf" - if b is None: - b = "inf" - self.range = [str(a), str(b)] - self.inclusive_min = inclusive_min - self.inclusive_max = inclusive_max - - def to_string(self): - return "[{1}{0[0]} {2}{0[1]}]".format( - self.range, - "(" if not self.inclusive_min else "", - "(" if not self.inclusive_max else "", - ) - - -class ScalarValue(Value): - combinable = True - - def __init__(self, v): - self.v = str(v) - - def to_string(self): - return self.v - - -class TagValue(Value): - combinable = False - - def __init__(self, *tags): - self.tags = tags - - def to_string(self): - return "{" + " | ".join(str(t) for t in self.tags) + "}" - - -class GeoValue(Value): - def __init__(self, lon, lat, radius, unit="km"): - self.lon = lon - self.lat = lat - self.radius = radius - self.unit = unit - - -class Node: - def __init__(self, *children, **kwparams): - """ - Create a node - - ### Parameters - - - **children**: One or more sub-conditions. These can be additional - `intersect`, `disjunct`, `union`, `optional`, or any other `Node` - type. - - The semantics of multiple conditions are dependent on the type of - query. For an `intersection` node, this amounts to a logical AND, - for a `union` node, this amounts to a logical `OR`. - - - **kwparams**: key-value parameters. Each key is the name of a field, - and the value should be a field value. This can be one of the - following: - - - Simple string (for text field matches) - - value returned by one of the helper functions - - list of either a string or a value - - - ### Examples - - Field `num` should be between 1 and 10 - ``` - intersect(num=between(1, 10) - ``` - - Name can either be `bob` or `john` - - ``` - union(name=("bob", "john")) - ``` - - Don't select countries in Israel, Japan, or US - - ``` - disjunct_union(country=("il", "jp", "us")) - ``` - """ - - self.params = [] - - kvparams = {} - for k, v in kwparams.items(): - curvals = kvparams.setdefault(k, []) - if isinstance(v, (str, int, float)): - curvals.append(Value.make_value(v)) - elif isinstance(v, Value): - curvals.append(v) - else: - curvals.extend(Value.make_value(subv) for subv in v) - - self.params += [Node.to_node(p) for p in children] - - for k, v in kvparams.items(): - self.params.extend(self.join_fields(k, v)) - - def join_fields(self, key, vals): - if len(vals) == 1: - return [BaseNode(f"@{key}:{vals[0].to_string()}")] - if not vals[0].combinable: - return [BaseNode(f"@{key}:{v.to_string()}") for v in vals] - s = BaseNode(f"@{key}:({self.JOINSTR.join(v.to_string() for v in vals)})") - return [s] - - @classmethod - def to_node(cls, obj): # noqa - if isinstance(obj, Node): - return obj - return BaseNode(obj) - - @property - def JOINSTR(self): - raise NotImplementedError() - - def to_string(self, with_parens=None): - with_parens = self._should_use_paren(with_parens) - pre, post = ("(", ")") if with_parens else ("", "") - return f"{pre}{self.JOINSTR.join(n.to_string() for n in self.params)}{post}" - - def _should_use_paren(self, optval): - if optval is not None: - return optval - return len(self.params) > 1 - - def __str__(self): - return self.to_string() - - -class BaseNode(Node): - def __init__(self, s): - super().__init__() - self.s = str(s) - - def to_string(self, with_parens=None): - return self.s - - -class IntersectNode(Node): - """ - Create an intersection node. All children need to be satisfied in order for - this node to evaluate as true - """ - - JOINSTR = " " - - -class UnionNode(Node): - """ - Create a union node. Any of the children need to be satisfied in order for - this node to evaluate as true - """ - - JOINSTR = "|" - - -class DisjunctNode(IntersectNode): - """ - Create a disjunct node. In order for this node to be true, all of its - children must evaluate to false - """ - - def to_string(self, with_parens=None): - with_parens = self._should_use_paren(with_parens) - ret = super().to_string(with_parens=False) - if with_parens: - return "(-" + ret + ")" - else: - return "-" + ret - - -class DistjunctUnion(DisjunctNode): - """ - This node is true if *all* of its children are false. This is equivalent to - ``` - disjunct(union(...)) - ``` - """ - - JOINSTR = "|" - - -class OptionalNode(IntersectNode): - """ - Create an optional node. If this nodes evaluates to true, then the document - will be rated higher in score/rank. - """ - - def to_string(self, with_parens=None): - with_parens = self._should_use_paren(with_parens) - ret = super().to_string(with_parens=False) - if with_parens: - return "(~" + ret + ")" - else: - return "~" + ret - - -def intersect(*args, **kwargs): - return IntersectNode(*args, **kwargs) - - -def union(*args, **kwargs): - return UnionNode(*args, **kwargs) - - -def disjunct(*args, **kwargs): - return DisjunctNode(*args, **kwargs) - - -def disjunct_union(*args, **kwargs): - return DistjunctUnion(*args, **kwargs) - - -def querystring(*args, **kwargs): - return intersect(*args, **kwargs).to_string() diff --git a/redis/commands/search/reducers.py b/redis/commands/search/reducers.py deleted file mode 100644 index 41ed11a..0000000 --- a/redis/commands/search/reducers.py +++ /dev/null @@ -1,178 +0,0 @@ -from .aggregation import Reducer, SortDirection - - -class FieldOnlyReducer(Reducer): - def __init__(self, field): - super().__init__(field) - self._field = field - - -class count(Reducer): - """ - Counts the number of results in the group - """ - - NAME = "COUNT" - - def __init__(self): - super().__init__() - - -class sum(FieldOnlyReducer): - """ - Calculates the sum of all the values in the given fields within the group - """ - - NAME = "SUM" - - def __init__(self, field): - super().__init__(field) - - -class min(FieldOnlyReducer): - """ - Calculates the smallest value in the given field within the group - """ - - NAME = "MIN" - - def __init__(self, field): - super().__init__(field) - - -class max(FieldOnlyReducer): - """ - Calculates the largest value in the given field within the group - """ - - NAME = "MAX" - - def __init__(self, field): - super().__init__(field) - - -class avg(FieldOnlyReducer): - """ - Calculates the mean value in the given field within the group - """ - - NAME = "AVG" - - def __init__(self, field): - super().__init__(field) - - -class tolist(FieldOnlyReducer): - """ - Returns all the matched properties in a list - """ - - NAME = "TOLIST" - - def __init__(self, field): - super().__init__(field) - - -class count_distinct(FieldOnlyReducer): - """ - Calculate the number of distinct values contained in all the results in - the group for the given field - """ - - NAME = "COUNT_DISTINCT" - - def __init__(self, field): - super().__init__(field) - - -class count_distinctish(FieldOnlyReducer): - """ - Calculate the number of distinct values contained in all the results in the - group for the given field. This uses a faster algorithm than - `count_distinct` but is less accurate - """ - - NAME = "COUNT_DISTINCTISH" - - -class quantile(Reducer): - """ - Return the value for the nth percentile within the range of values for the - field within the group. - """ - - NAME = "QUANTILE" - - def __init__(self, field, pct): - super().__init__(field, str(pct)) - self._field = field - - -class stddev(FieldOnlyReducer): - """ - Return the standard deviation for the values within the group - """ - - NAME = "STDDEV" - - def __init__(self, field): - super().__init__(field) - - -class first_value(Reducer): - """ - Selects the first value within the group according to sorting parameters - """ - - NAME = "FIRST_VALUE" - - def __init__(self, field, *byfields): - """ - Selects the first value of the given field within the group. - - ### Parameter - - - **field**: Source field used for the value - - **byfields**: How to sort the results. This can be either the - *class* of `aggregation.Asc` or `aggregation.Desc` in which - case the field `field` is also used as the sort input. - - `byfields` can also be one or more *instances* of `Asc` or `Desc` - indicating the sort order for these fields - """ - - fieldstrs = [] - if ( - len(byfields) == 1 - and isinstance(byfields[0], type) - and issubclass(byfields[0], SortDirection) - ): - byfields = [byfields[0](field)] - - for f in byfields: - fieldstrs += [f.field, f.DIRSTRING] - - args = [field] - if fieldstrs: - args += ["BY"] + fieldstrs - super().__init__(*args) - self._field = field - - -class random_sample(Reducer): - """ - Returns a random sample of items from the dataset, from the given property - """ - - NAME = "RANDOM_SAMPLE" - - def __init__(self, field, size): - """ - ### Parameter - - **field**: Field to sample from - **size**: Return this many items (can be less) - """ - args = [field, str(size)] - super().__init__(*args) - self._field = field diff --git a/redis/commands/search/result.py b/redis/commands/search/result.py deleted file mode 100644 index 5f4aca6..0000000 --- a/redis/commands/search/result.py +++ /dev/null @@ -1,73 +0,0 @@ -from ._util import to_string -from .document import Document - - -class Result: - """ - Represents the result of a search query, and has an array of Document - objects - """ - - def __init__( - self, res, hascontent, duration=0, has_payload=False, with_scores=False - ): - """ - - **snippets**: An optional dictionary of the form - {field: snippet_size} for snippet formatting - """ - - self.total = res[0] - self.duration = duration - self.docs = [] - - step = 1 - if hascontent: - step = step + 1 - if has_payload: - step = step + 1 - if with_scores: - step = step + 1 - - offset = 2 if with_scores else 1 - - for i in range(1, len(res), step): - id = to_string(res[i]) - payload = to_string(res[i + offset]) if has_payload else None - # fields_offset = 2 if has_payload else 1 - fields_offset = offset + 1 if has_payload else offset - score = float(res[i + 1]) if with_scores else None - - fields = {} - if hascontent: - fields = ( - dict( - dict( - zip( - map(to_string, res[i + fields_offset][::2]), - map(to_string, res[i + fields_offset][1::2]), - ) - ) - ) - if hascontent - else {} - ) - try: - del fields["id"] - except KeyError: - pass - - try: - fields["json"] = fields["$"] - del fields["$"] - except KeyError: - pass - - doc = ( - Document(id, score=score, payload=payload, **fields) - if with_scores - else Document(id, payload=payload, **fields) - ) - self.docs.append(doc) - - def __repr__(self): - return f"Result{{{self.total} total, docs: {self.docs}}}" diff --git a/redis/commands/search/suggestion.py b/redis/commands/search/suggestion.py deleted file mode 100644 index 5d1eba6..0000000 --- a/redis/commands/search/suggestion.py +++ /dev/null @@ -1,51 +0,0 @@ -from ._util import to_string - - -class Suggestion: - """ - Represents a single suggestion being sent or returned from the - autocomplete server - """ - - def __init__(self, string, score=1.0, payload=None): - self.string = to_string(string) - self.payload = to_string(payload) - self.score = score - - def __repr__(self): - return self.string - - -class SuggestionParser: - """ - Internal class used to parse results from the `SUGGET` command. - This needs to consume either 1, 2, or 3 values at a time from - the return value depending on what objects were requested - """ - - def __init__(self, with_scores, with_payloads, ret): - self.with_scores = with_scores - self.with_payloads = with_payloads - - if with_scores and with_payloads: - self.sugsize = 3 - self._scoreidx = 1 - self._payloadidx = 2 - elif with_scores: - self.sugsize = 2 - self._scoreidx = 1 - elif with_payloads: - self.sugsize = 2 - self._payloadidx = 1 - else: - self.sugsize = 1 - self._scoreidx = -1 - - self._sugs = ret - - def __iter__(self): - for i in range(0, len(self._sugs), self.sugsize): - ss = self._sugs[i] - score = float(self._sugs[i + self._scoreidx]) if self.with_scores else 1.0 - payload = self._sugs[i + self._payloadidx] if self.with_payloads else None - yield Suggestion(ss, score, payload) diff --git a/redis/commands/sentinel.py b/redis/commands/sentinel.py deleted file mode 100644 index a9b06c2..0000000 --- a/redis/commands/sentinel.py +++ /dev/null @@ -1,93 +0,0 @@ -import warnings - - -class SentinelCommands: - """ - A class containing the commands specific to redis sentinal. This class is - to be used as a mixin. - """ - - def sentinel(self, *args): - "Redis Sentinel's SENTINEL command." - warnings.warn(DeprecationWarning("Use the individual sentinel_* methods")) - - def sentinel_get_master_addr_by_name(self, service_name): - "Returns a (host, port) pair for the given ``service_name``" - return self.execute_command("SENTINEL GET-MASTER-ADDR-BY-NAME", service_name) - - def sentinel_master(self, service_name): - "Returns a dictionary containing the specified masters state." - return self.execute_command("SENTINEL MASTER", service_name) - - def sentinel_masters(self): - "Returns a list of dictionaries containing each master's state." - return self.execute_command("SENTINEL MASTERS") - - def sentinel_monitor(self, name, ip, port, quorum): - "Add a new master to Sentinel to be monitored" - return self.execute_command("SENTINEL MONITOR", name, ip, port, quorum) - - def sentinel_remove(self, name): - "Remove a master from Sentinel's monitoring" - return self.execute_command("SENTINEL REMOVE", name) - - def sentinel_sentinels(self, service_name): - "Returns a list of sentinels for ``service_name``" - return self.execute_command("SENTINEL SENTINELS", service_name) - - def sentinel_set(self, name, option, value): - "Set Sentinel monitoring parameters for a given master" - return self.execute_command("SENTINEL SET", name, option, value) - - def sentinel_slaves(self, service_name): - "Returns a list of slaves for ``service_name``" - return self.execute_command("SENTINEL SLAVES", service_name) - - def sentinel_reset(self, pattern): - """ - This command will reset all the masters with matching name. - The pattern argument is a glob-style pattern. - - The reset process clears any previous state in a master (including a - failover in progress), and removes every slave and sentinel already - discovered and associated with the master. - """ - return self.execute_command("SENTINEL RESET", pattern, once=True) - - def sentinel_failover(self, new_master_name): - """ - Force a failover as if the master was not reachable, and without - asking for agreement to other Sentinels (however a new version of the - configuration will be published so that the other Sentinels will - update their configurations). - """ - return self.execute_command("SENTINEL FAILOVER", new_master_name) - - def sentinel_ckquorum(self, new_master_name): - """ - Check if the current Sentinel configuration is able to reach the - quorum needed to failover a master, and the majority needed to - authorize the failover. - - This command should be used in monitoring systems to check if a - Sentinel deployment is ok. - """ - return self.execute_command("SENTINEL CKQUORUM", new_master_name, once=True) - - def sentinel_flushconfig(self): - """ - Force Sentinel to rewrite its configuration on disk, including the - current Sentinel state. - - Normally Sentinel rewrites the configuration every time something - changes in its state (in the context of the subset of the state which - is persisted on disk across restart). - However sometimes it is possible that the configuration file is lost - because of operation errors, disk failures, package upgrade scripts or - configuration managers. In those cases a way to to force Sentinel to - rewrite the configuration file is handy. - - This command works even if the previous configuration file is - completely missing. - """ - return self.execute_command("SENTINEL FLUSHCONFIG") diff --git a/redis/commands/timeseries/__init__.py b/redis/commands/timeseries/__init__.py deleted file mode 100644 index 5b1f151..0000000 --- a/redis/commands/timeseries/__init__.py +++ /dev/null @@ -1,80 +0,0 @@ -import redis.client - -from ..helpers import parse_to_list -from .commands import ( - ALTER_CMD, - CREATE_CMD, - CREATERULE_CMD, - DEL_CMD, - DELETERULE_CMD, - GET_CMD, - INFO_CMD, - MGET_CMD, - MRANGE_CMD, - MREVRANGE_CMD, - QUERYINDEX_CMD, - RANGE_CMD, - REVRANGE_CMD, - TimeSeriesCommands, -) -from .info import TSInfo -from .utils import parse_get, parse_m_get, parse_m_range, parse_range - - -class TimeSeries(TimeSeriesCommands): - """ - This class subclasses redis-py's `Redis` and implements RedisTimeSeries's - commands (prefixed with "ts"). - The client allows to interact with RedisTimeSeries and use all of it's - functionality. - """ - - def __init__(self, client=None, **kwargs): - """Create a new RedisTimeSeries client.""" - # Set the module commands' callbacks - self.MODULE_CALLBACKS = { - CREATE_CMD: redis.client.bool_ok, - ALTER_CMD: redis.client.bool_ok, - CREATERULE_CMD: redis.client.bool_ok, - DEL_CMD: int, - DELETERULE_CMD: redis.client.bool_ok, - RANGE_CMD: parse_range, - REVRANGE_CMD: parse_range, - MRANGE_CMD: parse_m_range, - MREVRANGE_CMD: parse_m_range, - GET_CMD: parse_get, - MGET_CMD: parse_m_get, - INFO_CMD: TSInfo, - QUERYINDEX_CMD: parse_to_list, - } - - self.client = client - self.execute_command = client.execute_command - - for key, value in self.MODULE_CALLBACKS.items(): - self.client.set_response_callback(key, value) - - def pipeline(self, transaction=True, shard_hint=None): - """Creates a pipeline for the TimeSeries module, that can be used - for executing only TimeSeries commands and core commands. - - Usage example: - - r = redis.Redis() - pipe = r.ts().pipeline() - for i in range(100): - pipeline.add("with_pipeline", i, 1.1 * i) - pipeline.execute() - - """ - p = Pipeline( - connection_pool=self.client.connection_pool, - response_callbacks=self.MODULE_CALLBACKS, - transaction=transaction, - shard_hint=shard_hint, - ) - return p - - -class Pipeline(TimeSeriesCommands, redis.client.Pipeline): - """Pipeline for the module.""" diff --git a/redis/commands/timeseries/commands.py b/redis/commands/timeseries/commands.py deleted file mode 100644 index c86e0b9..0000000 --- a/redis/commands/timeseries/commands.py +++ /dev/null @@ -1,768 +0,0 @@ -from redis.exceptions import DataError - -ADD_CMD = "TS.ADD" -ALTER_CMD = "TS.ALTER" -CREATERULE_CMD = "TS.CREATERULE" -CREATE_CMD = "TS.CREATE" -DECRBY_CMD = "TS.DECRBY" -DELETERULE_CMD = "TS.DELETERULE" -DEL_CMD = "TS.DEL" -GET_CMD = "TS.GET" -INCRBY_CMD = "TS.INCRBY" -INFO_CMD = "TS.INFO" -MADD_CMD = "TS.MADD" -MGET_CMD = "TS.MGET" -MRANGE_CMD = "TS.MRANGE" -MREVRANGE_CMD = "TS.MREVRANGE" -QUERYINDEX_CMD = "TS.QUERYINDEX" -RANGE_CMD = "TS.RANGE" -REVRANGE_CMD = "TS.REVRANGE" - - -class TimeSeriesCommands: - """RedisTimeSeries Commands.""" - - def create(self, key, **kwargs): - """ - Create a new time-series. - - Args: - - key: - time-series key - retention_msecs: - Maximum age for samples compared to last event time (in milliseconds). - If None or 0 is passed then the series is not trimmed at all. - uncompressed: - Since RedisTimeSeries v1.2, both timestamps and values are - compressed by default. - Adding this flag will keep data in an uncompressed form. - Compression not only saves - memory but usually improve performance due to lower number - of memory accesses. - labels: - Set of label-value pairs that represent metadata labels of the key. - chunk_size: - Each time-serie uses chunks of memory of fixed size for - time series samples. - You can alter the default TSDB chunk size by passing the - chunk_size argument (in Bytes). - duplicate_policy: - Since RedisTimeSeries v1.4 you can specify the duplicate sample policy - ( Configure what to do on duplicate sample. ) - Can be one of: - - 'block': an error will occur for any out of order sample. - - 'first': ignore the new value. - - 'last': override with latest value. - - 'min': only override if the value is lower than the existing value. - - 'max': only override if the value is higher than the existing value. - When this is not set, the server-wide default will be used. - - For more information: https://oss.redis.com/redistimeseries/commands/#tscreate - """ # noqa - retention_msecs = kwargs.get("retention_msecs", None) - uncompressed = kwargs.get("uncompressed", False) - labels = kwargs.get("labels", {}) - chunk_size = kwargs.get("chunk_size", None) - duplicate_policy = kwargs.get("duplicate_policy", None) - params = [key] - self._appendRetention(params, retention_msecs) - self._appendUncompressed(params, uncompressed) - self._appendChunkSize(params, chunk_size) - self._appendDuplicatePolicy(params, CREATE_CMD, duplicate_policy) - self._appendLabels(params, labels) - - return self.execute_command(CREATE_CMD, *params) - - def alter(self, key, **kwargs): - """ - Update the retention, labels of an existing key. - For more information see - - The parameters are the same as TS.CREATE. - - For more information: https://oss.redis.com/redistimeseries/commands/#tsalter - """ # noqa - retention_msecs = kwargs.get("retention_msecs", None) - labels = kwargs.get("labels", {}) - duplicate_policy = kwargs.get("duplicate_policy", None) - params = [key] - self._appendRetention(params, retention_msecs) - self._appendDuplicatePolicy(params, ALTER_CMD, duplicate_policy) - self._appendLabels(params, labels) - - return self.execute_command(ALTER_CMD, *params) - - def add(self, key, timestamp, value, **kwargs): - """ - Append (or create and append) a new sample to the series. - For more information see - - Args: - - key: - time-series key - timestamp: - Timestamp of the sample. * can be used for automatic timestamp (using the system clock). - value: - Numeric data value of the sample - retention_msecs: - Maximum age for samples compared to last event time (in milliseconds). - If None or 0 is passed then the series is not trimmed at all. - uncompressed: - Since RedisTimeSeries v1.2, both timestamps and values are compressed by default. - Adding this flag will keep data in an uncompressed form. Compression not only saves - memory but usually improve performance due to lower number of memory accesses. - labels: - Set of label-value pairs that represent metadata labels of the key. - chunk_size: - Each time-serie uses chunks of memory of fixed size for time series samples. - You can alter the default TSDB chunk size by passing the chunk_size argument (in Bytes). - duplicate_policy: - Since RedisTimeSeries v1.4 you can specify the duplicate sample policy - (Configure what to do on duplicate sample). - Can be one of: - - 'block': an error will occur for any out of order sample. - - 'first': ignore the new value. - - 'last': override with latest value. - - 'min': only override if the value is lower than the existing value. - - 'max': only override if the value is higher than the existing value. - When this is not set, the server-wide default will be used. - - For more information: https://oss.redis.com/redistimeseries/master/commands/#tsadd - """ # noqa - retention_msecs = kwargs.get("retention_msecs", None) - uncompressed = kwargs.get("uncompressed", False) - labels = kwargs.get("labels", {}) - chunk_size = kwargs.get("chunk_size", None) - duplicate_policy = kwargs.get("duplicate_policy", None) - params = [key, timestamp, value] - self._appendRetention(params, retention_msecs) - self._appendUncompressed(params, uncompressed) - self._appendChunkSize(params, chunk_size) - self._appendDuplicatePolicy(params, ADD_CMD, duplicate_policy) - self._appendLabels(params, labels) - - return self.execute_command(ADD_CMD, *params) - - def madd(self, ktv_tuples): - """ - Append (or create and append) a new `value` to series - `key` with `timestamp`. - Expects a list of `tuples` as (`key`,`timestamp`, `value`). - Return value is an array with timestamps of insertions. - - For more information: https://oss.redis.com/redistimeseries/master/commands/#tsmadd - """ # noqa - params = [] - for ktv in ktv_tuples: - for item in ktv: - params.append(item) - - return self.execute_command(MADD_CMD, *params) - - def incrby(self, key, value, **kwargs): - """ - Increment (or create an time-series and increment) the latest - sample's of a series. - This command can be used as a counter or gauge that automatically gets - history as a time series. - - Args: - - key: - time-series key - value: - Numeric data value of the sample - timestamp: - Timestamp of the sample. None can be used for automatic timestamp (using the system clock). - retention_msecs: - Maximum age for samples compared to last event time (in milliseconds). - If None or 0 is passed then the series is not trimmed at all. - uncompressed: - Since RedisTimeSeries v1.2, both timestamps and values are compressed by default. - Adding this flag will keep data in an uncompressed form. Compression not only saves - memory but usually improve performance due to lower number of memory accesses. - labels: - Set of label-value pairs that represent metadata labels of the key. - chunk_size: - Each time-series uses chunks of memory of fixed size for time series samples. - You can alter the default TSDB chunk size by passing the chunk_size argument (in Bytes). - - For more information: https://oss.redis.com/redistimeseries/master/commands/#tsincrbytsdecrby - """ # noqa - timestamp = kwargs.get("timestamp", None) - retention_msecs = kwargs.get("retention_msecs", None) - uncompressed = kwargs.get("uncompressed", False) - labels = kwargs.get("labels", {}) - chunk_size = kwargs.get("chunk_size", None) - params = [key, value] - self._appendTimestamp(params, timestamp) - self._appendRetention(params, retention_msecs) - self._appendUncompressed(params, uncompressed) - self._appendChunkSize(params, chunk_size) - self._appendLabels(params, labels) - - return self.execute_command(INCRBY_CMD, *params) - - def decrby(self, key, value, **kwargs): - """ - Decrement (or create an time-series and decrement) the - latest sample's of a series. - This command can be used as a counter or gauge that - automatically gets history as a time series. - - Args: - - key: - time-series key - value: - Numeric data value of the sample - timestamp: - Timestamp of the sample. None can be used for automatic - timestamp (using the system clock). - retention_msecs: - Maximum age for samples compared to last event time (in milliseconds). - If None or 0 is passed then the series is not trimmed at all. - uncompressed: - Since RedisTimeSeries v1.2, both timestamps and values are - compressed by default. - Adding this flag will keep data in an uncompressed form. - Compression not only saves - memory but usually improve performance due to lower number - of memory accesses. - labels: - Set of label-value pairs that represent metadata labels of the key. - chunk_size: - Each time-series uses chunks of memory of fixed size for time series samples. - You can alter the default TSDB chunk size by passing the chunk_size argument (in Bytes). - - For more information: https://oss.redis.com/redistimeseries/master/commands/#tsincrbytsdecrby - """ # noqa - timestamp = kwargs.get("timestamp", None) - retention_msecs = kwargs.get("retention_msecs", None) - uncompressed = kwargs.get("uncompressed", False) - labels = kwargs.get("labels", {}) - chunk_size = kwargs.get("chunk_size", None) - params = [key, value] - self._appendTimestamp(params, timestamp) - self._appendRetention(params, retention_msecs) - self._appendUncompressed(params, uncompressed) - self._appendChunkSize(params, chunk_size) - self._appendLabels(params, labels) - - return self.execute_command(DECRBY_CMD, *params) - - def delete(self, key, from_time, to_time): - """ - Delete data points for a given timeseries and interval range - in the form of start and end delete timestamps. - The given timestamp interval is closed (inclusive), meaning start - and end data points will also be deleted. - Return the count for deleted items. - For more information see - - Args: - - key: - time-series key. - from_time: - Start timestamp for the range deletion. - to_time: - End timestamp for the range deletion. - - For more information: https://oss.redis.com/redistimeseries/master/commands/#tsdel - """ # noqa - return self.execute_command(DEL_CMD, key, from_time, to_time) - - def createrule(self, source_key, dest_key, aggregation_type, bucket_size_msec): - """ - Create a compaction rule from values added to `source_key` into `dest_key`. - Aggregating for `bucket_size_msec` where an `aggregation_type` can be - [`avg`, `sum`, `min`, `max`, `range`, `count`, `first`, `last`, - `std.p`, `std.s`, `var.p`, `var.s`] - - For more information: https://oss.redis.com/redistimeseries/master/commands/#tscreaterule - """ # noqa - params = [source_key, dest_key] - self._appendAggregation(params, aggregation_type, bucket_size_msec) - - return self.execute_command(CREATERULE_CMD, *params) - - def deleterule(self, source_key, dest_key): - """ - Delete a compaction rule. - For more information see - - For more information: https://oss.redis.com/redistimeseries/master/commands/#tsdeleterule - """ # noqa - return self.execute_command(DELETERULE_CMD, source_key, dest_key) - - def __range_params( - self, - key, - from_time, - to_time, - count, - aggregation_type, - bucket_size_msec, - filter_by_ts, - filter_by_min_value, - filter_by_max_value, - align, - ): - """Create TS.RANGE and TS.REVRANGE arguments.""" - params = [key, from_time, to_time] - self._appendFilerByTs(params, filter_by_ts) - self._appendFilerByValue(params, filter_by_min_value, filter_by_max_value) - self._appendCount(params, count) - self._appendAlign(params, align) - self._appendAggregation(params, aggregation_type, bucket_size_msec) - - return params - - def range( - self, - key, - from_time, - to_time, - count=None, - aggregation_type=None, - bucket_size_msec=0, - filter_by_ts=None, - filter_by_min_value=None, - filter_by_max_value=None, - align=None, - ): - """ - Query a range in forward direction for a specific time-serie. - - Args: - - key: - Key name for timeseries. - from_time: - Start timestamp for the range query. - can be used to express - the minimum possible timestamp (0). - to_time: - End timestamp for range query, + can be used to express the - maximum possible timestamp. - count: - Optional maximum number of returned results. - aggregation_type: - Optional aggregation type. Can be one of - [`avg`, `sum`, `min`, `max`, `range`, `count`, - `first`, `last`, `std.p`, `std.s`, `var.p`, `var.s`] - bucket_size_msec: - Time bucket for aggregation in milliseconds. - filter_by_ts: - List of timestamps to filter the result by specific timestamps. - filter_by_min_value: - Filter result by minimum value (must mention also filter - by_max_value). - filter_by_max_value: - Filter result by maximum value (must mention also filter - by_min_value). - align: - Timestamp for alignment control for aggregation. - - For more information: https://oss.redis.com/redistimeseries/master/commands/#tsrangetsrevrange - """ # noqa - params = self.__range_params( - key, - from_time, - to_time, - count, - aggregation_type, - bucket_size_msec, - filter_by_ts, - filter_by_min_value, - filter_by_max_value, - align, - ) - return self.execute_command(RANGE_CMD, *params) - - def revrange( - self, - key, - from_time, - to_time, - count=None, - aggregation_type=None, - bucket_size_msec=0, - filter_by_ts=None, - filter_by_min_value=None, - filter_by_max_value=None, - align=None, - ): - """ - Query a range in reverse direction for a specific time-series. - - **Note**: This command is only available since RedisTimeSeries >= v1.4 - - Args: - - key: - Key name for timeseries. - from_time: - Start timestamp for the range query. - can be used to express the minimum possible timestamp (0). - to_time: - End timestamp for range query, + can be used to express the maximum possible timestamp. - count: - Optional maximum number of returned results. - aggregation_type: - Optional aggregation type. Can be one of [`avg`, `sum`, `min`, `max`, `range`, `count`, - `first`, `last`, `std.p`, `std.s`, `var.p`, `var.s`] - bucket_size_msec: - Time bucket for aggregation in milliseconds. - filter_by_ts: - List of timestamps to filter the result by specific timestamps. - filter_by_min_value: - Filter result by minimum value (must mention also filter_by_max_value). - filter_by_max_value: - Filter result by maximum value (must mention also filter_by_min_value). - align: - Timestamp for alignment control for aggregation. - - For more information: https://oss.redis.com/redistimeseries/master/commands/#tsrangetsrevrange - """ # noqa - params = self.__range_params( - key, - from_time, - to_time, - count, - aggregation_type, - bucket_size_msec, - filter_by_ts, - filter_by_min_value, - filter_by_max_value, - align, - ) - return self.execute_command(REVRANGE_CMD, *params) - - def __mrange_params( - self, - aggregation_type, - bucket_size_msec, - count, - filters, - from_time, - to_time, - with_labels, - filter_by_ts, - filter_by_min_value, - filter_by_max_value, - groupby, - reduce, - select_labels, - align, - ): - """Create TS.MRANGE and TS.MREVRANGE arguments.""" - params = [from_time, to_time] - self._appendFilerByTs(params, filter_by_ts) - self._appendFilerByValue(params, filter_by_min_value, filter_by_max_value) - self._appendCount(params, count) - self._appendAlign(params, align) - self._appendAggregation(params, aggregation_type, bucket_size_msec) - self._appendWithLabels(params, with_labels, select_labels) - params.extend(["FILTER"]) - params += filters - self._appendGroupbyReduce(params, groupby, reduce) - return params - - def mrange( - self, - from_time, - to_time, - filters, - count=None, - aggregation_type=None, - bucket_size_msec=0, - with_labels=False, - filter_by_ts=None, - filter_by_min_value=None, - filter_by_max_value=None, - groupby=None, - reduce=None, - select_labels=None, - align=None, - ): - """ - Query a range across multiple time-series by filters in forward direction. - - Args: - - from_time: - Start timestamp for the range query. `-` can be used to - express the minimum possible timestamp (0). - to_time: - End timestamp for range query, `+` can be used to express - the maximum possible timestamp. - filters: - filter to match the time-series labels. - count: - Optional maximum number of returned results. - aggregation_type: - Optional aggregation type. Can be one of - [`avg`, `sum`, `min`, `max`, `range`, `count`, - `first`, `last`, `std.p`, `std.s`, `var.p`, `var.s`] - bucket_size_msec: - Time bucket for aggregation in milliseconds. - with_labels: - Include in the reply the label-value pairs that represent metadata - labels of the time-series. - If this argument is not set, by default, an empty Array will be - replied on the labels array position. - filter_by_ts: - List of timestamps to filter the result by specific timestamps. - filter_by_min_value: - Filter result by minimum value (must mention also - filter_by_max_value). - filter_by_max_value: - Filter result by maximum value (must mention also - filter_by_min_value). - groupby: - Grouping by fields the results (must mention also reduce). - reduce: - Applying reducer functions on each group. Can be one - of [`sum`, `min`, `max`]. - select_labels: - Include in the reply only a subset of the key-value - pair labels of a series. - align: - Timestamp for alignment control for aggregation. - - For more information: https://oss.redis.com/redistimeseries/master/commands/#tsmrangetsmrevrange - """ # noqa - params = self.__mrange_params( - aggregation_type, - bucket_size_msec, - count, - filters, - from_time, - to_time, - with_labels, - filter_by_ts, - filter_by_min_value, - filter_by_max_value, - groupby, - reduce, - select_labels, - align, - ) - - return self.execute_command(MRANGE_CMD, *params) - - def mrevrange( - self, - from_time, - to_time, - filters, - count=None, - aggregation_type=None, - bucket_size_msec=0, - with_labels=False, - filter_by_ts=None, - filter_by_min_value=None, - filter_by_max_value=None, - groupby=None, - reduce=None, - select_labels=None, - align=None, - ): - """ - Query a range across multiple time-series by filters in reverse direction. - - Args: - - from_time: - Start timestamp for the range query. - can be used to express - the minimum possible timestamp (0). - to_time: - End timestamp for range query, + can be used to express - the maximum possible timestamp. - filters: - Filter to match the time-series labels. - count: - Optional maximum number of returned results. - aggregation_type: - Optional aggregation type. Can be one of - [`avg`, `sum`, `min`, `max`, `range`, `count`, - `first`, `last`, `std.p`, `std.s`, `var.p`, `var.s`] - bucket_size_msec: - Time bucket for aggregation in milliseconds. - with_labels: - Include in the reply the label-value pairs that represent - metadata labels - of the time-series. - If this argument is not set, by default, an empty Array - will be replied - on the labels array position. - filter_by_ts: - List of timestamps to filter the result by specific timestamps. - filter_by_min_value: - Filter result by minimum value (must mention also filter - by_max_value). - filter_by_max_value: - Filter result by maximum value (must mention also filter - by_min_value). - groupby: - Grouping by fields the results (must mention also reduce). - reduce: - Applying reducer functions on each group. Can be one - of [`sum`, `min`, `max`]. - select_labels: - Include in the reply only a subset of the key-value pair - labels of a series. - align: - Timestamp for alignment control for aggregation. - - For more information: https://oss.redis.com/redistimeseries/master/commands/#tsmrangetsmrevrange - """ # noqa - params = self.__mrange_params( - aggregation_type, - bucket_size_msec, - count, - filters, - from_time, - to_time, - with_labels, - filter_by_ts, - filter_by_min_value, - filter_by_max_value, - groupby, - reduce, - select_labels, - align, - ) - - return self.execute_command(MREVRANGE_CMD, *params) - - def get(self, key): - """# noqa - Get the last sample of `key`. - - For more information: https://oss.redis.com/redistimeseries/master/commands/#tsget - """ # noqa - return self.execute_command(GET_CMD, key) - - def mget(self, filters, with_labels=False): - """# noqa - Get the last samples matching the specific `filter`. - - For more information: https://oss.redis.com/redistimeseries/master/commands/#tsmget - """ # noqa - params = [] - self._appendWithLabels(params, with_labels) - params.extend(["FILTER"]) - params += filters - return self.execute_command(MGET_CMD, *params) - - def info(self, key): - """# noqa - Get information of `key`. - - For more information: https://oss.redis.com/redistimeseries/master/commands/#tsinfo - """ # noqa - return self.execute_command(INFO_CMD, key) - - def queryindex(self, filters): - """# noqa - Get all the keys matching the `filter` list. - - For more information: https://oss.redis.com/redistimeseries/master/commands/#tsqueryindex - """ # noq - return self.execute_command(QUERYINDEX_CMD, *filters) - - @staticmethod - def _appendUncompressed(params, uncompressed): - """Append UNCOMPRESSED tag to params.""" - if uncompressed: - params.extend(["UNCOMPRESSED"]) - - @staticmethod - def _appendWithLabels(params, with_labels, select_labels=None): - """Append labels behavior to params.""" - if with_labels and select_labels: - raise DataError( - "with_labels and select_labels cannot be provided together." - ) - - if with_labels: - params.extend(["WITHLABELS"]) - if select_labels: - params.extend(["SELECTED_LABELS", *select_labels]) - - @staticmethod - def _appendGroupbyReduce(params, groupby, reduce): - """Append GROUPBY REDUCE property to params.""" - if groupby is not None and reduce is not None: - params.extend(["GROUPBY", groupby, "REDUCE", reduce.upper()]) - - @staticmethod - def _appendRetention(params, retention): - """Append RETENTION property to params.""" - if retention is not None: - params.extend(["RETENTION", retention]) - - @staticmethod - def _appendLabels(params, labels): - """Append LABELS property to params.""" - if labels: - params.append("LABELS") - for k, v in labels.items(): - params.extend([k, v]) - - @staticmethod - def _appendCount(params, count): - """Append COUNT property to params.""" - if count is not None: - params.extend(["COUNT", count]) - - @staticmethod - def _appendTimestamp(params, timestamp): - """Append TIMESTAMP property to params.""" - if timestamp is not None: - params.extend(["TIMESTAMP", timestamp]) - - @staticmethod - def _appendAlign(params, align): - """Append ALIGN property to params.""" - if align is not None: - params.extend(["ALIGN", align]) - - @staticmethod - def _appendAggregation(params, aggregation_type, bucket_size_msec): - """Append AGGREGATION property to params.""" - if aggregation_type is not None: - params.append("AGGREGATION") - params.extend([aggregation_type, bucket_size_msec]) - - @staticmethod - def _appendChunkSize(params, chunk_size): - """Append CHUNK_SIZE property to params.""" - if chunk_size is not None: - params.extend(["CHUNK_SIZE", chunk_size]) - - @staticmethod - def _appendDuplicatePolicy(params, command, duplicate_policy): - """Append DUPLICATE_POLICY property to params on CREATE - and ON_DUPLICATE on ADD. - """ - if duplicate_policy is not None: - if command == "TS.ADD": - params.extend(["ON_DUPLICATE", duplicate_policy]) - else: - params.extend(["DUPLICATE_POLICY", duplicate_policy]) - - @staticmethod - def _appendFilerByTs(params, ts_list): - """Append FILTER_BY_TS property to params.""" - if ts_list is not None: - params.extend(["FILTER_BY_TS", *ts_list]) - - @staticmethod - def _appendFilerByValue(params, min_value, max_value): - """Append FILTER_BY_VALUE property to params.""" - if min_value is not None and max_value is not None: - params.extend(["FILTER_BY_VALUE", min_value, max_value]) diff --git a/redis/commands/timeseries/info.py b/redis/commands/timeseries/info.py deleted file mode 100644 index fba7f09..0000000 --- a/redis/commands/timeseries/info.py +++ /dev/null @@ -1,82 +0,0 @@ -from ..helpers import nativestr -from .utils import list_to_dict - - -class TSInfo: - """ - Hold information and statistics on the time-series. - Can be created using ``tsinfo`` command - https://oss.redis.com/redistimeseries/commands/#tsinfo. - """ - - rules = [] - labels = [] - sourceKey = None - chunk_count = None - memory_usage = None - total_samples = None - retention_msecs = None - last_time_stamp = None - first_time_stamp = None - - max_samples_per_chunk = None - chunk_size = None - duplicate_policy = None - - def __init__(self, args): - """ - Hold information and statistics on the time-series. - - The supported params that can be passed as args: - - rules: - A list of compaction rules of the time series. - sourceKey: - Key name for source time series in case the current series - is a target of a rule. - chunkCount: - Number of Memory Chunks used for the time series. - memoryUsage: - Total number of bytes allocated for the time series. - totalSamples: - Total number of samples in the time series. - labels: - A list of label-value pairs that represent the metadata - labels of the time series. - retentionTime: - Retention time, in milliseconds, for the time series. - lastTimestamp: - Last timestamp present in the time series. - firstTimestamp: - First timestamp present in the time series. - maxSamplesPerChunk: - Deprecated. - chunkSize: - Amount of memory, in bytes, allocated for data. - duplicatePolicy: - Policy that will define handling of duplicate samples. - - Can read more about on - https://oss.redis.com/redistimeseries/configuration/#duplicate_policy - """ - response = dict(zip(map(nativestr, args[::2]), args[1::2])) - self.rules = response["rules"] - self.source_key = response["sourceKey"] - self.chunk_count = response["chunkCount"] - self.memory_usage = response["memoryUsage"] - self.total_samples = response["totalSamples"] - self.labels = list_to_dict(response["labels"]) - self.retention_msecs = response["retentionTime"] - self.lastTimeStamp = response["lastTimestamp"] - self.first_time_stamp = response["firstTimestamp"] - if "maxSamplesPerChunk" in response: - self.max_samples_per_chunk = response["maxSamplesPerChunk"] - self.chunk_size = ( - self.max_samples_per_chunk * 16 - ) # backward compatible changes - if "chunkSize" in response: - self.chunk_size = response["chunkSize"] - if "duplicatePolicy" in response: - self.duplicate_policy = response["duplicatePolicy"] - if type(self.duplicate_policy) == bytes: - self.duplicate_policy = self.duplicate_policy.decode() diff --git a/redis/commands/timeseries/utils.py b/redis/commands/timeseries/utils.py deleted file mode 100644 index c49b040..0000000 --- a/redis/commands/timeseries/utils.py +++ /dev/null @@ -1,44 +0,0 @@ -from ..helpers import nativestr - - -def list_to_dict(aList): - return {nativestr(aList[i][0]): nativestr(aList[i][1]) for i in range(len(aList))} - - -def parse_range(response): - """Parse range response. Used by TS.RANGE and TS.REVRANGE.""" - return [tuple((r[0], float(r[1]))) for r in response] - - -def parse_m_range(response): - """Parse multi range response. Used by TS.MRANGE and TS.MREVRANGE.""" - res = [] - for item in response: - res.append({nativestr(item[0]): [list_to_dict(item[1]), parse_range(item[2])]}) - return sorted(res, key=lambda d: list(d.keys())) - - -def parse_get(response): - """Parse get response. Used by TS.GET.""" - if not response: - return None - return int(response[0]), float(response[1]) - - -def parse_m_get(response): - """Parse multi get response. Used by TS.MGET.""" - res = [] - for item in response: - if not item[2]: - res.append({nativestr(item[0]): [list_to_dict(item[1]), None, None]}) - else: - res.append( - { - nativestr(item[0]): [ - list_to_dict(item[1]), - int(item[2][0]), - float(item[2][1]), - ] - } - ) - return sorted(res, key=lambda d: list(d.keys())) diff --git a/redis/sentinel.py b/redis/sentinel.py deleted file mode 100644 index c9383d3..0000000 --- a/redis/sentinel.py +++ /dev/null @@ -1,337 +0,0 @@ -import random -import weakref - -from redis.client import Redis -from redis.commands import SentinelCommands -from redis.connection import Connection, ConnectionPool, SSLConnection -from redis.exceptions import ConnectionError, ReadOnlyError, ResponseError, TimeoutError -from redis.utils import str_if_bytes - - -class MasterNotFoundError(ConnectionError): - pass - - -class SlaveNotFoundError(ConnectionError): - pass - - -class SentinelManagedConnection(Connection): - def __init__(self, **kwargs): - self.connection_pool = kwargs.pop("connection_pool") - super().__init__(**kwargs) - - def __repr__(self): - pool = self.connection_pool - s = f"{type(self).__name__}" - if self.host: - host_info = f",host={self.host},port={self.port}" - s = s % host_info - return s - - def connect_to(self, address): - self.host, self.port = address - super().connect() - if self.connection_pool.check_connection: - self.send_command("PING") - if str_if_bytes(self.read_response()) != "PONG": - raise ConnectionError("PING failed") - - def connect(self): - if self._sock: - return # already connected - if self.connection_pool.is_master: - self.connect_to(self.connection_pool.get_master_address()) - else: - for slave in self.connection_pool.rotate_slaves(): - try: - return self.connect_to(slave) - except ConnectionError: - continue - raise SlaveNotFoundError # Never be here - - def read_response(self, disable_decoding=False): - try: - return super().read_response(disable_decoding=disable_decoding) - except ReadOnlyError: - if self.connection_pool.is_master: - # When talking to a master, a ReadOnlyError when likely - # indicates that the previous master that we're still connected - # to has been demoted to a slave and there's a new master. - # calling disconnect will force the connection to re-query - # sentinel during the next connect() attempt. - self.disconnect() - raise ConnectionError("The previous master is now a slave") - raise - - -class SentinelManagedSSLConnection(SentinelManagedConnection, SSLConnection): - pass - - -class SentinelConnectionPool(ConnectionPool): - """ - Sentinel backed connection pool. - - If ``check_connection`` flag is set to True, SentinelManagedConnection - sends a PING command right after establishing the connection. - """ - - def __init__(self, service_name, sentinel_manager, **kwargs): - kwargs["connection_class"] = kwargs.get( - "connection_class", - SentinelManagedSSLConnection - if kwargs.pop("ssl", False) - else SentinelManagedConnection, - ) - self.is_master = kwargs.pop("is_master", True) - self.check_connection = kwargs.pop("check_connection", False) - super().__init__(**kwargs) - self.connection_kwargs["connection_pool"] = weakref.proxy(self) - self.service_name = service_name - self.sentinel_manager = sentinel_manager - - def __repr__(self): - role = "master" if self.is_master else "slave" - return f"{type(self).__name__}>> from redis.sentinel import Sentinel - >>> sentinel = Sentinel([('localhost', 26379)], socket_timeout=0.1) - >>> master = sentinel.master_for('mymaster', socket_timeout=0.1) - >>> master.set('foo', 'bar') - >>> slave = sentinel.slave_for('mymaster', socket_timeout=0.1) - >>> slave.get('foo') - b'bar' - - ``sentinels`` is a list of sentinel nodes. Each node is represented by - a pair (hostname, port). - - ``min_other_sentinels`` defined a minimum number of peers for a sentinel. - When querying a sentinel, if it doesn't meet this threshold, responses - from that sentinel won't be considered valid. - - ``sentinel_kwargs`` is a dictionary of connection arguments used when - connecting to sentinel instances. Any argument that can be passed to - a normal Redis connection can be specified here. If ``sentinel_kwargs`` is - not specified, any socket_timeout and socket_keepalive options specified - in ``connection_kwargs`` will be used. - - ``connection_kwargs`` are keyword arguments that will be used when - establishing a connection to a Redis server. - """ - - def __init__( - self, - sentinels, - min_other_sentinels=0, - sentinel_kwargs=None, - **connection_kwargs, - ): - # if sentinel_kwargs isn't defined, use the socket_* options from - # connection_kwargs - if sentinel_kwargs is None: - sentinel_kwargs = { - k: v for k, v in connection_kwargs.items() if k.startswith("socket_") - } - self.sentinel_kwargs = sentinel_kwargs - - self.sentinels = [ - Redis(hostname, port, **self.sentinel_kwargs) - for hostname, port in sentinels - ] - self.min_other_sentinels = min_other_sentinels - self.connection_kwargs = connection_kwargs - - def execute_command(self, *args, **kwargs): - """ - Execute Sentinel command in sentinel nodes. - once - If set to True, then execute the resulting command on a single - node at random, rather than across the entire sentinel cluster. - """ - once = bool(kwargs.get("once", False)) - if "once" in kwargs.keys(): - kwargs.pop("once") - - if once: - for sentinel in self.sentinels: - sentinel.execute_command(*args, **kwargs) - else: - random.choice(self.sentinels).execute_command(*args, **kwargs) - return True - - def __repr__(self): - sentinel_addresses = [] - for sentinel in self.sentinels: - sentinel_addresses.append( - "{host}:{port}".format_map( - sentinel.connection_pool.connection_kwargs, - ) - ) - return f'{type(self).__name__}' - - def check_master_state(self, state, service_name): - if not state["is_master"] or state["is_sdown"] or state["is_odown"]: - return False - # Check if our sentinel doesn't see other nodes - if state["num-other-sentinels"] < self.min_other_sentinels: - return False - return True - - def discover_master(self, service_name): - """ - Asks sentinel servers for the Redis master's address corresponding - to the service labeled ``service_name``. - - Returns a pair (address, port) or raises MasterNotFoundError if no - master is found. - """ - for sentinel_no, sentinel in enumerate(self.sentinels): - try: - masters = sentinel.sentinel_masters() - except (ConnectionError, TimeoutError): - continue - state = masters.get(service_name) - if state and self.check_master_state(state, service_name): - # Put this sentinel at the top of the list - self.sentinels[0], self.sentinels[sentinel_no] = ( - sentinel, - self.sentinels[0], - ) - return state["ip"], state["port"] - raise MasterNotFoundError(f"No master found for {service_name!r}") - - def filter_slaves(self, slaves): - "Remove slaves that are in an ODOWN or SDOWN state" - slaves_alive = [] - for slave in slaves: - if slave["is_odown"] or slave["is_sdown"]: - continue - slaves_alive.append((slave["ip"], slave["port"])) - return slaves_alive - - def discover_slaves(self, service_name): - "Returns a list of alive slaves for service ``service_name``" - for sentinel in self.sentinels: - try: - slaves = sentinel.sentinel_slaves(service_name) - except (ConnectionError, ResponseError, TimeoutError): - continue - slaves = self.filter_slaves(slaves) - if slaves: - return slaves - return [] - - def master_for( - self, - service_name, - redis_class=Redis, - connection_pool_class=SentinelConnectionPool, - **kwargs, - ): - """ - Returns a redis client instance for the ``service_name`` master. - - A :py:class:`~redis.sentinel.SentinelConnectionPool` class is - used to retrieve the master's address before establishing a new - connection. - - NOTE: If the master's address has changed, any cached connections to - the old master are closed. - - By default clients will be a :py:class:`~redis.Redis` instance. - Specify a different class to the ``redis_class`` argument if you - desire something different. - - The ``connection_pool_class`` specifies the connection pool to - use. The :py:class:`~redis.sentinel.SentinelConnectionPool` - will be used by default. - - All other keyword arguments are merged with any connection_kwargs - passed to this class and passed to the connection pool as keyword - arguments to be used to initialize Redis connections. - """ - kwargs["is_master"] = True - connection_kwargs = dict(self.connection_kwargs) - connection_kwargs.update(kwargs) - return redis_class( - connection_pool=connection_pool_class( - service_name, self, **connection_kwargs - ) - ) - - def slave_for( - self, - service_name, - redis_class=Redis, - connection_pool_class=SentinelConnectionPool, - **kwargs, - ): - """ - Returns redis client instance for the ``service_name`` slave(s). - - A SentinelConnectionPool class is used to retrieve the slave's - address before establishing a new connection. - - By default clients will be a :py:class:`~redis.Redis` instance. - Specify a different class to the ``redis_class`` argument if you - desire something different. - - The ``connection_pool_class`` specifies the connection pool to use. - The SentinelConnectionPool will be used by default. - - All other keyword arguments are merged with any connection_kwargs - passed to this class and passed to the connection pool as keyword - arguments to be used to initialize Redis connections. - """ - kwargs["is_master"] = False - connection_kwargs = dict(self.connection_kwargs) - connection_kwargs.update(kwargs) - return redis_class( - connection_pool=connection_pool_class( - service_name, self, **connection_kwargs - ) - ) diff --git a/setup.py b/setup.py index 7733220..d980041 100644 --- a/setup.py +++ b/setup.py @@ -13,11 +13,6 @@ setup( include=[ "redis", "redis.commands", - "redis.commands.bf", - "redis.commands.json", - "redis.commands.search", - "redis.commands.timeseries", - "redis.commands.graph", ] ), url="https://github.com/redis/redis-py", diff --git a/tests/test_bloom.py b/tests/test_bloom.py deleted file mode 100644 index 8936584..0000000 --- a/tests/test_bloom.py +++ /dev/null @@ -1,383 +0,0 @@ -import pytest - -import redis.commands.bf -from redis.exceptions import ModuleError, RedisError -from redis.utils import HIREDIS_AVAILABLE - - -def intlist(obj): - return [int(v) for v in obj] - - -@pytest.fixture -def client(modclient): - assert isinstance(modclient.bf(), redis.commands.bf.BFBloom) - assert isinstance(modclient.cf(), redis.commands.bf.CFBloom) - assert isinstance(modclient.cms(), redis.commands.bf.CMSBloom) - assert isinstance(modclient.tdigest(), redis.commands.bf.TDigestBloom) - assert isinstance(modclient.topk(), redis.commands.bf.TOPKBloom) - - modclient.flushdb() - return modclient - - -@pytest.mark.redismod -def test_create(client): - """Test CREATE/RESERVE calls""" - assert client.bf().create("bloom", 0.01, 1000) - assert client.bf().create("bloom_e", 0.01, 1000, expansion=1) - assert client.bf().create("bloom_ns", 0.01, 1000, noScale=True) - assert client.cf().create("cuckoo", 1000) - assert client.cf().create("cuckoo_e", 1000, expansion=1) - assert client.cf().create("cuckoo_bs", 1000, bucket_size=4) - assert client.cf().create("cuckoo_mi", 1000, max_iterations=10) - assert client.cms().initbydim("cmsDim", 100, 5) - assert client.cms().initbyprob("cmsProb", 0.01, 0.01) - assert client.topk().reserve("topk", 5, 100, 5, 0.9) - assert client.tdigest().create("tDigest", 100) - - -# region Test Bloom Filter -@pytest.mark.redismod -def test_bf_add(client): - assert client.bf().create("bloom", 0.01, 1000) - assert 1 == client.bf().add("bloom", "foo") - assert 0 == client.bf().add("bloom", "foo") - assert [0] == intlist(client.bf().madd("bloom", "foo")) - assert [0, 1] == client.bf().madd("bloom", "foo", "bar") - assert [0, 0, 1] == client.bf().madd("bloom", "foo", "bar", "baz") - assert 1 == client.bf().exists("bloom", "foo") - assert 0 == client.bf().exists("bloom", "noexist") - assert [1, 0] == intlist(client.bf().mexists("bloom", "foo", "noexist")) - - -@pytest.mark.redismod -def test_bf_insert(client): - assert client.bf().create("bloom", 0.01, 1000) - assert [1] == intlist(client.bf().insert("bloom", ["foo"])) - assert [0, 1] == intlist(client.bf().insert("bloom", ["foo", "bar"])) - assert [1] == intlist(client.bf().insert("captest", ["foo"], capacity=10)) - assert [1] == intlist(client.bf().insert("errtest", ["foo"], error=0.01)) - assert 1 == client.bf().exists("bloom", "foo") - assert 0 == client.bf().exists("bloom", "noexist") - assert [1, 0] == intlist(client.bf().mexists("bloom", "foo", "noexist")) - info = client.bf().info("bloom") - assert 2 == info.insertedNum - assert 1000 == info.capacity - assert 1 == info.filterNum - - -@pytest.mark.redismod -def test_bf_scandump_and_loadchunk(client): - # Store a filter - client.bf().create("myBloom", "0.0001", "1000") - - # test is probabilistic and might fail. It is OK to change variables if - # certain to not break anything - def do_verify(): - res = 0 - for x in range(1000): - client.bf().add("myBloom", x) - rv = client.bf().exists("myBloom", x) - assert rv - rv = client.bf().exists("myBloom", f"nonexist_{x}") - res += rv == x - assert res < 5 - - do_verify() - cmds = [] - if HIREDIS_AVAILABLE: - with pytest.raises(ModuleError): - cur = client.bf().scandump("myBloom", 0) - return - - cur = client.bf().scandump("myBloom", 0) - first = cur[0] - cmds.append(cur) - - while True: - cur = client.bf().scandump("myBloom", first) - first = cur[0] - if first == 0: - break - else: - cmds.append(cur) - prev_info = client.bf().execute_command("bf.debug", "myBloom") - - # Remove the filter - client.bf().client.delete("myBloom") - - # Now, load all the commands: - for cmd in cmds: - client.bf().loadchunk("myBloom", *cmd) - - cur_info = client.bf().execute_command("bf.debug", "myBloom") - assert prev_info == cur_info - do_verify() - - client.bf().client.delete("myBloom") - client.bf().create("myBloom", "0.0001", "10000000") - - -@pytest.mark.redismod -def test_bf_info(client): - expansion = 4 - # Store a filter - client.bf().create("nonscaling", "0.0001", "1000", noScale=True) - info = client.bf().info("nonscaling") - assert info.expansionRate is None - - client.bf().create("expanding", "0.0001", "1000", expansion=expansion) - info = client.bf().info("expanding") - assert info.expansionRate == 4 - - try: - # noScale mean no expansion - client.bf().create( - "myBloom", "0.0001", "1000", expansion=expansion, noScale=True - ) - assert False - except RedisError: - assert True - - -# region Test Cuckoo Filter -@pytest.mark.redismod -def test_cf_add_and_insert(client): - assert client.cf().create("cuckoo", 1000) - assert client.cf().add("cuckoo", "filter") - assert not client.cf().addnx("cuckoo", "filter") - assert 1 == client.cf().addnx("cuckoo", "newItem") - assert [1] == client.cf().insert("captest", ["foo"]) - assert [1] == client.cf().insert("captest", ["foo"], capacity=1000) - assert [1] == client.cf().insertnx("captest", ["bar"]) - assert [1] == client.cf().insertnx("captest", ["food"], nocreate="1") - assert [0, 0, 1] == client.cf().insertnx("captest", ["foo", "bar", "baz"]) - assert [0] == client.cf().insertnx("captest", ["bar"], capacity=1000) - assert [1] == client.cf().insert("empty1", ["foo"], capacity=1000) - assert [1] == client.cf().insertnx("empty2", ["bar"], capacity=1000) - info = client.cf().info("captest") - assert 5 == info.insertedNum - assert 0 == info.deletedNum - assert 1 == info.filterNum - - -@pytest.mark.redismod -def test_cf_exists_and_del(client): - assert client.cf().create("cuckoo", 1000) - assert client.cf().add("cuckoo", "filter") - assert client.cf().exists("cuckoo", "filter") - assert not client.cf().exists("cuckoo", "notexist") - assert 1 == client.cf().count("cuckoo", "filter") - assert 0 == client.cf().count("cuckoo", "notexist") - assert client.cf().delete("cuckoo", "filter") - assert 0 == client.cf().count("cuckoo", "filter") - - -# region Test Count-Min Sketch -@pytest.mark.redismod -def test_cms(client): - assert client.cms().initbydim("dim", 1000, 5) - assert client.cms().initbyprob("prob", 0.01, 0.01) - assert client.cms().incrby("dim", ["foo"], [5]) - assert [0] == client.cms().query("dim", "notexist") - assert [5] == client.cms().query("dim", "foo") - assert [10, 15] == client.cms().incrby("dim", ["foo", "bar"], [5, 15]) - assert [10, 15] == client.cms().query("dim", "foo", "bar") - info = client.cms().info("dim") - assert 1000 == info.width - assert 5 == info.depth - assert 25 == info.count - - -@pytest.mark.redismod -def test_cms_merge(client): - assert client.cms().initbydim("A", 1000, 5) - assert client.cms().initbydim("B", 1000, 5) - assert client.cms().initbydim("C", 1000, 5) - assert client.cms().incrby("A", ["foo", "bar", "baz"], [5, 3, 9]) - assert client.cms().incrby("B", ["foo", "bar", "baz"], [2, 3, 1]) - assert [5, 3, 9] == client.cms().query("A", "foo", "bar", "baz") - assert [2, 3, 1] == client.cms().query("B", "foo", "bar", "baz") - assert client.cms().merge("C", 2, ["A", "B"]) - assert [7, 6, 10] == client.cms().query("C", "foo", "bar", "baz") - assert client.cms().merge("C", 2, ["A", "B"], ["1", "2"]) - assert [9, 9, 11] == client.cms().query("C", "foo", "bar", "baz") - assert client.cms().merge("C", 2, ["A", "B"], ["2", "3"]) - assert [16, 15, 21] == client.cms().query("C", "foo", "bar", "baz") - - -# endregion - - -# region Test Top-K -@pytest.mark.redismod -def test_topk(client): - # test list with empty buckets - assert client.topk().reserve("topk", 3, 50, 4, 0.9) - assert [ - None, - None, - None, - "A", - "C", - "D", - None, - None, - "E", - None, - "B", - "C", - None, - None, - None, - "D", - None, - ] == client.topk().add( - "topk", - "A", - "B", - "C", - "D", - "E", - "A", - "A", - "B", - "C", - "G", - "D", - "B", - "D", - "A", - "E", - "E", - 1, - ) - assert [1, 1, 0, 0, 1, 0, 0] == client.topk().query( - "topk", "A", "B", "C", "D", "E", "F", "G" - ) - assert [4, 3, 2, 3, 3, 0, 1] == client.topk().count( - "topk", "A", "B", "C", "D", "E", "F", "G" - ) - - # test full list - assert client.topk().reserve("topklist", 3, 50, 3, 0.9) - assert client.topk().add( - "topklist", - "A", - "B", - "C", - "D", - "E", - "A", - "A", - "B", - "C", - "G", - "D", - "B", - "D", - "A", - "E", - "E", - ) - assert ["A", "B", "E"] == client.topk().list("topklist") - assert ["A", 4, "B", 3, "E", 3] == client.topk().list("topklist", withcount=True) - info = client.topk().info("topklist") - assert 3 == info.k - assert 50 == info.width - assert 3 == info.depth - assert 0.9 == round(float(info.decay), 1) - - -@pytest.mark.redismod -def test_topk_incrby(client): - client.flushdb() - assert client.topk().reserve("topk", 3, 10, 3, 1) - assert [None, None, None] == client.topk().incrby( - "topk", ["bar", "baz", "42"], [3, 6, 2] - ) - assert [None, "bar"] == client.topk().incrby("topk", ["42", "xyzzy"], [8, 4]) - assert [3, 6, 10, 4, 0] == client.topk().count( - "topk", "bar", "baz", "42", "xyzzy", 4 - ) - - -# region Test T-Digest -@pytest.mark.redismod -def test_tdigest_reset(client): - assert client.tdigest().create("tDigest", 10) - # reset on empty histogram - assert client.tdigest().reset("tDigest") - # insert data-points into sketch - assert client.tdigest().add("tDigest", list(range(10)), [1.0] * 10) - - assert client.tdigest().reset("tDigest") - # assert we have 0 unmerged nodes - assert 0 == client.tdigest().info("tDigest").unmergedNodes - - -@pytest.mark.redismod -def test_tdigest_merge(client): - assert client.tdigest().create("to-tDigest", 10) - assert client.tdigest().create("from-tDigest", 10) - # insert data-points into sketch - assert client.tdigest().add("from-tDigest", [1.0] * 10, [1.0] * 10) - assert client.tdigest().add("to-tDigest", [2.0] * 10, [10.0] * 10) - # merge from-tdigest into to-tdigest - assert client.tdigest().merge("to-tDigest", "from-tDigest") - # we should now have 110 weight on to-histogram - info = client.tdigest().info("to-tDigest") - total_weight_to = float(info.mergedWeight) + float(info.unmergedWeight) - assert 110 == total_weight_to - - -@pytest.mark.redismod -def test_tdigest_min_and_max(client): - assert client.tdigest().create("tDigest", 100) - # insert data-points into sketch - assert client.tdigest().add("tDigest", [1, 2, 3], [1.0] * 3) - # min/max - assert 3 == client.tdigest().max("tDigest") - assert 1 == client.tdigest().min("tDigest") - - -@pytest.mark.redismod -def test_tdigest_quantile(client): - assert client.tdigest().create("tDigest", 500) - # insert data-points into sketch - assert client.tdigest().add( - "tDigest", list([x * 0.01 for x in range(1, 10000)]), [1.0] * 10000 - ) - # assert min min/max have same result as quantile 0 and 1 - assert client.tdigest().max("tDigest") == client.tdigest().quantile("tDigest", 1.0) - assert client.tdigest().min("tDigest") == client.tdigest().quantile("tDigest", 0.0) - - assert 1.0 == round(client.tdigest().quantile("tDigest", 0.01), 2) - assert 99.0 == round(client.tdigest().quantile("tDigest", 0.99), 2) - - -@pytest.mark.redismod -def test_tdigest_cdf(client): - assert client.tdigest().create("tDigest", 100) - # insert data-points into sketch - assert client.tdigest().add("tDigest", list(range(1, 10)), [1.0] * 10) - assert 0.1 == round(client.tdigest().cdf("tDigest", 1.0), 1) - assert 0.9 == round(client.tdigest().cdf("tDigest", 9.0), 1) - - -# @pytest.mark.redismod -# def test_pipeline(client): -# pipeline = client.bf().pipeline() -# assert not client.bf().execute_command("get pipeline") -# -# assert client.bf().create("pipeline", 0.01, 1000) -# for i in range(100): -# pipeline.add("pipeline", i) -# for i in range(100): -# assert not (client.bf().exists("pipeline", i)) -# -# pipeline.execute() -# -# for i in range(100): -# assert client.bf().exists("pipeline", i) diff --git a/tests/test_cluster.py b/tests/test_cluster.py deleted file mode 100644 index 496ed98..0000000 --- a/tests/test_cluster.py +++ /dev/null @@ -1,2664 +0,0 @@ -import binascii -import datetime -import warnings -from time import sleep -from unittest.mock import DEFAULT, Mock, call, patch - -import pytest - -from redis import Redis -from redis.cluster import ( - PRIMARY, - REDIS_CLUSTER_HASH_SLOTS, - REPLICA, - ClusterNode, - NodesManager, - RedisCluster, - get_node_name, -) -from redis.commands import CommandsParser -from redis.connection import Connection -from redis.crc import key_slot -from redis.exceptions import ( - AskError, - ClusterDownError, - ConnectionError, - DataError, - MovedError, - NoPermissionError, - RedisClusterException, - RedisError, -) -from redis.utils import str_if_bytes -from tests.test_pubsub import wait_for_message - -from .conftest import ( - _get_client, - skip_if_redis_enterprise, - skip_if_server_version_lt, - skip_unless_arch_bits, - wait_for_command, -) - -default_host = "127.0.0.1" -default_port = 7000 -default_cluster_slots = [ - [ - 0, - 8191, - ["127.0.0.1", 7000, "node_0"], - ["127.0.0.1", 7003, "node_3"], - ], - [8192, 16383, ["127.0.0.1", 7001, "node_1"], ["127.0.0.1", 7002, "node_2"]], -] - - -@pytest.fixture() -def slowlog(request, r): - """ - Set the slowlog threshold to 0, and the - max length to 128. This will force every - command into the slowlog and allow us - to test it - """ - # Save old values - current_config = r.config_get(target_nodes=r.get_primaries()[0]) - old_slower_than_value = current_config["slowlog-log-slower-than"] - old_max_legnth_value = current_config["slowlog-max-len"] - - # Function to restore the old values - def cleanup(): - r.config_set("slowlog-log-slower-than", old_slower_than_value) - r.config_set("slowlog-max-len", old_max_legnth_value) - - request.addfinalizer(cleanup) - - # Set the new values - r.config_set("slowlog-log-slower-than", 0) - r.config_set("slowlog-max-len", 128) - - -def get_mocked_redis_client(func=None, *args, **kwargs): - """ - Return a stable RedisCluster object that have deterministic - nodes and slots setup to remove the problem of different IP addresses - on different installations and machines. - """ - cluster_slots = kwargs.pop("cluster_slots", default_cluster_slots) - coverage_res = kwargs.pop("coverage_result", "yes") - cluster_enabled = kwargs.pop("cluster_enabled", True) - with patch.object(Redis, "execute_command") as execute_command_mock: - - def execute_command(*_args, **_kwargs): - if _args[0] == "CLUSTER SLOTS": - mock_cluster_slots = cluster_slots - return mock_cluster_slots - elif _args[0] == "COMMAND": - return {"get": [], "set": []} - elif _args[0] == "INFO": - return {"cluster_enabled": cluster_enabled} - elif len(_args) > 1 and _args[1] == "cluster-require-full-coverage": - return {"cluster-require-full-coverage": coverage_res} - elif func is not None: - return func(*args, **kwargs) - else: - return execute_command_mock(*_args, **_kwargs) - - execute_command_mock.side_effect = execute_command - - with patch.object( - CommandsParser, "initialize", autospec=True - ) as cmd_parser_initialize: - - def cmd_init_mock(self, r): - self.commands = { - "get": { - "name": "get", - "arity": 2, - "flags": ["readonly", "fast"], - "first_key_pos": 1, - "last_key_pos": 1, - "step_count": 1, - } - } - - cmd_parser_initialize.side_effect = cmd_init_mock - - return RedisCluster(*args, **kwargs) - - -def mock_node_resp(node, response): - connection = Mock() - connection.read_response.return_value = response - node.redis_connection.connection = connection - return node - - -def mock_node_resp_func(node, func): - connection = Mock() - connection.read_response.side_effect = func - node.redis_connection.connection = connection - return node - - -def mock_all_nodes_resp(rc, response): - for node in rc.get_nodes(): - mock_node_resp(node, response) - return rc - - -def find_node_ip_based_on_port(cluster_client, port): - for node in cluster_client.get_nodes(): - if node.port == port: - return node.host - - -def moved_redirection_helper(request, failover=False): - """ - Test that the client handles MOVED response after a failover. - Redirection after a failover means that the redirection address is of a - replica that was promoted to a primary. - - At first call it should return a MOVED ResponseError that will point - the client to the next server it should talk to. - - Verify that: - 1. it tries to talk to the redirected node - 2. it updates the slot's primary to the redirected node - - For a failover, also verify: - 3. the redirected node's server type updated to 'primary' - 4. the server type of the previous slot owner updated to 'replica' - """ - rc = _get_client(RedisCluster, request, flushdb=False) - slot = 12182 - redirect_node = None - # Get the current primary that holds this slot - prev_primary = rc.nodes_manager.get_node_from_slot(slot) - if failover: - if len(rc.nodes_manager.slots_cache[slot]) < 2: - warnings.warn("Skipping this test since it requires to have a " "replica") - return - redirect_node = rc.nodes_manager.slots_cache[slot][1] - else: - # Use one of the primaries to be the redirected node - redirect_node = rc.get_primaries()[0] - r_host = redirect_node.host - r_port = redirect_node.port - with patch.object(Redis, "parse_response") as parse_response: - - def moved_redirect_effect(connection, *args, **options): - def ok_response(connection, *args, **options): - assert connection.host == r_host - assert connection.port == r_port - - return "MOCK_OK" - - parse_response.side_effect = ok_response - raise MovedError(f"{slot} {r_host}:{r_port}") - - parse_response.side_effect = moved_redirect_effect - assert rc.execute_command("SET", "foo", "bar") == "MOCK_OK" - slot_primary = rc.nodes_manager.slots_cache[slot][0] - assert slot_primary == redirect_node - if failover: - assert rc.get_node(host=r_host, port=r_port).server_type == PRIMARY - assert prev_primary.server_type == REPLICA - - -@pytest.mark.onlycluster -class TestRedisClusterObj: - """ - Tests for the RedisCluster class - """ - - def test_host_port_startup_node(self): - """ - Test that it is possible to use host & port arguments as startup node - args - """ - cluster = get_mocked_redis_client(host=default_host, port=default_port) - assert cluster.get_node(host=default_host, port=default_port) is not None - - def test_startup_nodes(self): - """ - Test that it is possible to use startup_nodes - argument to init the cluster - """ - port_1 = 7000 - port_2 = 7001 - startup_nodes = [ - ClusterNode(default_host, port_1), - ClusterNode(default_host, port_2), - ] - cluster = get_mocked_redis_client(startup_nodes=startup_nodes) - assert ( - cluster.get_node(host=default_host, port=port_1) is not None - and cluster.get_node(host=default_host, port=port_2) is not None - ) - - def test_empty_startup_nodes(self): - """ - Test that exception is raised when empty providing empty startup_nodes - """ - with pytest.raises(RedisClusterException) as ex: - RedisCluster(startup_nodes=[]) - - assert str(ex.value).startswith( - "RedisCluster requires at least one node to discover the " "cluster" - ), str_if_bytes(ex.value) - - def test_from_url(self, r): - redis_url = f"redis://{default_host}:{default_port}/0" - with patch.object(RedisCluster, "from_url") as from_url: - - def from_url_mocked(_url, **_kwargs): - return get_mocked_redis_client(url=_url, **_kwargs) - - from_url.side_effect = from_url_mocked - cluster = RedisCluster.from_url(redis_url) - assert cluster.get_node(host=default_host, port=default_port) is not None - - def test_execute_command_errors(self, r): - """ - Test that if no key is provided then exception should be raised. - """ - with pytest.raises(RedisClusterException) as ex: - r.execute_command("GET") - assert str(ex.value).startswith( - "No way to dispatch this command to " "Redis Cluster. Missing key." - ) - - def test_execute_command_node_flag_primaries(self, r): - """ - Test command execution with nodes flag PRIMARIES - """ - primaries = r.get_primaries() - replicas = r.get_replicas() - mock_all_nodes_resp(r, "PONG") - assert r.ping(target_nodes=RedisCluster.PRIMARIES) is True - for primary in primaries: - conn = primary.redis_connection.connection - assert conn.read_response.called is True - for replica in replicas: - conn = replica.redis_connection.connection - assert conn.read_response.called is not True - - def test_execute_command_node_flag_replicas(self, r): - """ - Test command execution with nodes flag REPLICAS - """ - replicas = r.get_replicas() - if not replicas: - r = get_mocked_redis_client(default_host, default_port) - primaries = r.get_primaries() - mock_all_nodes_resp(r, "PONG") - assert r.ping(target_nodes=RedisCluster.REPLICAS) is True - for replica in replicas: - conn = replica.redis_connection.connection - assert conn.read_response.called is True - for primary in primaries: - conn = primary.redis_connection.connection - assert conn.read_response.called is not True - - def test_execute_command_node_flag_all_nodes(self, r): - """ - Test command execution with nodes flag ALL_NODES - """ - mock_all_nodes_resp(r, "PONG") - assert r.ping(target_nodes=RedisCluster.ALL_NODES) is True - for node in r.get_nodes(): - conn = node.redis_connection.connection - assert conn.read_response.called is True - - def test_execute_command_node_flag_random(self, r): - """ - Test command execution with nodes flag RANDOM - """ - mock_all_nodes_resp(r, "PONG") - assert r.ping(target_nodes=RedisCluster.RANDOM) is True - called_count = 0 - for node in r.get_nodes(): - conn = node.redis_connection.connection - if conn.read_response.called is True: - called_count += 1 - assert called_count == 1 - - def test_execute_command_default_node(self, r): - """ - Test command execution without node flag is being executed on the - default node - """ - def_node = r.get_default_node() - mock_node_resp(def_node, "PONG") - assert r.ping() is True - conn = def_node.redis_connection.connection - assert conn.read_response.called - - def test_ask_redirection(self, r): - """ - Test that the server handles ASK response. - - At first call it should return a ASK ResponseError that will point - the client to the next server it should talk to. - - Important thing to verify is that it tries to talk to the second node. - """ - redirect_node = r.get_nodes()[0] - with patch.object(Redis, "parse_response") as parse_response: - - def ask_redirect_effect(connection, *args, **options): - def ok_response(connection, *args, **options): - assert connection.host == redirect_node.host - assert connection.port == redirect_node.port - - return "MOCK_OK" - - parse_response.side_effect = ok_response - raise AskError(f"12182 {redirect_node.host}:{redirect_node.port}") - - parse_response.side_effect = ask_redirect_effect - - assert r.execute_command("SET", "foo", "bar") == "MOCK_OK" - - def test_moved_redirection(self, request): - """ - Test that the client handles MOVED response. - """ - moved_redirection_helper(request, failover=False) - - def test_moved_redirection_after_failover(self, request): - """ - Test that the client handles MOVED response after a failover. - """ - moved_redirection_helper(request, failover=True) - - def test_refresh_using_specific_nodes(self, request): - """ - Test making calls on specific nodes when the cluster has failed over to - another node - """ - node_7006 = ClusterNode(host=default_host, port=7006, server_type=PRIMARY) - node_7007 = ClusterNode(host=default_host, port=7007, server_type=PRIMARY) - with patch.object(Redis, "parse_response") as parse_response: - with patch.object(NodesManager, "initialize", autospec=True) as initialize: - with patch.multiple( - Connection, send_command=DEFAULT, connect=DEFAULT, can_read=DEFAULT - ) as mocks: - # simulate 7006 as a failed node - def parse_response_mock(connection, command_name, **options): - if connection.port == 7006: - parse_response.failed_calls += 1 - raise ClusterDownError( - "CLUSTERDOWN The cluster is " - "down. Use CLUSTER INFO for " - "more information" - ) - elif connection.port == 7007: - parse_response.successful_calls += 1 - - def initialize_mock(self): - # start with all slots mapped to 7006 - self.nodes_cache = {node_7006.name: node_7006} - self.default_node = node_7006 - self.slots_cache = {} - - for i in range(0, 16383): - self.slots_cache[i] = [node_7006] - - # After the first connection fails, a reinitialize - # should follow the cluster to 7007 - def map_7007(self): - self.nodes_cache = {node_7007.name: node_7007} - self.default_node = node_7007 - self.slots_cache = {} - - for i in range(0, 16383): - self.slots_cache[i] = [node_7007] - - # Change initialize side effect for the second call - initialize.side_effect = map_7007 - - parse_response.side_effect = parse_response_mock - parse_response.successful_calls = 0 - parse_response.failed_calls = 0 - initialize.side_effect = initialize_mock - mocks["can_read"].return_value = False - mocks["send_command"].return_value = "MOCK_OK" - mocks["connect"].return_value = None - with patch.object( - CommandsParser, "initialize", autospec=True - ) as cmd_parser_initialize: - - def cmd_init_mock(self, r): - self.commands = { - "get": { - "name": "get", - "arity": 2, - "flags": ["readonly", "fast"], - "first_key_pos": 1, - "last_key_pos": 1, - "step_count": 1, - } - } - - cmd_parser_initialize.side_effect = cmd_init_mock - - rc = _get_client(RedisCluster, request, flushdb=False) - assert len(rc.get_nodes()) == 1 - assert rc.get_node(node_name=node_7006.name) is not None - - rc.get("foo") - - # Cluster should now point to 7007, and there should be - # one failed and one successful call - assert len(rc.get_nodes()) == 1 - assert rc.get_node(node_name=node_7007.name) is not None - assert rc.get_node(node_name=node_7006.name) is None - assert parse_response.failed_calls == 1 - assert parse_response.successful_calls == 1 - - def test_reading_from_replicas_in_round_robin(self): - with patch.multiple( - Connection, - send_command=DEFAULT, - read_response=DEFAULT, - _connect=DEFAULT, - can_read=DEFAULT, - on_connect=DEFAULT, - ) as mocks: - with patch.object(Redis, "parse_response") as parse_response: - - def parse_response_mock_first(connection, *args, **options): - # Primary - assert connection.port == 7001 - parse_response.side_effect = parse_response_mock_second - return "MOCK_OK" - - def parse_response_mock_second(connection, *args, **options): - # Replica - assert connection.port == 7002 - parse_response.side_effect = parse_response_mock_third - return "MOCK_OK" - - def parse_response_mock_third(connection, *args, **options): - # Primary - assert connection.port == 7001 - return "MOCK_OK" - - # We don't need to create a real cluster connection but we - # do want RedisCluster.on_connect function to get called, - # so we'll mock some of the Connection's functions to allow it - parse_response.side_effect = parse_response_mock_first - mocks["send_command"].return_value = True - mocks["read_response"].return_value = "OK" - mocks["_connect"].return_value = True - mocks["can_read"].return_value = False - mocks["on_connect"].return_value = True - - # Create a cluster with reading from replications - read_cluster = get_mocked_redis_client( - host=default_host, port=default_port, read_from_replicas=True - ) - assert read_cluster.read_from_replicas is True - # Check that we read from the slot's nodes in a round robin - # matter. - # 'foo' belongs to slot 12182 and the slot's nodes are: - # [(127.0.0.1,7001,primary), (127.0.0.1,7002,replica)] - read_cluster.get("foo") - read_cluster.get("foo") - read_cluster.get("foo") - mocks["send_command"].assert_has_calls([call("READONLY")]) - - def test_keyslot(self, r): - """ - Test that method will compute correct key in all supported cases - """ - assert r.keyslot("foo") == 12182 - assert r.keyslot("{foo}bar") == 12182 - assert r.keyslot("{foo}") == 12182 - assert r.keyslot(1337) == 4314 - - assert r.keyslot(125) == r.keyslot(b"125") - assert r.keyslot(125) == r.keyslot("\x31\x32\x35") - assert r.keyslot("大奖") == r.keyslot(b"\xe5\xa4\xa7\xe5\xa5\x96") - assert r.keyslot("大奖") == r.keyslot(b"\xe5\xa4\xa7\xe5\xa5\x96") - assert r.keyslot(1337.1234) == r.keyslot("1337.1234") - assert r.keyslot(1337) == r.keyslot("1337") - assert r.keyslot(b"abc") == r.keyslot("abc") - - def test_get_node_name(self): - assert ( - get_node_name(default_host, default_port) - == f"{default_host}:{default_port}" - ) - - def test_all_nodes(self, r): - """ - Set a list of nodes and it should be possible to iterate over all - """ - nodes = [node for node in r.nodes_manager.nodes_cache.values()] - - for i, node in enumerate(r.get_nodes()): - assert node in nodes - - def test_all_nodes_masters(self, r): - """ - Set a list of nodes with random primaries/replicas config and it shold - be possible to iterate over all of them. - """ - nodes = [ - node - for node in r.nodes_manager.nodes_cache.values() - if node.server_type == PRIMARY - ] - - for node in r.get_primaries(): - assert node in nodes - - @pytest.mark.parametrize("error", RedisCluster.ERRORS_ALLOW_RETRY) - def test_cluster_down_overreaches_retry_attempts(self, error): - """ - When error that allows retry is thrown, test that we retry executing - the command as many times as configured in cluster_error_retry_attempts - and then raise the exception - """ - with patch.object(RedisCluster, "_execute_command") as execute_command: - - def raise_error(target_node, *args, **kwargs): - execute_command.failed_calls += 1 - raise error("mocked error") - - execute_command.side_effect = raise_error - - rc = get_mocked_redis_client(host=default_host, port=default_port) - - with pytest.raises(error): - rc.get("bar") - assert execute_command.failed_calls == rc.cluster_error_retry_attempts - - def test_user_on_connect_function(self, request): - """ - Test support in passing on_connect function by the user - """ - - def on_connect(connection): - assert connection is not None - - mock = Mock(side_effect=on_connect) - - _get_client(RedisCluster, request, redis_connect_func=mock) - assert mock.called is True - - def test_set_default_node_success(self, r): - """ - test successful replacement of the default cluster node - """ - default_node = r.get_default_node() - # get a different node - new_def_node = None - for node in r.get_nodes(): - if node != default_node: - new_def_node = node - break - assert r.set_default_node(new_def_node) is True - assert r.get_default_node() == new_def_node - - def test_set_default_node_failure(self, r): - """ - test failed replacement of the default cluster node - """ - default_node = r.get_default_node() - new_def_node = ClusterNode("1.1.1.1", 1111) - assert r.set_default_node(None) is False - assert r.set_default_node(new_def_node) is False - assert r.get_default_node() == default_node - - def test_get_node_from_key(self, r): - """ - Test that get_node_from_key function returns the correct node - """ - key = "bar" - slot = r.keyslot(key) - slot_nodes = r.nodes_manager.slots_cache.get(slot) - primary = slot_nodes[0] - assert r.get_node_from_key(key, replica=False) == primary - replica = r.get_node_from_key(key, replica=True) - if replica is not None: - assert replica.server_type == REPLICA - assert replica in slot_nodes - - -@pytest.mark.onlycluster -class TestClusterRedisCommands: - """ - Tests for RedisCluster unique commands - """ - - def test_case_insensitive_command_names(self, r): - assert ( - r.cluster_response_callbacks["cluster addslots"] - == r.cluster_response_callbacks["CLUSTER ADDSLOTS"] - ) - - def test_get_and_set(self, r): - # get and set can't be tested independently of each other - assert r.get("a") is None - byte_string = b"value" - integer = 5 - unicode_string = chr(3456) + "abcd" + chr(3421) - assert r.set("byte_string", byte_string) - assert r.set("integer", 5) - assert r.set("unicode_string", unicode_string) - assert r.get("byte_string") == byte_string - assert r.get("integer") == str(integer).encode() - assert r.get("unicode_string").decode("utf-8") == unicode_string - - def test_mget_nonatomic(self, r): - assert r.mget_nonatomic([]) == [] - assert r.mget_nonatomic(["a", "b"]) == [None, None] - r["a"] = "1" - r["b"] = "2" - r["c"] = "3" - - assert r.mget_nonatomic("a", "other", "b", "c") == [b"1", None, b"2", b"3"] - - def test_mset_nonatomic(self, r): - d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} - assert r.mset_nonatomic(d) - for k, v in d.items(): - assert r[k] == v - - def test_config_set(self, r): - assert r.config_set("slowlog-log-slower-than", 0) - - def test_cluster_config_resetstat(self, r): - r.ping(target_nodes="all") - all_info = r.info(target_nodes="all") - prior_commands_processed = -1 - for node_info in all_info.values(): - prior_commands_processed = node_info["total_commands_processed"] - assert prior_commands_processed >= 1 - r.config_resetstat(target_nodes="all") - all_info = r.info(target_nodes="all") - for node_info in all_info.values(): - reset_commands_processed = node_info["total_commands_processed"] - assert reset_commands_processed < prior_commands_processed - - def test_client_setname(self, r): - node = r.get_random_node() - r.client_setname("redis_py_test", target_nodes=node) - client_name = r.client_getname(target_nodes=node) - assert client_name == "redis_py_test" - - def test_exists(self, r): - d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} - r.mset_nonatomic(d) - assert r.exists(*d.keys()) == len(d) - - def test_delete(self, r): - d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} - r.mset_nonatomic(d) - assert r.delete(*d.keys()) == len(d) - assert r.delete(*d.keys()) == 0 - - def test_touch(self, r): - d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} - r.mset_nonatomic(d) - assert r.touch(*d.keys()) == len(d) - - def test_unlink(self, r): - d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} - r.mset_nonatomic(d) - assert r.unlink(*d.keys()) == len(d) - # Unlink is non-blocking so we sleep before - # verifying the deletion - sleep(0.1) - assert r.unlink(*d.keys()) == 0 - - def test_pubsub_channels_merge_results(self, r): - nodes = r.get_nodes() - channels = [] - pubsub_nodes = [] - i = 0 - for node in nodes: - channel = f"foo{i}" - # We will create different pubsub clients where each one is - # connected to a different node - p = r.pubsub(node) - pubsub_nodes.append(p) - p.subscribe(channel) - b_channel = channel.encode("utf-8") - channels.append(b_channel) - # Assert that each node returns only the channel it subscribed to - sub_channels = node.redis_connection.pubsub_channels() - if not sub_channels: - # Try again after a short sleep - sleep(0.3) - sub_channels = node.redis_connection.pubsub_channels() - assert sub_channels == [b_channel] - i += 1 - # Assert that the cluster's pubsub_channels function returns ALL of - # the cluster's channels - result = r.pubsub_channels(target_nodes="all") - result.sort() - assert result == channels - - def test_pubsub_numsub_merge_results(self, r): - nodes = r.get_nodes() - pubsub_nodes = [] - channel = "foo" - b_channel = channel.encode("utf-8") - for node in nodes: - # We will create different pubsub clients where each one is - # connected to a different node - p = r.pubsub(node) - pubsub_nodes.append(p) - p.subscribe(channel) - # Assert that each node returns that only one client is subscribed - sub_chann_num = node.redis_connection.pubsub_numsub(channel) - if sub_chann_num == [(b_channel, 0)]: - sleep(0.3) - sub_chann_num = node.redis_connection.pubsub_numsub(channel) - assert sub_chann_num == [(b_channel, 1)] - # Assert that the cluster's pubsub_numsub function returns ALL clients - # subscribed to this channel in the entire cluster - assert r.pubsub_numsub(channel, target_nodes="all") == [(b_channel, len(nodes))] - - def test_pubsub_numpat_merge_results(self, r): - nodes = r.get_nodes() - pubsub_nodes = [] - pattern = "foo*" - for node in nodes: - # We will create different pubsub clients where each one is - # connected to a different node - p = r.pubsub(node) - pubsub_nodes.append(p) - p.psubscribe(pattern) - # Assert that each node returns that only one client is subscribed - sub_num_pat = node.redis_connection.pubsub_numpat() - if sub_num_pat == 0: - sleep(0.3) - sub_num_pat = node.redis_connection.pubsub_numpat() - assert sub_num_pat == 1 - # Assert that the cluster's pubsub_numsub function returns ALL clients - # subscribed to this channel in the entire cluster - assert r.pubsub_numpat(target_nodes="all") == len(nodes) - - @skip_if_server_version_lt("2.8.0") - def test_cluster_pubsub_channels(self, r): - p = r.pubsub() - p.subscribe("foo", "bar", "baz", "quux") - for i in range(4): - assert wait_for_message(p, timeout=0.5)["type"] == "subscribe" - expected = [b"bar", b"baz", b"foo", b"quux"] - assert all( - [channel in r.pubsub_channels(target_nodes="all") for channel in expected] - ) - - @skip_if_server_version_lt("2.8.0") - def test_cluster_pubsub_numsub(self, r): - p1 = r.pubsub() - p1.subscribe("foo", "bar", "baz") - for i in range(3): - assert wait_for_message(p1, timeout=0.5)["type"] == "subscribe" - p2 = r.pubsub() - p2.subscribe("bar", "baz") - for i in range(2): - assert wait_for_message(p2, timeout=0.5)["type"] == "subscribe" - p3 = r.pubsub() - p3.subscribe("baz") - assert wait_for_message(p3, timeout=0.5)["type"] == "subscribe" - - channels = [(b"foo", 1), (b"bar", 2), (b"baz", 3)] - assert r.pubsub_numsub("foo", "bar", "baz", target_nodes="all") == channels - - def test_cluster_slots(self, r): - mock_all_nodes_resp(r, default_cluster_slots) - cluster_slots = r.cluster_slots() - assert isinstance(cluster_slots, dict) - assert len(default_cluster_slots) == len(cluster_slots) - assert cluster_slots.get((0, 8191)) is not None - assert cluster_slots.get((0, 8191)).get("primary") == ("127.0.0.1", 7000) - - def test_cluster_addslots(self, r): - node = r.get_random_node() - mock_node_resp(node, "OK") - assert r.cluster_addslots(node, 1, 2, 3) is True - - def test_cluster_countkeysinslot(self, r): - node = r.nodes_manager.get_node_from_slot(1) - mock_node_resp(node, 2) - assert r.cluster_countkeysinslot(1) == 2 - - def test_cluster_count_failure_report(self, r): - mock_all_nodes_resp(r, 0) - assert r.cluster_count_failure_report("node_0") == 0 - - def test_cluster_delslots(self): - cluster_slots = [ - [ - 0, - 8191, - ["127.0.0.1", 7000, "node_0"], - ], - [ - 8192, - 16383, - ["127.0.0.1", 7001, "node_1"], - ], - ] - r = get_mocked_redis_client( - host=default_host, port=default_port, cluster_slots=cluster_slots - ) - mock_all_nodes_resp(r, "OK") - node0 = r.get_node(default_host, 7000) - node1 = r.get_node(default_host, 7001) - assert r.cluster_delslots(0, 8192) == [True, True] - assert node0.redis_connection.connection.read_response.called - assert node1.redis_connection.connection.read_response.called - - def test_cluster_failover(self, r): - node = r.get_random_node() - mock_node_resp(node, "OK") - assert r.cluster_failover(node) is True - assert r.cluster_failover(node, "FORCE") is True - assert r.cluster_failover(node, "TAKEOVER") is True - with pytest.raises(RedisError): - r.cluster_failover(node, "FORCT") - - def test_cluster_info(self, r): - info = r.cluster_info() - assert isinstance(info, dict) - assert info["cluster_state"] == "ok" - - def test_cluster_keyslot(self, r): - mock_all_nodes_resp(r, 12182) - assert r.cluster_keyslot("foo") == 12182 - - def test_cluster_meet(self, r): - node = r.get_default_node() - mock_node_resp(node, "OK") - assert r.cluster_meet("127.0.0.1", 6379) is True - - def test_cluster_nodes(self, r): - response = ( - "c8253bae761cb1ecb2b61857d85dfe455a0fec8b 172.17.0.7:7006 " - "slave aa90da731f673a99617dfe930306549a09f83a6b 0 " - "1447836263059 5 connected\n" - "9bd595fe4821a0e8d6b99d70faa660638a7612b3 172.17.0.7:7008 " - "master - 0 1447836264065 0 connected\n" - "aa90da731f673a99617dfe930306549a09f83a6b 172.17.0.7:7003 " - "myself,master - 0 0 2 connected 5461-10922\n" - "1df047e5a594f945d82fc140be97a1452bcbf93e 172.17.0.7:7007 " - "slave 19efe5a631f3296fdf21a5441680f893e8cc96ec 0 " - "1447836262556 3 connected\n" - "4ad9a12e63e8f0207025eeba2354bcf4c85e5b22 172.17.0.7:7005 " - "master - 0 1447836262555 7 connected 0-5460\n" - "19efe5a631f3296fdf21a5441680f893e8cc96ec 172.17.0.7:7004 " - "master - 0 1447836263562 3 connected 10923-16383\n" - "fbb23ed8cfa23f17eaf27ff7d0c410492a1093d6 172.17.0.7:7002 " - "master,fail - 1447829446956 1447829444948 1 disconnected\n" - ) - mock_all_nodes_resp(r, response) - nodes = r.cluster_nodes() - assert len(nodes) == 7 - assert nodes.get("172.17.0.7:7006") is not None - assert ( - nodes.get("172.17.0.7:7006").get("node_id") - == "c8253bae761cb1ecb2b61857d85dfe455a0fec8b" - ) - - def test_cluster_replicate(self, r): - node = r.get_random_node() - all_replicas = r.get_replicas() - mock_all_nodes_resp(r, "OK") - assert r.cluster_replicate(node, "c8253bae761cb61857d") is True - results = r.cluster_replicate(all_replicas, "c8253bae761cb61857d") - if isinstance(results, dict): - for res in results.values(): - assert res is True - else: - assert results is True - - def test_cluster_reset(self, r): - mock_all_nodes_resp(r, "OK") - assert r.cluster_reset() is True - assert r.cluster_reset(False) is True - all_results = r.cluster_reset(False, target_nodes="all") - for res in all_results.values(): - assert res is True - - def test_cluster_save_config(self, r): - node = r.get_random_node() - all_nodes = r.get_nodes() - mock_all_nodes_resp(r, "OK") - assert r.cluster_save_config(node) is True - all_results = r.cluster_save_config(all_nodes) - for res in all_results.values(): - assert res is True - - def test_cluster_get_keys_in_slot(self, r): - response = [b"{foo}1", b"{foo}2"] - node = r.nodes_manager.get_node_from_slot(12182) - mock_node_resp(node, response) - keys = r.cluster_get_keys_in_slot(12182, 4) - assert keys == response - - def test_cluster_set_config_epoch(self, r): - mock_all_nodes_resp(r, "OK") - assert r.cluster_set_config_epoch(3) is True - all_results = r.cluster_set_config_epoch(3, target_nodes="all") - for res in all_results.values(): - assert res is True - - def test_cluster_setslot(self, r): - node = r.get_random_node() - mock_node_resp(node, "OK") - assert r.cluster_setslot(node, "node_0", 1218, "IMPORTING") is True - assert r.cluster_setslot(node, "node_0", 1218, "NODE") is True - assert r.cluster_setslot(node, "node_0", 1218, "MIGRATING") is True - with pytest.raises(RedisError): - r.cluster_failover(node, "STABLE") - with pytest.raises(RedisError): - r.cluster_failover(node, "STATE") - - def test_cluster_setslot_stable(self, r): - node = r.nodes_manager.get_node_from_slot(12182) - mock_node_resp(node, "OK") - assert r.cluster_setslot_stable(12182) is True - assert node.redis_connection.connection.read_response.called - - def test_cluster_replicas(self, r): - response = [ - b"01eca22229cf3c652b6fca0d09ff6941e0d2e3 " - b"127.0.0.1:6377@16377 slave " - b"52611e796814b78e90ad94be9d769a4f668f9a 0 " - b"1634550063436 4 connected", - b"r4xfga22229cf3c652b6fca0d09ff69f3e0d4d " - b"127.0.0.1:6378@16378 slave " - b"52611e796814b78e90ad94be9d769a4f668f9a 0 " - b"1634550063436 4 connected", - ] - mock_all_nodes_resp(r, response) - replicas = r.cluster_replicas("52611e796814b78e90ad94be9d769a4f668f9a") - assert replicas.get("127.0.0.1:6377") is not None - assert replicas.get("127.0.0.1:6378") is not None - assert ( - replicas.get("127.0.0.1:6378").get("node_id") - == "r4xfga22229cf3c652b6fca0d09ff69f3e0d4d" - ) - - def test_readonly(self): - r = get_mocked_redis_client(host=default_host, port=default_port) - mock_all_nodes_resp(r, "OK") - assert r.readonly() is True - all_replicas_results = r.readonly(target_nodes="replicas") - for res in all_replicas_results.values(): - assert res is True - for replica in r.get_replicas(): - assert replica.redis_connection.connection.read_response.called - - def test_readwrite(self): - r = get_mocked_redis_client(host=default_host, port=default_port) - mock_all_nodes_resp(r, "OK") - assert r.readwrite() is True - all_replicas_results = r.readwrite(target_nodes="replicas") - for res in all_replicas_results.values(): - assert res is True - for replica in r.get_replicas(): - assert replica.redis_connection.connection.read_response.called - - def test_bgsave(self, r): - assert r.bgsave() - sleep(0.3) - assert r.bgsave(True) - - def test_info(self, r): - # Map keys to same slot - r.set("x{1}", 1) - r.set("y{1}", 2) - r.set("z{1}", 3) - # Get node that handles the slot - slot = r.keyslot("x{1}") - node = r.nodes_manager.get_node_from_slot(slot) - # Run info on that node - info = r.info(target_nodes=node) - assert isinstance(info, dict) - assert info["db0"]["keys"] == 3 - - def _init_slowlog_test(self, r, node): - slowlog_lim = r.config_get("slowlog-log-slower-than", target_nodes=node) - assert r.config_set("slowlog-log-slower-than", 0, target_nodes=node) is True - return slowlog_lim["slowlog-log-slower-than"] - - def _teardown_slowlog_test(self, r, node, prev_limit): - assert ( - r.config_set("slowlog-log-slower-than", prev_limit, target_nodes=node) - is True - ) - - def test_slowlog_get(self, r, slowlog): - unicode_string = chr(3456) + "abcd" + chr(3421) - node = r.get_node_from_key(unicode_string) - slowlog_limit = self._init_slowlog_test(r, node) - assert r.slowlog_reset(target_nodes=node) - r.get(unicode_string) - slowlog = r.slowlog_get(target_nodes=node) - assert isinstance(slowlog, list) - commands = [log["command"] for log in slowlog] - - get_command = b" ".join((b"GET", unicode_string.encode("utf-8"))) - assert get_command in commands - assert b"SLOWLOG RESET" in commands - - # the order should be ['GET ', 'SLOWLOG RESET'], - # but if other clients are executing commands at the same time, there - # could be commands, before, between, or after, so just check that - # the two we care about are in the appropriate order. - assert commands.index(get_command) < commands.index(b"SLOWLOG RESET") - - # make sure other attributes are typed correctly - assert isinstance(slowlog[0]["start_time"], int) - assert isinstance(slowlog[0]["duration"], int) - # rollback the slowlog limit to its original value - self._teardown_slowlog_test(r, node, slowlog_limit) - - def test_slowlog_get_limit(self, r, slowlog): - assert r.slowlog_reset() - node = r.get_node_from_key("foo") - slowlog_limit = self._init_slowlog_test(r, node) - r.get("foo") - slowlog = r.slowlog_get(1, target_nodes=node) - assert isinstance(slowlog, list) - # only one command, based on the number we passed to slowlog_get() - assert len(slowlog) == 1 - self._teardown_slowlog_test(r, node, slowlog_limit) - - def test_slowlog_length(self, r, slowlog): - r.get("foo") - node = r.nodes_manager.get_node_from_slot(key_slot(b"foo")) - slowlog_len = r.slowlog_len(target_nodes=node) - assert isinstance(slowlog_len, int) - - def test_time(self, r): - t = r.time(target_nodes=r.get_primaries()[0]) - assert len(t) == 2 - assert isinstance(t[0], int) - assert isinstance(t[1], int) - - @skip_if_server_version_lt("4.0.0") - def test_memory_usage(self, r): - r.set("foo", "bar") - assert isinstance(r.memory_usage("foo"), int) - - @skip_if_server_version_lt("4.0.0") - def test_memory_malloc_stats(self, r): - assert r.memory_malloc_stats() - - @skip_if_server_version_lt("4.0.0") - def test_memory_stats(self, r): - # put a key into the current db to make sure that "db." - # has data - r.set("foo", "bar") - node = r.nodes_manager.get_node_from_slot(key_slot(b"foo")) - stats = r.memory_stats(target_nodes=node) - assert isinstance(stats, dict) - for key, value in stats.items(): - if key.startswith("db."): - assert isinstance(value, dict) - - @skip_if_server_version_lt("4.0.0") - def test_memory_help(self, r): - with pytest.raises(NotImplementedError): - r.memory_help() - - @skip_if_server_version_lt("4.0.0") - def test_memory_doctor(self, r): - with pytest.raises(NotImplementedError): - r.memory_doctor() - - def test_lastsave(self, r): - node = r.get_primaries()[0] - assert isinstance(r.lastsave(target_nodes=node), datetime.datetime) - - def test_cluster_echo(self, r): - node = r.get_primaries()[0] - assert r.echo("foo bar", target_nodes=node) == b"foo bar" - - @skip_if_server_version_lt("1.0.0") - def test_debug_segfault(self, r): - with pytest.raises(NotImplementedError): - r.debug_segfault() - - def test_config_resetstat(self, r): - node = r.get_primaries()[0] - r.ping(target_nodes=node) - prior_commands_processed = int( - r.info(target_nodes=node)["total_commands_processed"] - ) - assert prior_commands_processed >= 1 - r.config_resetstat(target_nodes=node) - reset_commands_processed = int( - r.info(target_nodes=node)["total_commands_processed"] - ) - assert reset_commands_processed < prior_commands_processed - - @skip_if_server_version_lt("6.2.0") - def test_client_trackinginfo(self, r): - node = r.get_primaries()[0] - res = r.client_trackinginfo(target_nodes=node) - assert len(res) > 2 - assert "prefixes" in res - - @skip_if_server_version_lt("2.9.50") - def test_client_pause(self, r): - node = r.get_primaries()[0] - assert r.client_pause(1, target_nodes=node) - assert r.client_pause(timeout=1, target_nodes=node) - with pytest.raises(RedisError): - r.client_pause(timeout="not an integer", target_nodes=node) - - @skip_if_server_version_lt("6.2.0") - def test_client_unpause(self, r): - assert r.client_unpause() - - @skip_if_server_version_lt("5.0.0") - def test_client_id(self, r): - node = r.get_primaries()[0] - assert r.client_id(target_nodes=node) > 0 - - @skip_if_server_version_lt("5.0.0") - def test_client_unblock(self, r): - node = r.get_primaries()[0] - myid = r.client_id(target_nodes=node) - assert not r.client_unblock(myid, target_nodes=node) - assert not r.client_unblock(myid, error=True, target_nodes=node) - assert not r.client_unblock(myid, error=False, target_nodes=node) - - @skip_if_server_version_lt("6.0.0") - def test_client_getredir(self, r): - node = r.get_primaries()[0] - assert isinstance(r.client_getredir(target_nodes=node), int) - assert r.client_getredir(target_nodes=node) == -1 - - @skip_if_server_version_lt("6.2.0") - def test_client_info(self, r): - node = r.get_primaries()[0] - info = r.client_info(target_nodes=node) - assert isinstance(info, dict) - assert "addr" in info - - @skip_if_server_version_lt("2.6.9") - def test_client_kill(self, r, r2): - node = r.get_primaries()[0] - r.client_setname("redis-py-c1", target_nodes="all") - r2.client_setname("redis-py-c2", target_nodes="all") - clients = [ - client - for client in r.client_list(target_nodes=node) - if client.get("name") in ["redis-py-c1", "redis-py-c2"] - ] - assert len(clients) == 2 - clients_by_name = {client.get("name"): client for client in clients} - - client_addr = clients_by_name["redis-py-c2"].get("addr") - assert r.client_kill(client_addr, target_nodes=node) is True - - clients = [ - client - for client in r.client_list(target_nodes=node) - if client.get("name") in ["redis-py-c1", "redis-py-c2"] - ] - assert len(clients) == 1 - assert clients[0].get("name") == "redis-py-c1" - - @skip_if_server_version_lt("2.6.0") - def test_cluster_bitop_not_empty_string(self, r): - r["{foo}a"] = "" - r.bitop("not", "{foo}r", "{foo}a") - assert r.get("{foo}r") is None - - @skip_if_server_version_lt("2.6.0") - def test_cluster_bitop_not(self, r): - test_str = b"\xAA\x00\xFF\x55" - correct = ~0xAA00FF55 & 0xFFFFFFFF - r["{foo}a"] = test_str - r.bitop("not", "{foo}r", "{foo}a") - assert int(binascii.hexlify(r["{foo}r"]), 16) == correct - - @skip_if_server_version_lt("2.6.0") - def test_cluster_bitop_not_in_place(self, r): - test_str = b"\xAA\x00\xFF\x55" - correct = ~0xAA00FF55 & 0xFFFFFFFF - r["{foo}a"] = test_str - r.bitop("not", "{foo}a", "{foo}a") - assert int(binascii.hexlify(r["{foo}a"]), 16) == correct - - @skip_if_server_version_lt("2.6.0") - def test_cluster_bitop_single_string(self, r): - test_str = b"\x01\x02\xFF" - r["{foo}a"] = test_str - r.bitop("and", "{foo}res1", "{foo}a") - r.bitop("or", "{foo}res2", "{foo}a") - r.bitop("xor", "{foo}res3", "{foo}a") - assert r["{foo}res1"] == test_str - assert r["{foo}res2"] == test_str - assert r["{foo}res3"] == test_str - - @skip_if_server_version_lt("2.6.0") - def test_cluster_bitop_string_operands(self, r): - r["{foo}a"] = b"\x01\x02\xFF\xFF" - r["{foo}b"] = b"\x01\x02\xFF" - r.bitop("and", "{foo}res1", "{foo}a", "{foo}b") - r.bitop("or", "{foo}res2", "{foo}a", "{foo}b") - r.bitop("xor", "{foo}res3", "{foo}a", "{foo}b") - assert int(binascii.hexlify(r["{foo}res1"]), 16) == 0x0102FF00 - assert int(binascii.hexlify(r["{foo}res2"]), 16) == 0x0102FFFF - assert int(binascii.hexlify(r["{foo}res3"]), 16) == 0x000000FF - - @skip_if_server_version_lt("6.2.0") - def test_cluster_copy(self, r): - assert r.copy("{foo}a", "{foo}b") == 0 - r.set("{foo}a", "bar") - assert r.copy("{foo}a", "{foo}b") == 1 - assert r.get("{foo}a") == b"bar" - assert r.get("{foo}b") == b"bar" - - @skip_if_server_version_lt("6.2.0") - def test_cluster_copy_and_replace(self, r): - r.set("{foo}a", "foo1") - r.set("{foo}b", "foo2") - assert r.copy("{foo}a", "{foo}b") == 0 - assert r.copy("{foo}a", "{foo}b", replace=True) == 1 - - @skip_if_server_version_lt("6.2.0") - def test_cluster_lmove(self, r): - r.rpush("{foo}a", "one", "two", "three", "four") - assert r.lmove("{foo}a", "{foo}b") - assert r.lmove("{foo}a", "{foo}b", "right", "left") - - @skip_if_server_version_lt("6.2.0") - def test_cluster_blmove(self, r): - r.rpush("{foo}a", "one", "two", "three", "four") - assert r.blmove("{foo}a", "{foo}b", 5) - assert r.blmove("{foo}a", "{foo}b", 1, "RIGHT", "LEFT") - - def test_cluster_msetnx(self, r): - d = {"{foo}a": b"1", "{foo}b": b"2", "{foo}c": b"3"} - assert r.msetnx(d) - d2 = {"{foo}a": b"x", "{foo}d": b"4"} - assert not r.msetnx(d2) - for k, v in d.items(): - assert r[k] == v - assert r.get("{foo}d") is None - - def test_cluster_rename(self, r): - r["{foo}a"] = "1" - assert r.rename("{foo}a", "{foo}b") - assert r.get("{foo}a") is None - assert r["{foo}b"] == b"1" - - def test_cluster_renamenx(self, r): - r["{foo}a"] = "1" - r["{foo}b"] = "2" - assert not r.renamenx("{foo}a", "{foo}b") - assert r["{foo}a"] == b"1" - assert r["{foo}b"] == b"2" - - # LIST COMMANDS - def test_cluster_blpop(self, r): - r.rpush("{foo}a", "1", "2") - r.rpush("{foo}b", "3", "4") - assert r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"3") - assert r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"4") - assert r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"1") - assert r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"2") - assert r.blpop(["{foo}b", "{foo}a"], timeout=1) is None - r.rpush("{foo}c", "1") - assert r.blpop("{foo}c", timeout=1) == (b"{foo}c", b"1") - - def test_cluster_brpop(self, r): - r.rpush("{foo}a", "1", "2") - r.rpush("{foo}b", "3", "4") - assert r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"4") - assert r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"3") - assert r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"2") - assert r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"1") - assert r.brpop(["{foo}b", "{foo}a"], timeout=1) is None - r.rpush("{foo}c", "1") - assert r.brpop("{foo}c", timeout=1) == (b"{foo}c", b"1") - - def test_cluster_brpoplpush(self, r): - r.rpush("{foo}a", "1", "2") - r.rpush("{foo}b", "3", "4") - assert r.brpoplpush("{foo}a", "{foo}b") == b"2" - assert r.brpoplpush("{foo}a", "{foo}b") == b"1" - assert r.brpoplpush("{foo}a", "{foo}b", timeout=1) is None - assert r.lrange("{foo}a", 0, -1) == [] - assert r.lrange("{foo}b", 0, -1) == [b"1", b"2", b"3", b"4"] - - def test_cluster_brpoplpush_empty_string(self, r): - r.rpush("{foo}a", "") - assert r.brpoplpush("{foo}a", "{foo}b") == b"" - - def test_cluster_rpoplpush(self, r): - r.rpush("{foo}a", "a1", "a2", "a3") - r.rpush("{foo}b", "b1", "b2", "b3") - assert r.rpoplpush("{foo}a", "{foo}b") == b"a3" - assert r.lrange("{foo}a", 0, -1) == [b"a1", b"a2"] - assert r.lrange("{foo}b", 0, -1) == [b"a3", b"b1", b"b2", b"b3"] - - def test_cluster_sdiff(self, r): - r.sadd("{foo}a", "1", "2", "3") - assert r.sdiff("{foo}a", "{foo}b") == {b"1", b"2", b"3"} - r.sadd("{foo}b", "2", "3") - assert r.sdiff("{foo}a", "{foo}b") == {b"1"} - - def test_cluster_sdiffstore(self, r): - r.sadd("{foo}a", "1", "2", "3") - assert r.sdiffstore("{foo}c", "{foo}a", "{foo}b") == 3 - assert r.smembers("{foo}c") == {b"1", b"2", b"3"} - r.sadd("{foo}b", "2", "3") - assert r.sdiffstore("{foo}c", "{foo}a", "{foo}b") == 1 - assert r.smembers("{foo}c") == {b"1"} - - def test_cluster_sinter(self, r): - r.sadd("{foo}a", "1", "2", "3") - assert r.sinter("{foo}a", "{foo}b") == set() - r.sadd("{foo}b", "2", "3") - assert r.sinter("{foo}a", "{foo}b") == {b"2", b"3"} - - def test_cluster_sinterstore(self, r): - r.sadd("{foo}a", "1", "2", "3") - assert r.sinterstore("{foo}c", "{foo}a", "{foo}b") == 0 - assert r.smembers("{foo}c") == set() - r.sadd("{foo}b", "2", "3") - assert r.sinterstore("{foo}c", "{foo}a", "{foo}b") == 2 - assert r.smembers("{foo}c") == {b"2", b"3"} - - def test_cluster_smove(self, r): - r.sadd("{foo}a", "a1", "a2") - r.sadd("{foo}b", "b1", "b2") - assert r.smove("{foo}a", "{foo}b", "a1") - assert r.smembers("{foo}a") == {b"a2"} - assert r.smembers("{foo}b") == {b"b1", b"b2", b"a1"} - - def test_cluster_sunion(self, r): - r.sadd("{foo}a", "1", "2") - r.sadd("{foo}b", "2", "3") - assert r.sunion("{foo}a", "{foo}b") == {b"1", b"2", b"3"} - - def test_cluster_sunionstore(self, r): - r.sadd("{foo}a", "1", "2") - r.sadd("{foo}b", "2", "3") - assert r.sunionstore("{foo}c", "{foo}a", "{foo}b") == 3 - assert r.smembers("{foo}c") == {b"1", b"2", b"3"} - - @skip_if_server_version_lt("6.2.0") - def test_cluster_zdiff(self, r): - r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) - r.zadd("{foo}b", {"a1": 1, "a2": 2}) - assert r.zdiff(["{foo}a", "{foo}b"]) == [b"a3"] - assert r.zdiff(["{foo}a", "{foo}b"], withscores=True) == [b"a3", b"3"] - - @skip_if_server_version_lt("6.2.0") - def test_cluster_zdiffstore(self, r): - r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) - r.zadd("{foo}b", {"a1": 1, "a2": 2}) - assert r.zdiffstore("{foo}out", ["{foo}a", "{foo}b"]) - assert r.zrange("{foo}out", 0, -1) == [b"a3"] - assert r.zrange("{foo}out", 0, -1, withscores=True) == [(b"a3", 3.0)] - - @skip_if_server_version_lt("6.2.0") - def test_cluster_zinter(self, r): - r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 1}) - r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) - r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) - assert r.zinter(["{foo}a", "{foo}b", "{foo}c"]) == [b"a3", b"a1"] - # invalid aggregation - with pytest.raises(DataError): - r.zinter(["{foo}a", "{foo}b", "{foo}c"], aggregate="foo", withscores=True) - # aggregate with SUM - assert r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True) == [ - (b"a3", 8), - (b"a1", 9), - ] - # aggregate with MAX - assert r.zinter( - ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True - ) == [(b"a3", 5), (b"a1", 6)] - # aggregate with MIN - assert r.zinter( - ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True - ) == [(b"a1", 1), (b"a3", 1)] - # with weights - assert r.zinter({"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True) == [ - (b"a3", 20), - (b"a1", 23), - ] - - def test_cluster_zinterstore_sum(self, r): - r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) - r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) - r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) - assert r.zinterstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"]) == 2 - assert r.zrange("{foo}d", 0, -1, withscores=True) == [(b"a3", 8), (b"a1", 9)] - - def test_cluster_zinterstore_max(self, r): - r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) - r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) - r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) - assert ( - r.zinterstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX") - == 2 - ) - assert r.zrange("{foo}d", 0, -1, withscores=True) == [(b"a3", 5), (b"a1", 6)] - - def test_cluster_zinterstore_min(self, r): - r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) - r.zadd("{foo}b", {"a1": 2, "a2": 3, "a3": 5}) - r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) - assert ( - r.zinterstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN") - == 2 - ) - assert r.zrange("{foo}d", 0, -1, withscores=True) == [(b"a1", 1), (b"a3", 3)] - - def test_cluster_zinterstore_with_weight(self, r): - r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) - r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) - r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) - assert r.zinterstore("{foo}d", {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}) == 2 - assert r.zrange("{foo}d", 0, -1, withscores=True) == [(b"a3", 20), (b"a1", 23)] - - @skip_if_server_version_lt("4.9.0") - def test_cluster_bzpopmax(self, r): - r.zadd("{foo}a", {"a1": 1, "a2": 2}) - r.zadd("{foo}b", {"b1": 10, "b2": 20}) - assert r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"b2", 20) - assert r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"b1", 10) - assert r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"a2", 2) - assert r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"a1", 1) - assert r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) is None - r.zadd("{foo}c", {"c1": 100}) - assert r.bzpopmax("{foo}c", timeout=1) == (b"{foo}c", b"c1", 100) - - @skip_if_server_version_lt("4.9.0") - def test_cluster_bzpopmin(self, r): - r.zadd("{foo}a", {"a1": 1, "a2": 2}) - r.zadd("{foo}b", {"b1": 10, "b2": 20}) - assert r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"b1", 10) - assert r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"b2", 20) - assert r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"a1", 1) - assert r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"a2", 2) - assert r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) is None - r.zadd("{foo}c", {"c1": 100}) - assert r.bzpopmin("{foo}c", timeout=1) == (b"{foo}c", b"c1", 100) - - @skip_if_server_version_lt("6.2.0") - def test_cluster_zrangestore(self, r): - r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) - assert r.zrangestore("{foo}b", "{foo}a", 0, 1) - assert r.zrange("{foo}b", 0, -1) == [b"a1", b"a2"] - assert r.zrangestore("{foo}b", "{foo}a", 1, 2) - assert r.zrange("{foo}b", 0, -1) == [b"a2", b"a3"] - assert r.zrange("{foo}b", 0, -1, withscores=True) == [(b"a2", 2), (b"a3", 3)] - # reversed order - assert r.zrangestore("{foo}b", "{foo}a", 1, 2, desc=True) - assert r.zrange("{foo}b", 0, -1) == [b"a1", b"a2"] - # by score - assert r.zrangestore( - "{foo}b", "{foo}a", 2, 1, byscore=True, offset=0, num=1, desc=True - ) - assert r.zrange("{foo}b", 0, -1) == [b"a2"] - # by lex - assert r.zrangestore( - "{foo}b", "{foo}a", "[a2", "(a3", bylex=True, offset=0, num=1 - ) - assert r.zrange("{foo}b", 0, -1) == [b"a2"] - - @skip_if_server_version_lt("6.2.0") - def test_cluster_zunion(self, r): - r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) - r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) - r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) - # sum - assert r.zunion(["{foo}a", "{foo}b", "{foo}c"]) == [b"a2", b"a4", b"a3", b"a1"] - assert r.zunion(["{foo}a", "{foo}b", "{foo}c"], withscores=True) == [ - (b"a2", 3), - (b"a4", 4), - (b"a3", 8), - (b"a1", 9), - ] - # max - assert r.zunion( - ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True - ) == [(b"a2", 2), (b"a4", 4), (b"a3", 5), (b"a1", 6)] - # min - assert r.zunion( - ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True - ) == [(b"a1", 1), (b"a2", 1), (b"a3", 1), (b"a4", 4)] - # with weight - assert r.zunion({"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True) == [ - (b"a2", 5), - (b"a4", 12), - (b"a3", 20), - (b"a1", 23), - ] - - def test_cluster_zunionstore_sum(self, r): - r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) - r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) - r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) - assert r.zunionstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"]) == 4 - assert r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a2", 3), - (b"a4", 4), - (b"a3", 8), - (b"a1", 9), - ] - - def test_cluster_zunionstore_max(self, r): - r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) - r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) - r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) - assert ( - r.zunionstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX") - == 4 - ) - assert r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a2", 2), - (b"a4", 4), - (b"a3", 5), - (b"a1", 6), - ] - - def test_cluster_zunionstore_min(self, r): - r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) - r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 4}) - r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) - assert ( - r.zunionstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN") - == 4 - ) - assert r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a1", 1), - (b"a2", 2), - (b"a3", 3), - (b"a4", 4), - ] - - def test_cluster_zunionstore_with_weight(self, r): - r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) - r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) - r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) - assert r.zunionstore("{foo}d", {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}) == 4 - assert r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a2", 5), - (b"a4", 12), - (b"a3", 20), - (b"a1", 23), - ] - - @skip_if_server_version_lt("2.8.9") - def test_cluster_pfcount(self, r): - members = {b"1", b"2", b"3"} - r.pfadd("{foo}a", *members) - assert r.pfcount("{foo}a") == len(members) - members_b = {b"2", b"3", b"4"} - r.pfadd("{foo}b", *members_b) - assert r.pfcount("{foo}b") == len(members_b) - assert r.pfcount("{foo}a", "{foo}b") == len(members_b.union(members)) - - @skip_if_server_version_lt("2.8.9") - def test_cluster_pfmerge(self, r): - mema = {b"1", b"2", b"3"} - memb = {b"2", b"3", b"4"} - memc = {b"5", b"6", b"7"} - r.pfadd("{foo}a", *mema) - r.pfadd("{foo}b", *memb) - r.pfadd("{foo}c", *memc) - r.pfmerge("{foo}d", "{foo}c", "{foo}a") - assert r.pfcount("{foo}d") == 6 - r.pfmerge("{foo}d", "{foo}b") - assert r.pfcount("{foo}d") == 7 - - def test_cluster_sort_store(self, r): - r.rpush("{foo}a", "2", "3", "1") - assert r.sort("{foo}a", store="{foo}sorted_values") == 3 - assert r.lrange("{foo}sorted_values", 0, -1) == [b"1", b"2", b"3"] - - # GEO COMMANDS - @skip_if_server_version_lt("6.2.0") - def test_cluster_geosearchstore(self, r): - values = (2.1909389952632, 41.433791470673, "place1") + ( - 2.1873744593677, - 41.406342043777, - "place2", - ) - - r.geoadd("{foo}barcelona", values) - r.geosearchstore( - "{foo}places_barcelona", - "{foo}barcelona", - longitude=2.191, - latitude=41.433, - radius=1000, - ) - assert r.zrange("{foo}places_barcelona", 0, -1) == [b"place1"] - - @skip_unless_arch_bits(64) - @skip_if_server_version_lt("6.2.0") - def test_geosearchstore_dist(self, r): - values = (2.1909389952632, 41.433791470673, "place1") + ( - 2.1873744593677, - 41.406342043777, - "place2", - ) - - r.geoadd("{foo}barcelona", values) - r.geosearchstore( - "{foo}places_barcelona", - "{foo}barcelona", - longitude=2.191, - latitude=41.433, - radius=1000, - storedist=True, - ) - # instead of save the geo score, the distance is saved. - assert r.zscore("{foo}places_barcelona", "place1") == 88.05060698409301 - - @skip_if_server_version_lt("3.2.0") - def test_cluster_georadius_store(self, r): - values = (2.1909389952632, 41.433791470673, "place1") + ( - 2.1873744593677, - 41.406342043777, - "place2", - ) - - r.geoadd("{foo}barcelona", values) - r.georadius( - "{foo}barcelona", 2.191, 41.433, 1000, store="{foo}places_barcelona" - ) - assert r.zrange("{foo}places_barcelona", 0, -1) == [b"place1"] - - @skip_unless_arch_bits(64) - @skip_if_server_version_lt("3.2.0") - def test_cluster_georadius_store_dist(self, r): - values = (2.1909389952632, 41.433791470673, "place1") + ( - 2.1873744593677, - 41.406342043777, - "place2", - ) - - r.geoadd("{foo}barcelona", values) - r.georadius( - "{foo}barcelona", 2.191, 41.433, 1000, store_dist="{foo}places_barcelona" - ) - # instead of save the geo score, the distance is saved. - assert r.zscore("{foo}places_barcelona", "place1") == 88.05060698409301 - - def test_cluster_dbsize(self, r): - d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} - assert r.mset_nonatomic(d) - assert r.dbsize(target_nodes="primaries") == len(d) - - def test_cluster_keys(self, r): - assert r.keys() == [] - keys_with_underscores = {b"test_a", b"test_b"} - keys = keys_with_underscores.union({b"testc"}) - for key in keys: - r[key] = 1 - assert ( - set(r.keys(pattern="test_*", target_nodes="primaries")) - == keys_with_underscores - ) - assert set(r.keys(pattern="test*", target_nodes="primaries")) == keys - - # SCAN COMMANDS - @skip_if_server_version_lt("2.8.0") - def test_cluster_scan(self, r): - r.set("a", 1) - r.set("b", 2) - r.set("c", 3) - cursor, keys = r.scan(target_nodes="primaries") - assert cursor == 0 - assert set(keys) == {b"a", b"b", b"c"} - _, keys = r.scan(match="a", target_nodes="primaries") - assert set(keys) == {b"a"} - - @skip_if_server_version_lt("6.0.0") - def test_cluster_scan_type(self, r): - r.sadd("a-set", 1) - r.hset("a-hash", "foo", 2) - r.lpush("a-list", "aux", 3) - _, keys = r.scan(match="a*", _type="SET", target_nodes="primaries") - assert set(keys) == {b"a-set"} - - @skip_if_server_version_lt("2.8.0") - def test_cluster_scan_iter(self, r): - r.set("a", 1) - r.set("b", 2) - r.set("c", 3) - keys = list(r.scan_iter(target_nodes="primaries")) - assert set(keys) == {b"a", b"b", b"c"} - keys = list(r.scan_iter(match="a", target_nodes="primaries")) - assert set(keys) == {b"a"} - - def test_cluster_randomkey(self, r): - node = r.get_node_from_key("{foo}") - assert r.randomkey(target_nodes=node) is None - for key in ("{foo}a", "{foo}b", "{foo}c"): - r[key] = 1 - assert r.randomkey(target_nodes=node) in (b"{foo}a", b"{foo}b", b"{foo}c") - - @skip_if_server_version_lt("6.0.0") - @skip_if_redis_enterprise() - def test_acl_log(self, r, request): - key = "{cache}:" - node = r.get_node_from_key(key) - username = "redis-py-user" - - def teardown(): - r.acl_deluser(username, target_nodes="primaries") - - request.addfinalizer(teardown) - r.acl_setuser( - username, - enabled=True, - reset=True, - commands=["+get", "+set", "+select", "+cluster", "+command", "+info"], - keys=["{cache}:*"], - nopass=True, - target_nodes="primaries", - ) - r.acl_log_reset(target_nodes=node) - - user_client = _get_client( - RedisCluster, request, flushdb=False, username=username - ) - - # Valid operation and key - assert user_client.set("{cache}:0", 1) - assert user_client.get("{cache}:0") == b"1" - - # Invalid key - with pytest.raises(NoPermissionError): - user_client.get("{cache}violated_cache:0") - - # Invalid operation - with pytest.raises(NoPermissionError): - user_client.hset("{cache}:0", "hkey", "hval") - - assert isinstance(r.acl_log(target_nodes=node), list) - assert len(r.acl_log(target_nodes=node)) == 2 - assert len(r.acl_log(count=1, target_nodes=node)) == 1 - assert isinstance(r.acl_log(target_nodes=node)[0], dict) - assert "client-info" in r.acl_log(count=1, target_nodes=node)[0] - assert r.acl_log_reset(target_nodes=node) - - -@pytest.mark.onlycluster -class TestNodesManager: - """ - Tests for the NodesManager class - """ - - def test_load_balancer(self, r): - n_manager = r.nodes_manager - lb = n_manager.read_load_balancer - slot_1 = 1257 - slot_2 = 8975 - node_1 = ClusterNode(default_host, 6379, PRIMARY) - node_2 = ClusterNode(default_host, 6378, REPLICA) - node_3 = ClusterNode(default_host, 6377, REPLICA) - node_4 = ClusterNode(default_host, 6376, PRIMARY) - node_5 = ClusterNode(default_host, 6375, REPLICA) - n_manager.slots_cache = { - slot_1: [node_1, node_2, node_3], - slot_2: [node_4, node_5], - } - primary1_name = n_manager.slots_cache[slot_1][0].name - primary2_name = n_manager.slots_cache[slot_2][0].name - list1_size = len(n_manager.slots_cache[slot_1]) - list2_size = len(n_manager.slots_cache[slot_2]) - # slot 1 - assert lb.get_server_index(primary1_name, list1_size) == 0 - assert lb.get_server_index(primary1_name, list1_size) == 1 - assert lb.get_server_index(primary1_name, list1_size) == 2 - assert lb.get_server_index(primary1_name, list1_size) == 0 - # slot 2 - assert lb.get_server_index(primary2_name, list2_size) == 0 - assert lb.get_server_index(primary2_name, list2_size) == 1 - assert lb.get_server_index(primary2_name, list2_size) == 0 - - lb.reset() - assert lb.get_server_index(primary1_name, list1_size) == 0 - assert lb.get_server_index(primary2_name, list2_size) == 0 - - def test_init_slots_cache_not_all_slots_covered(self): - """ - Test that if not all slots are covered it should raise an exception - """ - # Missing slot 5460 - cluster_slots = [ - [0, 5459, ["127.0.0.1", 7000], ["127.0.0.1", 7003]], - [5461, 10922, ["127.0.0.1", 7001], ["127.0.0.1", 7004]], - [10923, 16383, ["127.0.0.1", 7002], ["127.0.0.1", 7005]], - ] - with pytest.raises(RedisClusterException) as ex: - get_mocked_redis_client( - host=default_host, port=default_port, cluster_slots=cluster_slots - ) - assert str(ex.value).startswith( - "All slots are not covered after query all startup_nodes." - ) - - def test_init_slots_cache_not_require_full_coverage_error(self): - """ - When require_full_coverage is set to False and not all slots are - covered, if one of the nodes has 'cluster-require_full_coverage' - config set to 'yes' the cluster initialization should fail - """ - # Missing slot 5460 - cluster_slots = [ - [0, 5459, ["127.0.0.1", 7000], ["127.0.0.1", 7003]], - [5461, 10922, ["127.0.0.1", 7001], ["127.0.0.1", 7004]], - [10923, 16383, ["127.0.0.1", 7002], ["127.0.0.1", 7005]], - ] - - with pytest.raises(RedisClusterException): - get_mocked_redis_client( - host=default_host, - port=default_port, - cluster_slots=cluster_slots, - require_full_coverage=False, - coverage_result="yes", - ) - - def test_init_slots_cache_not_require_full_coverage_success(self): - """ - When require_full_coverage is set to False and not all slots are - covered, if all of the nodes has 'cluster-require_full_coverage' - config set to 'no' the cluster initialization should succeed - """ - # Missing slot 5460 - cluster_slots = [ - [0, 5459, ["127.0.0.1", 7000], ["127.0.0.1", 7003]], - [5461, 10922, ["127.0.0.1", 7001], ["127.0.0.1", 7004]], - [10923, 16383, ["127.0.0.1", 7002], ["127.0.0.1", 7005]], - ] - - rc = get_mocked_redis_client( - host=default_host, - port=default_port, - cluster_slots=cluster_slots, - require_full_coverage=False, - coverage_result="no", - ) - - assert 5460 not in rc.nodes_manager.slots_cache - - def test_init_slots_cache_not_require_full_coverage_skips_check(self): - """ - Test that when require_full_coverage is set to False and - skip_full_coverage_check is set to true, the cluster initialization - succeed without checking the nodes' Redis configurations - """ - # Missing slot 5460 - cluster_slots = [ - [0, 5459, ["127.0.0.1", 7000], ["127.0.0.1", 7003]], - [5461, 10922, ["127.0.0.1", 7001], ["127.0.0.1", 7004]], - [10923, 16383, ["127.0.0.1", 7002], ["127.0.0.1", 7005]], - ] - - with patch.object( - NodesManager, "cluster_require_full_coverage" - ) as conf_check_mock: - rc = get_mocked_redis_client( - host=default_host, - port=default_port, - cluster_slots=cluster_slots, - require_full_coverage=False, - skip_full_coverage_check=True, - coverage_result="no", - ) - - assert conf_check_mock.called is False - assert 5460 not in rc.nodes_manager.slots_cache - - def test_init_slots_cache(self): - """ - Test that slots cache can in initialized and all slots are covered - """ - good_slots_resp = [ - [0, 5460, ["127.0.0.1", 7000], ["127.0.0.2", 7003]], - [5461, 10922, ["127.0.0.1", 7001], ["127.0.0.2", 7004]], - [10923, 16383, ["127.0.0.1", 7002], ["127.0.0.2", 7005]], - ] - - rc = get_mocked_redis_client( - host=default_host, port=default_port, cluster_slots=good_slots_resp - ) - n_manager = rc.nodes_manager - assert len(n_manager.slots_cache) == REDIS_CLUSTER_HASH_SLOTS - for slot_info in good_slots_resp: - all_hosts = ["127.0.0.1", "127.0.0.2"] - all_ports = [7000, 7001, 7002, 7003, 7004, 7005] - slot_start = slot_info[0] - slot_end = slot_info[1] - for i in range(slot_start, slot_end + 1): - assert len(n_manager.slots_cache[i]) == len(slot_info[2:]) - assert n_manager.slots_cache[i][0].host in all_hosts - assert n_manager.slots_cache[i][1].host in all_hosts - assert n_manager.slots_cache[i][0].port in all_ports - assert n_manager.slots_cache[i][1].port in all_ports - - assert len(n_manager.nodes_cache) == 6 - - def test_init_slots_cache_cluster_mode_disabled(self): - """ - Test that creating a RedisCluster failes if one of the startup nodes - has cluster mode disabled - """ - with pytest.raises(RedisClusterException) as e: - get_mocked_redis_client( - host=default_host, port=default_port, cluster_enabled=False - ) - assert "Cluster mode is not enabled on this node" in str(e.value) - - def test_empty_startup_nodes(self): - """ - It should not be possible to create a node manager with no nodes - specified - """ - with pytest.raises(RedisClusterException): - NodesManager([]) - - def test_wrong_startup_nodes_type(self): - """ - If something other then a list type itteratable is provided it should - fail - """ - with pytest.raises(RedisClusterException): - NodesManager({}) - - def test_init_slots_cache_slots_collision(self, request): - """ - Test that if 2 nodes do not agree on the same slots setup it should - raise an error. In this test both nodes will say that the first - slots block should be bound to different servers. - """ - with patch.object(NodesManager, "create_redis_node") as create_redis_node: - - def create_mocked_redis_node(host, port, **kwargs): - """ - Helper function to return custom slots cache data from - different redis nodes - """ - if port == 7000: - result = [ - [ - 0, - 5460, - ["127.0.0.1", 7000], - ["127.0.0.1", 7003], - ], - [ - 5461, - 10922, - ["127.0.0.1", 7001], - ["127.0.0.1", 7004], - ], - ] - - elif port == 7001: - result = [ - [ - 0, - 5460, - ["127.0.0.1", 7001], - ["127.0.0.1", 7003], - ], - [ - 5461, - 10922, - ["127.0.0.1", 7000], - ["127.0.0.1", 7004], - ], - ] - else: - result = [] - - r_node = Redis(host=host, port=port) - - orig_execute_command = r_node.execute_command - - def execute_command(*args, **kwargs): - if args[0] == "CLUSTER SLOTS": - return result - elif args[0] == "INFO": - return {"cluster_enabled": True} - elif args[1] == "cluster-require-full-coverage": - return {"cluster-require-full-coverage": "yes"} - else: - return orig_execute_command(*args, **kwargs) - - r_node.execute_command = execute_command - return r_node - - create_redis_node.side_effect = create_mocked_redis_node - - with pytest.raises(RedisClusterException) as ex: - node_1 = ClusterNode("127.0.0.1", 7000) - node_2 = ClusterNode("127.0.0.1", 7001) - RedisCluster(startup_nodes=[node_1, node_2]) - assert str(ex.value).startswith( - "startup_nodes could not agree on a valid slots cache" - ), str(ex.value) - - def test_cluster_one_instance(self): - """ - If the cluster exists of only 1 node then there is some hacks that must - be validated they work. - """ - node = ClusterNode(default_host, default_port) - cluster_slots = [[0, 16383, ["", default_port]]] - rc = get_mocked_redis_client(startup_nodes=[node], cluster_slots=cluster_slots) - - n = rc.nodes_manager - assert len(n.nodes_cache) == 1 - n_node = rc.get_node(node_name=node.name) - assert n_node is not None - assert n_node == node - assert n_node.server_type == PRIMARY - assert len(n.slots_cache) == REDIS_CLUSTER_HASH_SLOTS - for i in range(0, REDIS_CLUSTER_HASH_SLOTS): - assert n.slots_cache[i] == [n_node] - - def test_init_with_down_node(self): - """ - If I can't connect to one of the nodes, everything should still work. - But if I can't connect to any of the nodes, exception should be thrown. - """ - with patch.object(NodesManager, "create_redis_node") as create_redis_node: - - def create_mocked_redis_node(host, port, **kwargs): - if port == 7000: - raise ConnectionError("mock connection error for 7000") - - r_node = Redis(host=host, port=port, decode_responses=True) - - def execute_command(*args, **kwargs): - if args[0] == "CLUSTER SLOTS": - return [ - [ - 0, - 8191, - ["127.0.0.1", 7001, "node_1"], - ], - [ - 8192, - 16383, - ["127.0.0.1", 7002, "node_2"], - ], - ] - elif args[0] == "INFO": - return {"cluster_enabled": True} - elif args[1] == "cluster-require-full-coverage": - return {"cluster-require-full-coverage": "yes"} - - r_node.execute_command = execute_command - - return r_node - - create_redis_node.side_effect = create_mocked_redis_node - - node_1 = ClusterNode("127.0.0.1", 7000) - node_2 = ClusterNode("127.0.0.1", 7001) - - # If all startup nodes fail to connect, connection error should be - # thrown - with pytest.raises(RedisClusterException) as e: - RedisCluster(startup_nodes=[node_1]) - assert "Redis Cluster cannot be connected" in str(e.value) - - with patch.object( - CommandsParser, "initialize", autospec=True - ) as cmd_parser_initialize: - - def cmd_init_mock(self, r): - self.commands = { - "get": { - "name": "get", - "arity": 2, - "flags": ["readonly", "fast"], - "first_key_pos": 1, - "last_key_pos": 1, - "step_count": 1, - } - } - - cmd_parser_initialize.side_effect = cmd_init_mock - # When at least one startup node is reachable, the cluster - # initialization should succeeds - rc = RedisCluster(startup_nodes=[node_1, node_2]) - assert rc.get_node(host=default_host, port=7001) is not None - assert rc.get_node(host=default_host, port=7002) is not None - - -@pytest.mark.onlycluster -class TestClusterPubSubObject: - """ - Tests for the ClusterPubSub class - """ - - def test_init_pubsub_with_host_and_port(self, r): - """ - Test creation of pubsub instance with passed host and port - """ - node = r.get_default_node() - p = r.pubsub(host=node.host, port=node.port) - assert p.get_pubsub_node() == node - - def test_init_pubsub_with_node(self, r): - """ - Test creation of pubsub instance with passed node - """ - node = r.get_default_node() - p = r.pubsub(node=node) - assert p.get_pubsub_node() == node - - def test_init_pubusub_without_specifying_node(self, r): - """ - Test creation of pubsub instance without specifying a node. The node - should be determined based on the keyslot of the first command - execution. - """ - channel_name = "foo" - node = r.get_node_from_key(channel_name) - p = r.pubsub() - assert p.get_pubsub_node() is None - p.subscribe(channel_name) - assert p.get_pubsub_node() == node - - def test_init_pubsub_with_a_non_existent_node(self, r): - """ - Test creation of pubsub instance with node that doesn't exists in the - cluster. RedisClusterException should be raised. - """ - node = ClusterNode("1.1.1.1", 1111) - with pytest.raises(RedisClusterException): - r.pubsub(node) - - def test_init_pubsub_with_a_non_existent_host_port(self, r): - """ - Test creation of pubsub instance with host and port that don't belong - to a node in the cluster. - RedisClusterException should be raised. - """ - with pytest.raises(RedisClusterException): - r.pubsub(host="1.1.1.1", port=1111) - - def test_init_pubsub_host_or_port(self, r): - """ - Test creation of pubsub instance with host but without port, and vice - versa. DataError should be raised. - """ - with pytest.raises(DataError): - r.pubsub(host="localhost") - - with pytest.raises(DataError): - r.pubsub(port=16379) - - def test_get_redis_connection(self, r): - """ - Test that get_redis_connection() returns the redis connection of the - set pubsub node - """ - node = r.get_default_node() - p = r.pubsub(node=node) - assert p.get_redis_connection() == node.redis_connection - - -@pytest.mark.onlycluster -class TestClusterPipeline: - """ - Tests for the ClusterPipeline class - """ - - def test_blocked_methods(self, r): - """ - Currently some method calls on a Cluster pipeline - is blocked when using in cluster mode. - They maybe implemented in the future. - """ - pipe = r.pipeline() - with pytest.raises(RedisClusterException): - pipe.multi() - - with pytest.raises(RedisClusterException): - pipe.immediate_execute_command() - - with pytest.raises(RedisClusterException): - pipe._execute_transaction(None, None, None) - - with pytest.raises(RedisClusterException): - pipe.load_scripts() - - with pytest.raises(RedisClusterException): - pipe.watch() - - with pytest.raises(RedisClusterException): - pipe.unwatch() - - with pytest.raises(RedisClusterException): - pipe.script_load_for_pipeline(None) - - with pytest.raises(RedisClusterException): - pipe.eval() - - def test_blocked_arguments(self, r): - """ - Currently some arguments is blocked when using in cluster mode. - They maybe implemented in the future. - """ - with pytest.raises(RedisClusterException) as ex: - r.pipeline(transaction=True) - - assert ( - str(ex.value).startswith("transaction is deprecated in cluster mode") - is True - ) - - with pytest.raises(RedisClusterException) as ex: - r.pipeline(shard_hint=True) - - assert ( - str(ex.value).startswith("shard_hint is deprecated in cluster mode") is True - ) - - def test_redis_cluster_pipeline(self, r): - """ - Test that we can use a pipeline with the RedisCluster class - """ - with r.pipeline() as pipe: - pipe.set("foo", "bar") - pipe.get("foo") - assert pipe.execute() == [True, b"bar"] - - def test_mget_disabled(self, r): - """ - Test that mget is disabled for ClusterPipeline - """ - with r.pipeline() as pipe: - with pytest.raises(RedisClusterException): - pipe.mget(["a"]) - - def test_mset_disabled(self, r): - """ - Test that mset is disabled for ClusterPipeline - """ - with r.pipeline() as pipe: - with pytest.raises(RedisClusterException): - pipe.mset({"a": 1, "b": 2}) - - def test_rename_disabled(self, r): - """ - Test that rename is disabled for ClusterPipeline - """ - with r.pipeline(transaction=False) as pipe: - with pytest.raises(RedisClusterException): - pipe.rename("a", "b") - - def test_renamenx_disabled(self, r): - """ - Test that renamenx is disabled for ClusterPipeline - """ - with r.pipeline(transaction=False) as pipe: - with pytest.raises(RedisClusterException): - pipe.renamenx("a", "b") - - def test_delete_single(self, r): - """ - Test a single delete operation - """ - r["a"] = 1 - with r.pipeline(transaction=False) as pipe: - pipe.delete("a") - assert pipe.execute() == [1] - - def test_multi_delete_unsupported(self, r): - """ - Test that multi delete operation is unsupported - """ - with r.pipeline(transaction=False) as pipe: - r["a"] = 1 - r["b"] = 2 - with pytest.raises(RedisClusterException): - pipe.delete("a", "b") - - def test_brpoplpush_disabled(self, r): - """ - Test that brpoplpush is disabled for ClusterPipeline - """ - with r.pipeline(transaction=False) as pipe: - with pytest.raises(RedisClusterException): - pipe.brpoplpush() - - def test_rpoplpush_disabled(self, r): - """ - Test that rpoplpush is disabled for ClusterPipeline - """ - with r.pipeline(transaction=False) as pipe: - with pytest.raises(RedisClusterException): - pipe.rpoplpush() - - def test_sort_disabled(self, r): - """ - Test that sort is disabled for ClusterPipeline - """ - with r.pipeline(transaction=False) as pipe: - with pytest.raises(RedisClusterException): - pipe.sort() - - def test_sdiff_disabled(self, r): - """ - Test that sdiff is disabled for ClusterPipeline - """ - with r.pipeline(transaction=False) as pipe: - with pytest.raises(RedisClusterException): - pipe.sdiff() - - def test_sdiffstore_disabled(self, r): - """ - Test that sdiffstore is disabled for ClusterPipeline - """ - with r.pipeline(transaction=False) as pipe: - with pytest.raises(RedisClusterException): - pipe.sdiffstore() - - def test_sinter_disabled(self, r): - """ - Test that sinter is disabled for ClusterPipeline - """ - with r.pipeline(transaction=False) as pipe: - with pytest.raises(RedisClusterException): - pipe.sinter() - - def test_sinterstore_disabled(self, r): - """ - Test that sinterstore is disabled for ClusterPipeline - """ - with r.pipeline(transaction=False) as pipe: - with pytest.raises(RedisClusterException): - pipe.sinterstore() - - def test_smove_disabled(self, r): - """ - Test that move is disabled for ClusterPipeline - """ - with r.pipeline(transaction=False) as pipe: - with pytest.raises(RedisClusterException): - pipe.smove() - - def test_sunion_disabled(self, r): - """ - Test that sunion is disabled for ClusterPipeline - """ - with r.pipeline(transaction=False) as pipe: - with pytest.raises(RedisClusterException): - pipe.sunion() - - def test_sunionstore_disabled(self, r): - """ - Test that sunionstore is disabled for ClusterPipeline - """ - with r.pipeline(transaction=False) as pipe: - with pytest.raises(RedisClusterException): - pipe.sunionstore() - - def test_spfmerge_disabled(self, r): - """ - Test that spfmerge is disabled for ClusterPipeline - """ - with r.pipeline(transaction=False) as pipe: - with pytest.raises(RedisClusterException): - pipe.pfmerge() - - def test_multi_key_operation_with_a_single_slot(self, r): - """ - Test multi key operation with a single slot - """ - pipe = r.pipeline(transaction=False) - pipe.set("a{foo}", 1) - pipe.set("b{foo}", 2) - pipe.set("c{foo}", 3) - pipe.get("a{foo}") - pipe.get("b{foo}") - pipe.get("c{foo}") - - res = pipe.execute() - assert res == [True, True, True, b"1", b"2", b"3"] - - def test_multi_key_operation_with_multi_slots(self, r): - """ - Test multi key operation with more than one slot - """ - pipe = r.pipeline(transaction=False) - pipe.set("a{foo}", 1) - pipe.set("b{foo}", 2) - pipe.set("c{foo}", 3) - pipe.set("bar", 4) - pipe.set("bazz", 5) - pipe.get("a{foo}") - pipe.get("b{foo}") - pipe.get("c{foo}") - pipe.get("bar") - pipe.get("bazz") - res = pipe.execute() - assert res == [True, True, True, True, True, b"1", b"2", b"3", b"4", b"5"] - - def test_connection_error_not_raised(self, r): - """ - Test that the pipeline doesn't raise an error on connection error when - raise_on_error=False - """ - key = "foo" - node = r.get_node_from_key(key, False) - - def raise_connection_error(): - e = ConnectionError("error") - return e - - with r.pipeline() as pipe: - mock_node_resp_func(node, raise_connection_error) - res = pipe.get(key).get(key).execute(raise_on_error=False) - assert node.redis_connection.connection.read_response.called - assert isinstance(res[0], ConnectionError) - - def test_connection_error_raised(self, r): - """ - Test that the pipeline raises an error on connection error when - raise_on_error=True - """ - key = "foo" - node = r.get_node_from_key(key, False) - - def raise_connection_error(): - e = ConnectionError("error") - return e - - with r.pipeline() as pipe: - mock_node_resp_func(node, raise_connection_error) - with pytest.raises(ConnectionError): - pipe.get(key).get(key).execute(raise_on_error=True) - - def test_asking_error(self, r): - """ - Test redirection on ASK error - """ - key = "foo" - first_node = r.get_node_from_key(key, False) - ask_node = None - for node in r.get_nodes(): - if node != first_node: - ask_node = node - break - if ask_node is None: - warnings.warn("skipping this test since the cluster has only one " "node") - return - ask_msg = f"{r.keyslot(key)} {ask_node.host}:{ask_node.port}" - - def raise_ask_error(): - raise AskError(ask_msg) - - with r.pipeline() as pipe: - mock_node_resp_func(first_node, raise_ask_error) - mock_node_resp(ask_node, "MOCK_OK") - res = pipe.get(key).execute() - assert first_node.redis_connection.connection.read_response.called - assert ask_node.redis_connection.connection.read_response.called - assert res == ["MOCK_OK"] - - def test_empty_stack(self, r): - """ - If pipeline is executed with no commands it should - return a empty list. - """ - p = r.pipeline() - result = p.execute() - assert result == [] - - -@pytest.mark.onlycluster -class TestReadOnlyPipeline: - """ - Tests for ClusterPipeline class in readonly mode - """ - - def test_pipeline_readonly(self, r): - """ - On readonly mode, we supports get related stuff only. - """ - r.readonly(target_nodes="all") - r.set("foo71", "a1") # we assume this key is set on 127.0.0.1:7001 - r.zadd("foo88", {"z1": 1}) # we assume this key is set on 127.0.0.1:7002 - r.zadd("foo88", {"z2": 4}) - - with r.pipeline() as readonly_pipe: - readonly_pipe.get("foo71").zrange("foo88", 0, 5, withscores=True) - assert readonly_pipe.execute() == [ - b"a1", - [(b"z1", 1.0), (b"z2", 4)], - ] - - def test_moved_redirection_on_slave_with_default(self, r): - """ - On Pipeline, we redirected once and finally get from master with - readonly client when data is completely moved. - """ - key = "bar" - r.set(key, "foo") - # set read_from_replicas to True - r.read_from_replicas = True - primary = r.get_node_from_key(key, False) - replica = r.get_node_from_key(key, True) - with r.pipeline() as readwrite_pipe: - mock_node_resp(primary, "MOCK_FOO") - if replica is not None: - moved_error = f"{r.keyslot(key)} {primary.host}:{primary.port}" - - def raise_moved_error(): - raise MovedError(moved_error) - - mock_node_resp_func(replica, raise_moved_error) - assert readwrite_pipe.reinitialize_counter == 0 - readwrite_pipe.get(key).get(key) - assert readwrite_pipe.execute() == ["MOCK_FOO", "MOCK_FOO"] - if replica is not None: - # the slot has a replica as well, so MovedError should have - # occurred. If MovedError occurs, we should see the - # reinitialize_counter increase. - assert readwrite_pipe.reinitialize_counter == 1 - conn = replica.redis_connection.connection - assert conn.read_response.called is True - - def test_readonly_pipeline_from_readonly_client(self, request): - """ - Test that the pipeline is initialized with readonly mode if the client - has it enabled - """ - # Create a cluster with reading from replications - ro = _get_client(RedisCluster, request, read_from_replicas=True) - key = "bar" - ro.set(key, "foo") - import time - - time.sleep(0.2) - with ro.pipeline() as readonly_pipe: - mock_all_nodes_resp(ro, "MOCK_OK") - assert readonly_pipe.read_from_replicas is True - assert readonly_pipe.get(key).get(key).execute() == ["MOCK_OK", "MOCK_OK"] - slot_nodes = ro.nodes_manager.slots_cache[ro.keyslot(key)] - if len(slot_nodes) > 1: - executed_on_replica = False - for node in slot_nodes: - if node.server_type == REPLICA: - conn = node.redis_connection.connection - executed_on_replica = conn.read_response.called - if executed_on_replica: - break - assert executed_on_replica is True - - -@pytest.mark.onlycluster -class TestClusterMonitor: - def test_wait_command_not_found(self, r): - "Make sure the wait_for_command func works when command is not found" - key = "foo" - node = r.get_node_from_key(key) - with r.monitor(target_node=node) as m: - response = wait_for_command(r, m, "nothing", key=key) - assert response is None - - def test_response_values(self, r): - db = 0 - key = "foo" - node = r.get_node_from_key(key) - with r.monitor(target_node=node) as m: - r.ping(target_nodes=node) - response = wait_for_command(r, m, "PING", key=key) - assert isinstance(response["time"], float) - assert response["db"] == db - assert response["client_type"] in ("tcp", "unix") - assert isinstance(response["client_address"], str) - assert isinstance(response["client_port"], str) - assert response["command"] == "PING" - - def test_command_with_quoted_key(self, r): - key = "{foo}1" - node = r.get_node_from_key(key) - with r.monitor(node) as m: - r.get('{foo}"bar') - response = wait_for_command(r, m, 'GET {foo}"bar', key=key) - assert response["command"] == 'GET {foo}"bar' - - def test_command_with_binary_data(self, r): - key = "{foo}1" - node = r.get_node_from_key(key) - with r.monitor(target_node=node) as m: - byte_string = b"{foo}bar\x92" - r.get(byte_string) - response = wait_for_command(r, m, "GET {foo}bar\\x92", key=key) - assert response["command"] == "GET {foo}bar\\x92" - - def test_command_with_escaped_data(self, r): - key = "{foo}1" - node = r.get_node_from_key(key) - with r.monitor(target_node=node) as m: - byte_string = b"{foo}bar\\x92" - r.get(byte_string) - response = wait_for_command(r, m, "GET {foo}bar\\\\x92", key=key) - assert response["command"] == "GET {foo}bar\\\\x92" diff --git a/tests/test_commands.py b/tests/test_commands.py index b28b63e..744697f 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -4193,18 +4193,6 @@ class TestRedisCommands: assert r.replicaof("NO ONE") assert r.replicaof("NO", "ONE") - @skip_if_server_version_lt("2.8.0") - def test_sync(self, r): - r2 = redis.Redis(port=6380, decode_responses=False) - res = r2.sync() - assert b"REDIS" in res - - @skip_if_server_version_lt("2.8.0") - def test_psync(self, r): - r2 = redis.Redis(port=6380, decode_responses=False) - res = r2.psync(r2.client_id(), 1) - assert b"FULLRESYNC" in res - @pytest.mark.onlynoncluster class TestBinarySave: diff --git a/tests/test_connection.py b/tests/test_connection.py index d94a815..7da8789 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -21,28 +21,6 @@ def test_invalid_response(r): assert str(cm.value) == f"Protocol Error: {raw!r}" -@skip_if_server_version_lt("4.0.0") -@pytest.mark.redismod -def test_loading_external_modules(modclient): - def inner(): - pass - - modclient.load_external_module("myfuncname", inner) - assert getattr(modclient, "myfuncname") == inner - assert isinstance(getattr(modclient, "myfuncname"), types.FunctionType) - - # and call it - from redis.commands import RedisModuleCommands - - j = RedisModuleCommands.json - modclient.load_external_module("sometestfuncname", j) - - # d = {'hello': 'world!'} - # mod = j(modclient) - # mod.set("fookey", ".", d) - # assert mod.get('fookey') == d - - class TestConnection: def test_disconnect(self): conn = Connection() diff --git a/tests/test_graph.py b/tests/test_graph.py deleted file mode 100644 index c6dc9a4..0000000 --- a/tests/test_graph.py +++ /dev/null @@ -1,477 +0,0 @@ -import pytest - -from redis.commands.graph import Edge, Node, Path -from redis.exceptions import ResponseError - - -@pytest.fixture -def client(modclient): - modclient.flushdb() - return modclient - - -@pytest.mark.redismod -def test_bulk(client): - with pytest.raises(NotImplementedError): - client.graph().bulk() - client.graph().bulk(foo="bar!") - - -@pytest.mark.redismod -def test_graph_creation(client): - graph = client.graph() - - john = Node( - label="person", - properties={ - "name": "John Doe", - "age": 33, - "gender": "male", - "status": "single", - }, - ) - graph.add_node(john) - japan = Node(label="country", properties={"name": "Japan"}) - - graph.add_node(japan) - edge = Edge(john, "visited", japan, properties={"purpose": "pleasure"}) - graph.add_edge(edge) - - graph.commit() - - query = ( - 'MATCH (p:person)-[v:visited {purpose:"pleasure"}]->(c:country) ' - "RETURN p, v, c" - ) - - result = graph.query(query) - - person = result.result_set[0][0] - visit = result.result_set[0][1] - country = result.result_set[0][2] - - assert person == john - assert visit.properties == edge.properties - assert country == japan - - query = """RETURN [1, 2.3, "4", true, false, null]""" - result = graph.query(query) - assert [1, 2.3, "4", True, False, None] == result.result_set[0][0] - - # All done, remove graph. - graph.delete() - - -@pytest.mark.redismod -def test_array_functions(client): - query = """CREATE (p:person{name:'a',age:32, array:[0,1,2]})""" - client.graph().query(query) - - query = """WITH [0,1,2] as x return x""" - result = client.graph().query(query) - assert [0, 1, 2] == result.result_set[0][0] - - query = """MATCH(n) return collect(n)""" - result = client.graph().query(query) - - a = Node( - node_id=0, - label="person", - properties={"name": "a", "age": 32, "array": [0, 1, 2]}, - ) - - assert [a] == result.result_set[0][0] - - -@pytest.mark.redismod -def test_path(client): - node0 = Node(node_id=0, label="L1") - node1 = Node(node_id=1, label="L1") - edge01 = Edge(node0, "R1", node1, edge_id=0, properties={"value": 1}) - - graph = client.graph() - graph.add_node(node0) - graph.add_node(node1) - graph.add_edge(edge01) - graph.flush() - - path01 = Path.new_empty_path().add_node(node0).add_edge(edge01).add_node(node1) - expected_results = [[path01]] - - query = "MATCH p=(:L1)-[:R1]->(:L1) RETURN p ORDER BY p" - result = graph.query(query) - assert expected_results == result.result_set - - -@pytest.mark.redismod -def test_param(client): - params = [1, 2.3, "str", True, False, None, [0, 1, 2]] - query = "RETURN $param" - for param in params: - result = client.graph().query(query, {"param": param}) - expected_results = [[param]] - assert expected_results == result.result_set - - -@pytest.mark.redismod -def test_map(client): - query = "RETURN {a:1, b:'str', c:NULL, d:[1,2,3], e:True, f:{x:1, y:2}}" - - actual = client.graph().query(query).result_set[0][0] - expected = { - "a": 1, - "b": "str", - "c": None, - "d": [1, 2, 3], - "e": True, - "f": {"x": 1, "y": 2}, - } - - assert actual == expected - - -@pytest.mark.redismod -def test_point(client): - query = "RETURN point({latitude: 32.070794860, longitude: 34.820751118})" - expected_lat = 32.070794860 - expected_lon = 34.820751118 - actual = client.graph().query(query).result_set[0][0] - assert abs(actual["latitude"] - expected_lat) < 0.001 - assert abs(actual["longitude"] - expected_lon) < 0.001 - - query = "RETURN point({latitude: 32, longitude: 34.0})" - expected_lat = 32 - expected_lon = 34 - actual = client.graph().query(query).result_set[0][0] - assert abs(actual["latitude"] - expected_lat) < 0.001 - assert abs(actual["longitude"] - expected_lon) < 0.001 - - -@pytest.mark.redismod -def test_index_response(client): - result_set = client.graph().query("CREATE INDEX ON :person(age)") - assert 1 == result_set.indices_created - - result_set = client.graph().query("CREATE INDEX ON :person(age)") - assert 0 == result_set.indices_created - - result_set = client.graph().query("DROP INDEX ON :person(age)") - assert 1 == result_set.indices_deleted - - with pytest.raises(ResponseError): - client.graph().query("DROP INDEX ON :person(age)") - - -@pytest.mark.redismod -def test_stringify_query_result(client): - graph = client.graph() - - john = Node( - alias="a", - label="person", - properties={ - "name": "John Doe", - "age": 33, - "gender": "male", - "status": "single", - }, - ) - graph.add_node(john) - - japan = Node(alias="b", label="country", properties={"name": "Japan"}) - graph.add_node(japan) - - edge = Edge(john, "visited", japan, properties={"purpose": "pleasure"}) - graph.add_edge(edge) - - assert ( - str(john) - == """(a:person{age:33,gender:"male",name:"John Doe",status:"single"})""" # noqa - ) - assert ( - str(edge) - == """(a:person{age:33,gender:"male",name:"John Doe",status:"single"})""" # noqa - + """-[:visited{purpose:"pleasure"}]->""" - + """(b:country{name:"Japan"})""" - ) - assert str(japan) == """(b:country{name:"Japan"})""" - - graph.commit() - - query = """MATCH (p:person)-[v:visited {purpose:"pleasure"}]->(c:country) - RETURN p, v, c""" - - result = client.graph().query(query) - person = result.result_set[0][0] - visit = result.result_set[0][1] - country = result.result_set[0][2] - - assert ( - str(person) - == """(:person{age:33,gender:"male",name:"John Doe",status:"single"})""" # noqa - ) - assert str(visit) == """()-[:visited{purpose:"pleasure"}]->()""" - assert str(country) == """(:country{name:"Japan"})""" - - graph.delete() - - -@pytest.mark.redismod -def test_optional_match(client): - # Build a graph of form (a)-[R]->(b) - node0 = Node(node_id=0, label="L1", properties={"value": "a"}) - node1 = Node(node_id=1, label="L1", properties={"value": "b"}) - - edge01 = Edge(node0, "R", node1, edge_id=0) - - graph = client.graph() - graph.add_node(node0) - graph.add_node(node1) - graph.add_edge(edge01) - graph.flush() - - # Issue a query that collects all outgoing edges from both nodes - # (the second has none) - query = """MATCH (a) OPTIONAL MATCH (a)-[e]->(b) RETURN a, e, b ORDER BY a.value""" # noqa - expected_results = [[node0, edge01, node1], [node1, None, None]] - - result = client.graph().query(query) - assert expected_results == result.result_set - - graph.delete() - - -@pytest.mark.redismod -def test_cached_execution(client): - client.graph().query("CREATE ()") - - uncached_result = client.graph().query("MATCH (n) RETURN n, $param", {"param": [0]}) - assert uncached_result.cached_execution is False - - # loop to make sure the query is cached on each thread on server - for x in range(0, 64): - cached_result = client.graph().query( - "MATCH (n) RETURN n, $param", {"param": [0]} - ) - assert uncached_result.result_set == cached_result.result_set - - # should be cached on all threads by now - assert cached_result.cached_execution - - -@pytest.mark.redismod -def test_explain(client): - create_query = """CREATE (:Rider {name:'Valentino Rossi'})-[:rides]->(:Team {name:'Yamaha'}), - (:Rider {name:'Dani Pedrosa'})-[:rides]->(:Team {name:'Honda'}), - (:Rider {name:'Andrea Dovizioso'})-[:rides]->(:Team {name:'Ducati'})""" - client.graph().query(create_query) - - result = client.graph().explain( - "MATCH (r:Rider)-[:rides]->(t:Team) WHERE t.name = $name RETURN r.name, t.name, $params", # noqa - {"name": "Yehuda"}, - ) - expected = "Results\n Project\n Conditional Traverse | (t:Team)->(r:Rider)\n Filter\n Node By Label Scan | (t:Team)" # noqa - assert result == expected - - -@pytest.mark.redismod -def test_slowlog(client): - create_query = """CREATE (:Rider {name:'Valentino Rossi'})-[:rides]->(:Team {name:'Yamaha'}), - (:Rider {name:'Dani Pedrosa'})-[:rides]->(:Team {name:'Honda'}), - (:Rider {name:'Andrea Dovizioso'})-[:rides]->(:Team {name:'Ducati'})""" - client.graph().query(create_query) - - results = client.graph().slowlog() - assert results[0][1] == "GRAPH.QUERY" - assert results[0][2] == create_query - - -@pytest.mark.redismod -def test_query_timeout(client): - # Build a sample graph with 1000 nodes. - client.graph().query("UNWIND range(0,1000) as val CREATE ({v: val})") - # Issue a long-running query with a 1-millisecond timeout. - with pytest.raises(ResponseError): - client.graph().query("MATCH (a), (b), (c), (d) RETURN *", timeout=1) - assert False is False - - with pytest.raises(Exception): - client.graph().query("RETURN 1", timeout="str") - assert False is False - - -@pytest.mark.redismod -def test_read_only_query(client): - with pytest.raises(Exception): - # Issue a write query, specifying read-only true, - # this call should fail. - client.graph().query("CREATE (p:person {name:'a'})", read_only=True) - assert False is False - - -@pytest.mark.redismod -def test_profile(client): - q = """UNWIND range(1, 3) AS x CREATE (p:Person {v:x})""" - profile = client.graph().profile(q).result_set - assert "Create | Records produced: 3" in profile - assert "Unwind | Records produced: 3" in profile - - q = "MATCH (p:Person) WHERE p.v > 1 RETURN p" - profile = client.graph().profile(q).result_set - assert "Results | Records produced: 2" in profile - assert "Project | Records produced: 2" in profile - assert "Filter | Records produced: 2" in profile - assert "Node By Label Scan | (p:Person) | Records produced: 3" in profile - - -@pytest.mark.redismod -def test_config(client): - config_name = "RESULTSET_SIZE" - config_value = 3 - - # Set configuration - response = client.graph().config(config_name, config_value, set=True) - assert response == "OK" - - # Make sure config been updated. - response = client.graph().config(config_name, set=False) - expected_response = [config_name, config_value] - assert response == expected_response - - config_name = "QUERY_MEM_CAPACITY" - config_value = 1 << 20 # 1MB - - # Set configuration - response = client.graph().config(config_name, config_value, set=True) - assert response == "OK" - - # Make sure config been updated. - response = client.graph().config(config_name, set=False) - expected_response = [config_name, config_value] - assert response == expected_response - - # reset to default - client.graph().config("QUERY_MEM_CAPACITY", 0, set=True) - client.graph().config("RESULTSET_SIZE", -100, set=True) - - -@pytest.mark.redismod -def test_list_keys(client): - result = client.graph().list_keys() - assert result == [] - - client.execute_command("GRAPH.EXPLAIN", "G", "RETURN 1") - result = client.graph().list_keys() - assert result == ["G"] - - client.execute_command("GRAPH.EXPLAIN", "X", "RETURN 1") - result = client.graph().list_keys() - assert result == ["G", "X"] - - client.delete("G") - client.rename("X", "Z") - result = client.graph().list_keys() - assert result == ["Z"] - - client.delete("Z") - result = client.graph().list_keys() - assert result == [] - - -@pytest.mark.redismod -def test_multi_label(client): - redis_graph = client.graph("g") - - node = Node(label=["l", "ll"]) - redis_graph.add_node(node) - redis_graph.commit() - - query = "MATCH (n) RETURN n" - result = redis_graph.query(query) - result_node = result.result_set[0][0] - assert result_node == node - - try: - Node(label=1) - assert False - except AssertionError: - assert True - - try: - Node(label=["l", 1]) - assert False - except AssertionError: - assert True - - -@pytest.mark.redismod -def test_cache_sync(client): - pass - return - # This test verifies that client internal graph schema cache stays - # in sync with the graph schema - # - # Client B will try to get Client A out of sync by: - # 1. deleting the graph - # 2. reconstructing the graph in a different order, this will casuse - # a differance in the current mapping between string IDs and the - # mapping Client A is aware of - # - # Client A should pick up on the changes by comparing graph versions - # and resyncing its cache. - - A = client.graph("cache-sync") - B = client.graph("cache-sync") - - # Build order: - # 1. introduce label 'L' and 'K' - # 2. introduce attribute 'x' and 'q' - # 3. introduce relationship-type 'R' and 'S' - - A.query("CREATE (:L)") - B.query("CREATE (:K)") - A.query("MATCH (n) SET n.x = 1") - B.query("MATCH (n) SET n.q = 1") - A.query("MATCH (n) CREATE (n)-[:R]->()") - B.query("MATCH (n) CREATE (n)-[:S]->()") - - # Cause client A to populate its cache - A.query("MATCH (n)-[e]->() RETURN n, e") - - assert len(A._labels) == 2 - assert len(A._properties) == 2 - assert len(A._relationshipTypes) == 2 - assert A._labels[0] == "L" - assert A._labels[1] == "K" - assert A._properties[0] == "x" - assert A._properties[1] == "q" - assert A._relationshipTypes[0] == "R" - assert A._relationshipTypes[1] == "S" - - # Have client B reconstruct the graph in a different order. - B.delete() - - # Build order: - # 1. introduce relationship-type 'R' - # 2. introduce label 'L' - # 3. introduce attribute 'x' - B.query("CREATE ()-[:S]->()") - B.query("CREATE ()-[:R]->()") - B.query("CREATE (:K)") - B.query("CREATE (:L)") - B.query("MATCH (n) SET n.q = 1") - B.query("MATCH (n) SET n.x = 1") - - # A's internal cached mapping is now out of sync - # issue a query and make sure A's cache is synced. - A.query("MATCH (n)-[e]->() RETURN n, e") - - assert len(A._labels) == 2 - assert len(A._properties) == 2 - assert len(A._relationshipTypes) == 2 - assert A._labels[0] == "K" - assert A._labels[1] == "L" - assert A._properties[0] == "q" - assert A._properties[1] == "x" - assert A._relationshipTypes[0] == "S" - assert A._relationshipTypes[1] == "R" diff --git a/tests/test_graph_utils/__init__.py b/tests/test_graph_utils/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_graph_utils/test_edge.py b/tests/test_graph_utils/test_edge.py deleted file mode 100644 index 42358de..0000000 --- a/tests/test_graph_utils/test_edge.py +++ /dev/null @@ -1,77 +0,0 @@ -import pytest - -from redis.commands.graph import edge, node - - -@pytest.mark.redismod -def test_init(): - - with pytest.raises(AssertionError): - edge.Edge(None, None, None) - edge.Edge(node.Node(), None, None) - edge.Edge(None, None, node.Node()) - - assert isinstance( - edge.Edge(node.Node(node_id=1), None, node.Node(node_id=2)), edge.Edge - ) - - -@pytest.mark.redismod -def test_toString(): - props_result = edge.Edge( - node.Node(), None, node.Node(), properties={"a": "a", "b": 10} - ).toString() - assert props_result == '{a:"a",b:10}' - - no_props_result = edge.Edge( - node.Node(), None, node.Node(), properties={} - ).toString() - assert no_props_result == "" - - -@pytest.mark.redismod -def test_stringify(): - john = node.Node( - alias="a", - label="person", - properties={"name": "John Doe", "age": 33, "someArray": [1, 2, 3]}, - ) - japan = node.Node(alias="b", label="country", properties={"name": "Japan"}) - edge_with_relation = edge.Edge( - john, "visited", japan, properties={"purpose": "pleasure"} - ) - assert ( - '(a:person{age:33,name:"John Doe",someArray:[1, 2, 3]})' - '-[:visited{purpose:"pleasure"}]->' - '(b:country{name:"Japan"})' == str(edge_with_relation) - ) - - edge_no_relation_no_props = edge.Edge(japan, "", john) - assert ( - '(b:country{name:"Japan"})' - "-[]->" - '(a:person{age:33,name:"John Doe",someArray:[1, 2, 3]})' - == str(edge_no_relation_no_props) - ) - - edge_only_props = edge.Edge(john, "", japan, properties={"a": "b", "c": 3}) - assert ( - '(a:person{age:33,name:"John Doe",someArray:[1, 2, 3]})' - '-[{a:"b",c:3}]->' - '(b:country{name:"Japan"})' == str(edge_only_props) - ) - - -@pytest.mark.redismod -def test_comparision(): - node1 = node.Node(node_id=1) - node2 = node.Node(node_id=2) - node3 = node.Node(node_id=3) - - edge1 = edge.Edge(node1, None, node2) - assert edge1 == edge.Edge(node1, None, node2) - assert edge1 != edge.Edge(node1, "bla", node2) - assert edge1 != edge.Edge(node1, None, node3) - assert edge1 != edge.Edge(node3, None, node2) - assert edge1 != edge.Edge(node2, None, node1) - assert edge1 != edge.Edge(node1, None, node2, properties={"a": 10}) diff --git a/tests/test_graph_utils/test_node.py b/tests/test_graph_utils/test_node.py deleted file mode 100644 index faf8ab6..0000000 --- a/tests/test_graph_utils/test_node.py +++ /dev/null @@ -1,52 +0,0 @@ -import pytest - -from redis.commands.graph import node - - -@pytest.fixture -def fixture(): - no_args = node.Node() - no_props = node.Node(node_id=1, alias="alias", label="l") - props_only = node.Node(properties={"a": "a", "b": 10}) - no_label = node.Node(node_id=1, alias="alias", properties={"a": "a"}) - multi_label = node.Node(node_id=1, alias="alias", label=["l", "ll"]) - return no_args, no_props, props_only, no_label, multi_label - - -@pytest.mark.redismod -def test_toString(fixture): - no_args, no_props, props_only, no_label, multi_label = fixture - assert no_args.toString() == "" - assert no_props.toString() == "" - assert props_only.toString() == '{a:"a",b:10}' - assert no_label.toString() == '{a:"a"}' - assert multi_label.toString() == "" - - -@pytest.mark.redismod -def test_stringify(fixture): - no_args, no_props, props_only, no_label, multi_label = fixture - assert str(no_args) == "()" - assert str(no_props) == "(alias:l)" - assert str(props_only) == '({a:"a",b:10})' - assert str(no_label) == '(alias{a:"a"})' - assert str(multi_label) == "(alias:l:ll)" - - -@pytest.mark.redismod -def test_comparision(fixture): - no_args, no_props, props_only, no_label, multi_label = fixture - - assert node.Node() == node.Node() - assert node.Node(node_id=1) == node.Node(node_id=1) - assert node.Node(node_id=1) != node.Node(node_id=2) - assert node.Node(node_id=1, alias="a") == node.Node(node_id=1, alias="b") - assert node.Node(node_id=1, alias="a") == node.Node(node_id=1, alias="a") - assert node.Node(node_id=1, label="a") == node.Node(node_id=1, label="a") - assert node.Node(node_id=1, label="a") != node.Node(node_id=1, label="b") - assert node.Node(node_id=1, alias="a", label="l") == node.Node( - node_id=1, alias="a", label="l" - ) - assert node.Node(alias="a", label="l") != node.Node(alias="a", label="l1") - assert node.Node(properties={"a": 10}) == node.Node(properties={"a": 10}) - assert node.Node() != node.Node(properties={"a": 10}) diff --git a/tests/test_graph_utils/test_path.py b/tests/test_graph_utils/test_path.py deleted file mode 100644 index d581269..0000000 --- a/tests/test_graph_utils/test_path.py +++ /dev/null @@ -1,91 +0,0 @@ -import pytest - -from redis.commands.graph import edge, node, path - - -@pytest.mark.redismod -def test_init(): - with pytest.raises(TypeError): - path.Path(None, None) - path.Path([], None) - path.Path(None, []) - - assert isinstance(path.Path([], []), path.Path) - - -@pytest.mark.redismod -def test_new_empty_path(): - new_empty_path = path.Path.new_empty_path() - assert isinstance(new_empty_path, path.Path) - assert new_empty_path._nodes == [] - assert new_empty_path._edges == [] - - -@pytest.mark.redismod -def test_wrong_flows(): - node_1 = node.Node(node_id=1) - node_2 = node.Node(node_id=2) - node_3 = node.Node(node_id=3) - - edge_1 = edge.Edge(node_1, None, node_2) - edge_2 = edge.Edge(node_1, None, node_3) - - p = path.Path.new_empty_path() - with pytest.raises(AssertionError): - p.add_edge(edge_1) - - p.add_node(node_1) - with pytest.raises(AssertionError): - p.add_node(node_2) - - p.add_edge(edge_1) - with pytest.raises(AssertionError): - p.add_edge(edge_2) - - -@pytest.mark.redismod -def test_nodes_and_edges(): - node_1 = node.Node(node_id=1) - node_2 = node.Node(node_id=2) - edge_1 = edge.Edge(node_1, None, node_2) - - p = path.Path.new_empty_path() - assert p.nodes() == [] - p.add_node(node_1) - assert [] == p.edges() - assert 0 == p.edge_count() - assert [node_1] == p.nodes() - assert node_1 == p.get_node(0) - assert node_1 == p.first_node() - assert node_1 == p.last_node() - assert 1 == p.nodes_count() - p.add_edge(edge_1) - assert [edge_1] == p.edges() - assert 1 == p.edge_count() - assert edge_1 == p.get_relationship(0) - p.add_node(node_2) - assert [node_1, node_2] == p.nodes() - assert node_1 == p.first_node() - assert node_2 == p.last_node() - assert 2 == p.nodes_count() - - -@pytest.mark.redismod -def test_compare(): - node_1 = node.Node(node_id=1) - node_2 = node.Node(node_id=2) - edge_1 = edge.Edge(node_1, None, node_2) - - assert path.Path.new_empty_path() == path.Path.new_empty_path() - assert path.Path(nodes=[node_1, node_2], edges=[edge_1]) == path.Path( - nodes=[node_1, node_2], edges=[edge_1] - ) - assert path.Path(nodes=[node_1], edges=[]) != path.Path(nodes=[], edges=[]) - assert path.Path(nodes=[node_1], edges=[]) != path.Path(nodes=[], edges=[]) - assert path.Path(nodes=[node_1], edges=[]) != path.Path(nodes=[node_2], edges=[]) - assert path.Path(nodes=[node_1], edges=[edge_1]) != path.Path( - nodes=[node_1], edges=[] - ) - assert path.Path(nodes=[node_1], edges=[edge_1]) != path.Path( - nodes=[node_2], edges=[edge_1] - ) diff --git a/tests/test_json.py b/tests/test_json.py deleted file mode 100644 index 6980e67..0000000 --- a/tests/test_json.py +++ /dev/null @@ -1,1432 +0,0 @@ -import pytest - -import redis -from redis import exceptions -from redis.commands.json.decoders import decode_list, unstring -from redis.commands.json.path import Path - -from .conftest import skip_ifmodversion_lt - - -@pytest.fixture -def client(modclient): - modclient.flushdb() - return modclient - - -@pytest.mark.redismod -def test_json_setbinarykey(client): - d = {"hello": "world", b"some": "value"} - with pytest.raises(TypeError): - client.json().set("somekey", Path.rootPath(), d) - assert client.json().set("somekey", Path.rootPath(), d, decode_keys=True) - - -@pytest.mark.redismod -def test_json_setgetdeleteforget(client): - assert client.json().set("foo", Path.rootPath(), "bar") - assert client.json().get("foo") == "bar" - assert client.json().get("baz") is None - assert client.json().delete("foo") == 1 - assert client.json().forget("foo") == 0 # second delete - assert client.exists("foo") == 0 - - -@pytest.mark.redismod -def test_jsonget(client): - client.json().set("foo", Path.rootPath(), "bar") - assert client.json().get("foo") == "bar" - - -@pytest.mark.redismod -def test_json_get_jset(client): - assert client.json().set("foo", Path.rootPath(), "bar") - assert "bar" == client.json().get("foo") - assert client.json().get("baz") is None - assert 1 == client.json().delete("foo") - assert client.exists("foo") == 0 - - -@pytest.mark.redismod -def test_nonascii_setgetdelete(client): - assert client.json().set("notascii", Path.rootPath(), "hyvää-élève") - assert "hyvää-élève" == client.json().get("notascii", no_escape=True) - assert 1 == client.json().delete("notascii") - assert client.exists("notascii") == 0 - - -@pytest.mark.redismod -def test_jsonsetexistentialmodifiersshouldsucceed(client): - obj = {"foo": "bar"} - assert client.json().set("obj", Path.rootPath(), obj) - - # Test that flags prevent updates when conditions are unmet - assert client.json().set("obj", Path("foo"), "baz", nx=True) is None - assert client.json().set("obj", Path("qaz"), "baz", xx=True) is None - - # Test that flags allow updates when conditions are met - assert client.json().set("obj", Path("foo"), "baz", xx=True) - assert client.json().set("obj", Path("qaz"), "baz", nx=True) - - # Test that flags are mutually exlusive - with pytest.raises(Exception): - client.json().set("obj", Path("foo"), "baz", nx=True, xx=True) - - -@pytest.mark.redismod -def test_mgetshouldsucceed(client): - client.json().set("1", Path.rootPath(), 1) - client.json().set("2", Path.rootPath(), 2) - assert client.json().mget(["1"], Path.rootPath()) == [1] - - assert client.json().mget([1, 2], Path.rootPath()) == [1, 2] - - -@pytest.mark.redismod -@skip_ifmodversion_lt("99.99.99", "ReJSON") # todo: update after the release -def test_clear(client): - client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) - assert 1 == client.json().clear("arr", Path.rootPath()) - assert [] == client.json().get("arr") - - -@pytest.mark.redismod -def test_type(client): - client.json().set("1", Path.rootPath(), 1) - assert "integer" == client.json().type("1", Path.rootPath()) - assert "integer" == client.json().type("1") - - -@pytest.mark.redismod -def test_numincrby(client): - client.json().set("num", Path.rootPath(), 1) - assert 2 == client.json().numincrby("num", Path.rootPath(), 1) - assert 2.5 == client.json().numincrby("num", Path.rootPath(), 0.5) - assert 1.25 == client.json().numincrby("num", Path.rootPath(), -1.25) - - -@pytest.mark.redismod -def test_nummultby(client): - client.json().set("num", Path.rootPath(), 1) - - with pytest.deprecated_call(): - assert 2 == client.json().nummultby("num", Path.rootPath(), 2) - assert 5 == client.json().nummultby("num", Path.rootPath(), 2.5) - assert 2.5 == client.json().nummultby("num", Path.rootPath(), 0.5) - - -@pytest.mark.redismod -@skip_ifmodversion_lt("99.99.99", "ReJSON") # todo: update after the release -def test_toggle(client): - client.json().set("bool", Path.rootPath(), False) - assert client.json().toggle("bool", Path.rootPath()) - assert client.json().toggle("bool", Path.rootPath()) is False - # check non-boolean value - client.json().set("num", Path.rootPath(), 1) - with pytest.raises(redis.exceptions.ResponseError): - client.json().toggle("num", Path.rootPath()) - - -@pytest.mark.redismod -def test_strappend(client): - client.json().set("jsonkey", Path.rootPath(), "foo") - assert 6 == client.json().strappend("jsonkey", "bar") - assert "foobar" == client.json().get("jsonkey", Path.rootPath()) - - -# @pytest.mark.redismod -# def test_debug(client): -# client.json().set("str", Path.rootPath(), "foo") -# assert 24 == client.json().debug("MEMORY", "str", Path.rootPath()) -# assert 24 == client.json().debug("MEMORY", "str") -# -# # technically help is valid -# assert isinstance(client.json().debug("HELP"), list) - - -@pytest.mark.redismod -def test_strlen(client): - client.json().set("str", Path.rootPath(), "foo") - assert 3 == client.json().strlen("str", Path.rootPath()) - client.json().strappend("str", "bar", Path.rootPath()) - assert 6 == client.json().strlen("str", Path.rootPath()) - assert 6 == client.json().strlen("str") - - -@pytest.mark.redismod -def test_arrappend(client): - client.json().set("arr", Path.rootPath(), [1]) - assert 2 == client.json().arrappend("arr", Path.rootPath(), 2) - assert 4 == client.json().arrappend("arr", Path.rootPath(), 3, 4) - assert 7 == client.json().arrappend("arr", Path.rootPath(), *[5, 6, 7]) - - -@pytest.mark.redismod -def test_arrindex(client): - client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) - assert 1 == client.json().arrindex("arr", Path.rootPath(), 1) - assert -1 == client.json().arrindex("arr", Path.rootPath(), 1, 2) - - -@pytest.mark.redismod -def test_arrinsert(client): - client.json().set("arr", Path.rootPath(), [0, 4]) - assert 5 - -client.json().arrinsert( - "arr", - Path.rootPath(), - 1, - *[ - 1, - 2, - 3, - ], - ) - assert [0, 1, 2, 3, 4] == client.json().get("arr") - - # test prepends - client.json().set("val2", Path.rootPath(), [5, 6, 7, 8, 9]) - client.json().arrinsert("val2", Path.rootPath(), 0, ["some", "thing"]) - assert client.json().get("val2") == [["some", "thing"], 5, 6, 7, 8, 9] - - -@pytest.mark.redismod -def test_arrlen(client): - client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) - assert 5 == client.json().arrlen("arr", Path.rootPath()) - assert 5 == client.json().arrlen("arr") - assert client.json().arrlen("fakekey") is None - - -@pytest.mark.redismod -def test_arrpop(client): - client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) - assert 4 == client.json().arrpop("arr", Path.rootPath(), 4) - assert 3 == client.json().arrpop("arr", Path.rootPath(), -1) - assert 2 == client.json().arrpop("arr", Path.rootPath()) - assert 0 == client.json().arrpop("arr", Path.rootPath(), 0) - assert [1] == client.json().get("arr") - - # test out of bounds - client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) - assert 4 == client.json().arrpop("arr", Path.rootPath(), 99) - - # none test - client.json().set("arr", Path.rootPath(), []) - assert client.json().arrpop("arr") is None - - -@pytest.mark.redismod -def test_arrtrim(client): - client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) - assert 3 == client.json().arrtrim("arr", Path.rootPath(), 1, 3) - assert [1, 2, 3] == client.json().get("arr") - - # <0 test, should be 0 equivalent - client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) - assert 0 == client.json().arrtrim("arr", Path.rootPath(), -1, 3) - - # testing stop > end - client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) - assert 2 == client.json().arrtrim("arr", Path.rootPath(), 3, 99) - - # start > array size and stop - client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) - assert 0 == client.json().arrtrim("arr", Path.rootPath(), 9, 1) - - # all larger - client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) - assert 0 == client.json().arrtrim("arr", Path.rootPath(), 9, 11) - - -@pytest.mark.redismod -def test_resp(client): - obj = {"foo": "bar", "baz": 1, "qaz": True} - client.json().set("obj", Path.rootPath(), obj) - assert "bar" == client.json().resp("obj", Path("foo")) - assert 1 == client.json().resp("obj", Path("baz")) - assert client.json().resp("obj", Path("qaz")) - assert isinstance(client.json().resp("obj"), list) - - -@pytest.mark.redismod -def test_objkeys(client): - obj = {"foo": "bar", "baz": "qaz"} - client.json().set("obj", Path.rootPath(), obj) - keys = client.json().objkeys("obj", Path.rootPath()) - keys.sort() - exp = list(obj.keys()) - exp.sort() - assert exp == keys - - client.json().set("obj", Path.rootPath(), obj) - keys = client.json().objkeys("obj") - assert keys == list(obj.keys()) - - assert client.json().objkeys("fakekey") is None - - -@pytest.mark.redismod -def test_objlen(client): - obj = {"foo": "bar", "baz": "qaz"} - client.json().set("obj", Path.rootPath(), obj) - assert len(obj) == client.json().objlen("obj", Path.rootPath()) - - client.json().set("obj", Path.rootPath(), obj) - assert len(obj) == client.json().objlen("obj") - - -@pytest.mark.redismod -def test_json_commands_in_pipeline(client): - p = client.json().pipeline() - p.set("foo", Path.rootPath(), "bar") - p.get("foo") - p.delete("foo") - assert [True, "bar", 1] == p.execute() - assert client.keys() == [] - assert client.get("foo") is None - - # now with a true, json object - client.flushdb() - p = client.json().pipeline() - d = {"hello": "world", "oh": "snap"} - with pytest.deprecated_call(): - p.jsonset("foo", Path.rootPath(), d) - p.jsonget("foo") - p.exists("notarealkey") - p.delete("foo") - assert [True, d, 0, 1] == p.execute() - assert client.keys() == [] - assert client.get("foo") is None - - -@pytest.mark.redismod -def test_json_delete_with_dollar(client): - doc1 = {"a": 1, "nested": {"a": 2, "b": 3}} - assert client.json().set("doc1", "$", doc1) - assert client.json().delete("doc1", "$..a") == 2 - r = client.json().get("doc1", "$") - assert r == [{"nested": {"b": 3}}] - - doc2 = {"a": {"a": 2, "b": 3}, "b": ["a", "b"], "nested": {"b": [True, "a", "b"]}} - assert client.json().set("doc2", "$", doc2) - assert client.json().delete("doc2", "$..a") == 1 - res = client.json().get("doc2", "$") - assert res == [{"nested": {"b": [True, "a", "b"]}, "b": ["a", "b"]}] - - doc3 = [ - { - "ciao": ["non ancora"], - "nested": [ - {"ciao": [1, "a"]}, - {"ciao": [2, "a"]}, - {"ciaoc": [3, "non", "ciao"]}, - {"ciao": [4, "a"]}, - {"e": [5, "non", "ciao"]}, - ], - } - ] - assert client.json().set("doc3", "$", doc3) - assert client.json().delete("doc3", '$.[0]["nested"]..ciao') == 3 - - doc3val = [ - [ - { - "ciao": ["non ancora"], - "nested": [ - {}, - {}, - {"ciaoc": [3, "non", "ciao"]}, - {}, - {"e": [5, "non", "ciao"]}, - ], - } - ] - ] - res = client.json().get("doc3", "$") - assert res == doc3val - - # Test default path - assert client.json().delete("doc3") == 1 - assert client.json().get("doc3", "$") is None - - client.json().delete("not_a_document", "..a") - - -@pytest.mark.redismod -def test_json_forget_with_dollar(client): - doc1 = {"a": 1, "nested": {"a": 2, "b": 3}} - assert client.json().set("doc1", "$", doc1) - assert client.json().forget("doc1", "$..a") == 2 - r = client.json().get("doc1", "$") - assert r == [{"nested": {"b": 3}}] - - doc2 = {"a": {"a": 2, "b": 3}, "b": ["a", "b"], "nested": {"b": [True, "a", "b"]}} - assert client.json().set("doc2", "$", doc2) - assert client.json().forget("doc2", "$..a") == 1 - res = client.json().get("doc2", "$") - assert res == [{"nested": {"b": [True, "a", "b"]}, "b": ["a", "b"]}] - - doc3 = [ - { - "ciao": ["non ancora"], - "nested": [ - {"ciao": [1, "a"]}, - {"ciao": [2, "a"]}, - {"ciaoc": [3, "non", "ciao"]}, - {"ciao": [4, "a"]}, - {"e": [5, "non", "ciao"]}, - ], - } - ] - assert client.json().set("doc3", "$", doc3) - assert client.json().forget("doc3", '$.[0]["nested"]..ciao') == 3 - - doc3val = [ - [ - { - "ciao": ["non ancora"], - "nested": [ - {}, - {}, - {"ciaoc": [3, "non", "ciao"]}, - {}, - {"e": [5, "non", "ciao"]}, - ], - } - ] - ] - res = client.json().get("doc3", "$") - assert res == doc3val - - # Test default path - assert client.json().forget("doc3") == 1 - assert client.json().get("doc3", "$") is None - - client.json().forget("not_a_document", "..a") - - -@pytest.mark.redismod -def test_json_mget_dollar(client): - # Test mget with multi paths - client.json().set( - "doc1", - "$", - {"a": 1, "b": 2, "nested": {"a": 3}, "c": None, "nested2": {"a": None}}, - ) - client.json().set( - "doc2", - "$", - {"a": 4, "b": 5, "nested": {"a": 6}, "c": None, "nested2": {"a": [None]}}, - ) - # Compare also to single JSON.GET - assert client.json().get("doc1", "$..a") == [1, 3, None] - assert client.json().get("doc2", "$..a") == [4, 6, [None]] - - # Test mget with single path - client.json().mget("doc1", "$..a") == [1, 3, None] - # Test mget with multi path - client.json().mget(["doc1", "doc2"], "$..a") == [[1, 3, None], [4, 6, [None]]] - - # Test missing key - client.json().mget(["doc1", "missing_doc"], "$..a") == [[1, 3, None], None] - res = client.json().mget(["missing_doc1", "missing_doc2"], "$..a") - assert res == [None, None] - - -@pytest.mark.redismod -def test_numby_commands_dollar(client): - - # Test NUMINCRBY - client.json().set("doc1", "$", {"a": "b", "b": [{"a": 2}, {"a": 5.0}, {"a": "c"}]}) - # Test multi - assert client.json().numincrby("doc1", "$..a", 2) == [None, 4, 7.0, None] - - assert client.json().numincrby("doc1", "$..a", 2.5) == [None, 6.5, 9.5, None] - # Test single - assert client.json().numincrby("doc1", "$.b[1].a", 2) == [11.5] - - assert client.json().numincrby("doc1", "$.b[2].a", 2) == [None] - assert client.json().numincrby("doc1", "$.b[1].a", 3.5) == [15.0] - - # Test NUMMULTBY - client.json().set("doc1", "$", {"a": "b", "b": [{"a": 2}, {"a": 5.0}, {"a": "c"}]}) - - # test list - with pytest.deprecated_call(): - assert client.json().nummultby("doc1", "$..a", 2) == [None, 4, 10, None] - assert client.json().nummultby("doc1", "$..a", 2.5) == [None, 10.0, 25.0, None] - - # Test single - with pytest.deprecated_call(): - assert client.json().nummultby("doc1", "$.b[1].a", 2) == [50.0] - assert client.json().nummultby("doc1", "$.b[2].a", 2) == [None] - assert client.json().nummultby("doc1", "$.b[1].a", 3) == [150.0] - - # test missing keys - with pytest.raises(exceptions.ResponseError): - client.json().numincrby("non_existing_doc", "$..a", 2) - client.json().nummultby("non_existing_doc", "$..a", 2) - - # Test legacy NUMINCRBY - client.json().set("doc1", "$", {"a": "b", "b": [{"a": 2}, {"a": 5.0}, {"a": "c"}]}) - client.json().numincrby("doc1", ".b[0].a", 3) == 5 - - # Test legacy NUMMULTBY - client.json().set("doc1", "$", {"a": "b", "b": [{"a": 2}, {"a": 5.0}, {"a": "c"}]}) - - with pytest.deprecated_call(): - client.json().nummultby("doc1", ".b[0].a", 3) == 6 - - -@pytest.mark.redismod -def test_strappend_dollar(client): - - client.json().set( - "doc1", "$", {"a": "foo", "nested1": {"a": "hello"}, "nested2": {"a": 31}} - ) - # Test multi - client.json().strappend("doc1", "bar", "$..a") == [6, 8, None] - - client.json().get("doc1", "$") == [ - {"a": "foobar", "nested1": {"a": "hellobar"}, "nested2": {"a": 31}} - ] - # Test single - client.json().strappend("doc1", "baz", "$.nested1.a") == [11] - - client.json().get("doc1", "$") == [ - {"a": "foobar", "nested1": {"a": "hellobarbaz"}, "nested2": {"a": 31}} - ] - - # Test missing key - with pytest.raises(exceptions.ResponseError): - client.json().strappend("non_existing_doc", "$..a", "err") - - # Test multi - client.json().strappend("doc1", "bar", ".*.a") == 8 - client.json().get("doc1", "$") == [ - {"a": "foo", "nested1": {"a": "hellobar"}, "nested2": {"a": 31}} - ] - - # Test missing path - with pytest.raises(exceptions.ResponseError): - client.json().strappend("doc1", "piu") - - -@pytest.mark.redismod -def test_strlen_dollar(client): - - # Test multi - client.json().set( - "doc1", "$", {"a": "foo", "nested1": {"a": "hello"}, "nested2": {"a": 31}} - ) - assert client.json().strlen("doc1", "$..a") == [3, 5, None] - - res2 = client.json().strappend("doc1", "bar", "$..a") - res1 = client.json().strlen("doc1", "$..a") - assert res1 == res2 - - # Test single - client.json().strlen("doc1", "$.nested1.a") == [8] - client.json().strlen("doc1", "$.nested2.a") == [None] - - # Test missing key - with pytest.raises(exceptions.ResponseError): - client.json().strlen("non_existing_doc", "$..a") - - -@pytest.mark.redismod -def test_arrappend_dollar(client): - client.json().set( - "doc1", - "$", - { - "a": ["foo"], - "nested1": {"a": ["hello", None, "world"]}, - "nested2": {"a": 31}, - }, - ) - # Test multi - client.json().arrappend("doc1", "$..a", "bar", "racuda") == [3, 5, None] - assert client.json().get("doc1", "$") == [ - { - "a": ["foo", "bar", "racuda"], - "nested1": {"a": ["hello", None, "world", "bar", "racuda"]}, - "nested2": {"a": 31}, - } - ] - - # Test single - assert client.json().arrappend("doc1", "$.nested1.a", "baz") == [6] - assert client.json().get("doc1", "$") == [ - { - "a": ["foo", "bar", "racuda"], - "nested1": {"a": ["hello", None, "world", "bar", "racuda", "baz"]}, - "nested2": {"a": 31}, - } - ] - - # Test missing key - with pytest.raises(exceptions.ResponseError): - client.json().arrappend("non_existing_doc", "$..a") - - # Test legacy - client.json().set( - "doc1", - "$", - { - "a": ["foo"], - "nested1": {"a": ["hello", None, "world"]}, - "nested2": {"a": 31}, - }, - ) - # Test multi (all paths are updated, but return result of last path) - assert client.json().arrappend("doc1", "..a", "bar", "racuda") == 5 - - assert client.json().get("doc1", "$") == [ - { - "a": ["foo", "bar", "racuda"], - "nested1": {"a": ["hello", None, "world", "bar", "racuda"]}, - "nested2": {"a": 31}, - } - ] - # Test single - assert client.json().arrappend("doc1", ".nested1.a", "baz") == 6 - assert client.json().get("doc1", "$") == [ - { - "a": ["foo", "bar", "racuda"], - "nested1": {"a": ["hello", None, "world", "bar", "racuda", "baz"]}, - "nested2": {"a": 31}, - } - ] - - # Test missing key - with pytest.raises(exceptions.ResponseError): - client.json().arrappend("non_existing_doc", "$..a") - - -@pytest.mark.redismod -def test_arrinsert_dollar(client): - client.json().set( - "doc1", - "$", - { - "a": ["foo"], - "nested1": {"a": ["hello", None, "world"]}, - "nested2": {"a": 31}, - }, - ) - # Test multi - assert client.json().arrinsert("doc1", "$..a", "1", "bar", "racuda") == [3, 5, None] - - assert client.json().get("doc1", "$") == [ - { - "a": ["foo", "bar", "racuda"], - "nested1": {"a": ["hello", "bar", "racuda", None, "world"]}, - "nested2": {"a": 31}, - } - ] - # Test single - assert client.json().arrinsert("doc1", "$.nested1.a", -2, "baz") == [6] - assert client.json().get("doc1", "$") == [ - { - "a": ["foo", "bar", "racuda"], - "nested1": {"a": ["hello", "bar", "racuda", "baz", None, "world"]}, - "nested2": {"a": 31}, - } - ] - - # Test missing key - with pytest.raises(exceptions.ResponseError): - client.json().arrappend("non_existing_doc", "$..a") - - -@pytest.mark.redismod -def test_arrlen_dollar(client): - - client.json().set( - "doc1", - "$", - { - "a": ["foo"], - "nested1": {"a": ["hello", None, "world"]}, - "nested2": {"a": 31}, - }, - ) - - # Test multi - assert client.json().arrlen("doc1", "$..a") == [1, 3, None] - assert client.json().arrappend("doc1", "$..a", "non", "abba", "stanza") == [ - 4, - 6, - None, - ] - - client.json().clear("doc1", "$.a") - assert client.json().arrlen("doc1", "$..a") == [0, 6, None] - # Test single - assert client.json().arrlen("doc1", "$.nested1.a") == [6] - - # Test missing key - with pytest.raises(exceptions.ResponseError): - client.json().arrappend("non_existing_doc", "$..a") - - client.json().set( - "doc1", - "$", - { - "a": ["foo"], - "nested1": {"a": ["hello", None, "world"]}, - "nested2": {"a": 31}, - }, - ) - # Test multi (return result of last path) - assert client.json().arrlen("doc1", "$..a") == [1, 3, None] - assert client.json().arrappend("doc1", "..a", "non", "abba", "stanza") == 6 - - # Test single - assert client.json().arrlen("doc1", ".nested1.a") == 6 - - # Test missing key - assert client.json().arrlen("non_existing_doc", "..a") is None - - -@pytest.mark.redismod -def test_arrpop_dollar(client): - client.json().set( - "doc1", - "$", - { - "a": ["foo"], - "nested1": {"a": ["hello", None, "world"]}, - "nested2": {"a": 31}, - }, - ) - - # # # Test multi - assert client.json().arrpop("doc1", "$..a", 1) == ['"foo"', None, None] - - assert client.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": ["hello", "world"]}, "nested2": {"a": 31}} - ] - - # Test missing key - with pytest.raises(exceptions.ResponseError): - client.json().arrpop("non_existing_doc", "..a") - - # # Test legacy - client.json().set( - "doc1", - "$", - { - "a": ["foo"], - "nested1": {"a": ["hello", None, "world"]}, - "nested2": {"a": 31}, - }, - ) - # Test multi (all paths are updated, but return result of last path) - client.json().arrpop("doc1", "..a", "1") is None - assert client.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": ["hello", "world"]}, "nested2": {"a": 31}} - ] - - # # Test missing key - with pytest.raises(exceptions.ResponseError): - client.json().arrpop("non_existing_doc", "..a") - - -@pytest.mark.redismod -def test_arrtrim_dollar(client): - - client.json().set( - "doc1", - "$", - { - "a": ["foo"], - "nested1": {"a": ["hello", None, "world"]}, - "nested2": {"a": 31}, - }, - ) - # Test multi - assert client.json().arrtrim("doc1", "$..a", "1", -1) == [0, 2, None] - assert client.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": [None, "world"]}, "nested2": {"a": 31}} - ] - - assert client.json().arrtrim("doc1", "$..a", "1", "1") == [0, 1, None] - assert client.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": ["world"]}, "nested2": {"a": 31}} - ] - # Test single - assert client.json().arrtrim("doc1", "$.nested1.a", 1, 0) == [0] - assert client.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": []}, "nested2": {"a": 31}} - ] - - # Test missing key - with pytest.raises(exceptions.ResponseError): - client.json().arrtrim("non_existing_doc", "..a", "0", 1) - - # Test legacy - client.json().set( - "doc1", - "$", - { - "a": ["foo"], - "nested1": {"a": ["hello", None, "world"]}, - "nested2": {"a": 31}, - }, - ) - - # Test multi (all paths are updated, but return result of last path) - assert client.json().arrtrim("doc1", "..a", "1", "-1") == 2 - - # Test single - assert client.json().arrtrim("doc1", ".nested1.a", "1", "1") == 1 - assert client.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": ["world"]}, "nested2": {"a": 31}} - ] - - # Test missing key - with pytest.raises(exceptions.ResponseError): - client.json().arrtrim("non_existing_doc", "..a", 1, 1) - - -@pytest.mark.redismod -def test_objkeys_dollar(client): - client.json().set( - "doc1", - "$", - { - "nested1": {"a": {"foo": 10, "bar": 20}}, - "a": ["foo"], - "nested2": {"a": {"baz": 50}}, - }, - ) - - # Test single - assert client.json().objkeys("doc1", "$.nested1.a") == [["foo", "bar"]] - - # Test legacy - assert client.json().objkeys("doc1", ".*.a") == ["foo", "bar"] - # Test single - assert client.json().objkeys("doc1", ".nested2.a") == ["baz"] - - # Test missing key - assert client.json().objkeys("non_existing_doc", "..a") is None - - # Test non existing doc - with pytest.raises(exceptions.ResponseError): - assert client.json().objkeys("non_existing_doc", "$..a") == [] - - assert client.json().objkeys("doc1", "$..nowhere") == [] - - -@pytest.mark.redismod -def test_objlen_dollar(client): - client.json().set( - "doc1", - "$", - { - "nested1": {"a": {"foo": 10, "bar": 20}}, - "a": ["foo"], - "nested2": {"a": {"baz": 50}}, - }, - ) - # Test multi - assert client.json().objlen("doc1", "$..a") == [2, None, 1] - # Test single - assert client.json().objlen("doc1", "$.nested1.a") == [2] - - # Test missing key, and path - with pytest.raises(exceptions.ResponseError): - client.json().objlen("non_existing_doc", "$..a") - - assert client.json().objlen("doc1", "$.nowhere") == [] - - # Test legacy - assert client.json().objlen("doc1", ".*.a") == 2 - - # Test single - assert client.json().objlen("doc1", ".nested2.a") == 1 - - # Test missing key - assert client.json().objlen("non_existing_doc", "..a") is None - - # Test missing path - # with pytest.raises(exceptions.ResponseError): - client.json().objlen("doc1", ".nowhere") - - -@pytest.mark.redismod -def load_types_data(nested_key_name): - td = { - "object": {}, - "array": [], - "string": "str", - "integer": 42, - "number": 1.2, - "boolean": False, - "null": None, - } - jdata = {} - types = [] - for i, (k, v) in zip(range(1, len(td) + 1), iter(td.items())): - jdata["nested" + str(i)] = {nested_key_name: v} - types.append(k) - - return jdata, types - - -@pytest.mark.redismod -def test_type_dollar(client): - jdata, jtypes = load_types_data("a") - client.json().set("doc1", "$", jdata) - # Test multi - assert client.json().type("doc1", "$..a") == jtypes - - # Test single - assert client.json().type("doc1", "$.nested2.a") == [jtypes[1]] - - # Test missing key - assert client.json().type("non_existing_doc", "..a") is None - - -@pytest.mark.redismod -def test_clear_dollar(client): - - client.json().set( - "doc1", - "$", - { - "nested1": {"a": {"foo": 10, "bar": 20}}, - "a": ["foo"], - "nested2": {"a": "claro"}, - "nested3": {"a": {"baz": 50}}, - }, - ) - # Test multi - assert client.json().clear("doc1", "$..a") == 3 - - assert client.json().get("doc1", "$") == [ - {"nested1": {"a": {}}, "a": [], "nested2": {"a": "claro"}, "nested3": {"a": {}}} - ] - - # Test single - client.json().set( - "doc1", - "$", - { - "nested1": {"a": {"foo": 10, "bar": 20}}, - "a": ["foo"], - "nested2": {"a": "claro"}, - "nested3": {"a": {"baz": 50}}, - }, - ) - assert client.json().clear("doc1", "$.nested1.a") == 1 - assert client.json().get("doc1", "$") == [ - { - "nested1": {"a": {}}, - "a": ["foo"], - "nested2": {"a": "claro"}, - "nested3": {"a": {"baz": 50}}, - } - ] - - # Test missing path (defaults to root) - assert client.json().clear("doc1") == 1 - assert client.json().get("doc1", "$") == [{}] - - # Test missing key - with pytest.raises(exceptions.ResponseError): - client.json().clear("non_existing_doc", "$..a") - - -@pytest.mark.redismod -def test_toggle_dollar(client): - client.json().set( - "doc1", - "$", - { - "a": ["foo"], - "nested1": {"a": False}, - "nested2": {"a": 31}, - "nested3": {"a": True}, - }, - ) - # Test multi - assert client.json().toggle("doc1", "$..a") == [None, 1, None, 0] - assert client.json().get("doc1", "$") == [ - { - "a": ["foo"], - "nested1": {"a": True}, - "nested2": {"a": 31}, - "nested3": {"a": False}, - } - ] - - # Test missing key - with pytest.raises(exceptions.ResponseError): - client.json().toggle("non_existing_doc", "$..a") - - -# @pytest.mark.redismod -# def test_debug_dollar(client): -# -# jdata, jtypes = load_types_data("a") -# -# client.json().set("doc1", "$", jdata) -# -# # Test multi -# assert client.json().debug("MEMORY", "doc1", "$..a") == [72, 24, 24, 16, 16, 1, 0] -# -# # Test single -# assert client.json().debug("MEMORY", "doc1", "$.nested2.a") == [24] -# -# # Test legacy -# assert client.json().debug("MEMORY", "doc1", "..a") == 72 -# -# # Test missing path (defaults to root) -# assert client.json().debug("MEMORY", "doc1") == 72 -# -# # Test missing key -# assert client.json().debug("MEMORY", "non_existing_doc", "$..a") == [] - - -@pytest.mark.redismod -def test_resp_dollar(client): - - data = { - "L1": { - "a": { - "A1_B1": 10, - "A1_B2": False, - "A1_B3": { - "A1_B3_C1": None, - "A1_B3_C2": [ - "A1_B3_C2_D1_1", - "A1_B3_C2_D1_2", - -19.5, - "A1_B3_C2_D1_4", - "A1_B3_C2_D1_5", - {"A1_B3_C2_D1_6_E1": True}, - ], - "A1_B3_C3": [1], - }, - "A1_B4": { - "A1_B4_C1": "foo", - }, - }, - }, - "L2": { - "a": { - "A2_B1": 20, - "A2_B2": False, - "A2_B3": { - "A2_B3_C1": None, - "A2_B3_C2": [ - "A2_B3_C2_D1_1", - "A2_B3_C2_D1_2", - -37.5, - "A2_B3_C2_D1_4", - "A2_B3_C2_D1_5", - {"A2_B3_C2_D1_6_E1": False}, - ], - "A2_B3_C3": [2], - }, - "A2_B4": { - "A2_B4_C1": "bar", - }, - }, - }, - } - client.json().set("doc1", "$", data) - # Test multi - res = client.json().resp("doc1", "$..a") - assert res == [ - [ - "{", - "A1_B1", - 10, - "A1_B2", - "false", - "A1_B3", - [ - "{", - "A1_B3_C1", - None, - "A1_B3_C2", - [ - "[", - "A1_B3_C2_D1_1", - "A1_B3_C2_D1_2", - "-19.5", - "A1_B3_C2_D1_4", - "A1_B3_C2_D1_5", - ["{", "A1_B3_C2_D1_6_E1", "true"], - ], - "A1_B3_C3", - ["[", 1], - ], - "A1_B4", - ["{", "A1_B4_C1", "foo"], - ], - [ - "{", - "A2_B1", - 20, - "A2_B2", - "false", - "A2_B3", - [ - "{", - "A2_B3_C1", - None, - "A2_B3_C2", - [ - "[", - "A2_B3_C2_D1_1", - "A2_B3_C2_D1_2", - "-37.5", - "A2_B3_C2_D1_4", - "A2_B3_C2_D1_5", - ["{", "A2_B3_C2_D1_6_E1", "false"], - ], - "A2_B3_C3", - ["[", 2], - ], - "A2_B4", - ["{", "A2_B4_C1", "bar"], - ], - ] - - # Test single - resSingle = client.json().resp("doc1", "$.L1.a") - assert resSingle == [ - [ - "{", - "A1_B1", - 10, - "A1_B2", - "false", - "A1_B3", - [ - "{", - "A1_B3_C1", - None, - "A1_B3_C2", - [ - "[", - "A1_B3_C2_D1_1", - "A1_B3_C2_D1_2", - "-19.5", - "A1_B3_C2_D1_4", - "A1_B3_C2_D1_5", - ["{", "A1_B3_C2_D1_6_E1", "true"], - ], - "A1_B3_C3", - ["[", 1], - ], - "A1_B4", - ["{", "A1_B4_C1", "foo"], - ] - ] - - # Test missing path - client.json().resp("doc1", "$.nowhere") - - # Test missing key - # with pytest.raises(exceptions.ResponseError): - client.json().resp("non_existing_doc", "$..a") - - -@pytest.mark.redismod -def test_arrindex_dollar(client): - - client.json().set( - "store", - "$", - { - "store": { - "book": [ - { - "category": "reference", - "author": "Nigel Rees", - "title": "Sayings of the Century", - "price": 8.95, - "size": [10, 20, 30, 40], - }, - { - "category": "fiction", - "author": "Evelyn Waugh", - "title": "Sword of Honour", - "price": 12.99, - "size": [50, 60, 70, 80], - }, - { - "category": "fiction", - "author": "Herman Melville", - "title": "Moby Dick", - "isbn": "0-553-21311-3", - "price": 8.99, - "size": [5, 10, 20, 30], - }, - { - "category": "fiction", - "author": "J. R. R. Tolkien", - "title": "The Lord of the Rings", - "isbn": "0-395-19395-8", - "price": 22.99, - "size": [5, 6, 7, 8], - }, - ], - "bicycle": {"color": "red", "price": 19.95}, - } - }, - ) - - assert client.json().get("store", "$.store.book[?(@.price<10)].size") == [ - [10, 20, 30, 40], - [5, 10, 20, 30], - ] - assert client.json().arrindex( - "store", "$.store.book[?(@.price<10)].size", "20" - ) == [-1, -1] - - # Test index of int scalar in multi values - client.json().set( - "test_num", - ".", - [ - {"arr": [0, 1, 3.0, 3, 2, 1, 0, 3]}, - {"nested1_found": {"arr": [5, 4, 3, 2, 1, 0, 1, 2, 3.0, 2, 4, 5]}}, - {"nested2_not_found": {"arr": [2, 4, 6]}}, - {"nested3_scalar": {"arr": "3"}}, - [ - {"nested41_not_arr": {"arr_renamed": [1, 2, 3]}}, - {"nested42_empty_arr": {"arr": []}}, - ], - ], - ) - - assert client.json().get("test_num", "$..arr") == [ - [0, 1, 3.0, 3, 2, 1, 0, 3], - [5, 4, 3, 2, 1, 0, 1, 2, 3.0, 2, 4, 5], - [2, 4, 6], - "3", - [], - ] - - assert client.json().arrindex("test_num", "$..arr", 3) == [3, 2, -1, None, -1] - - # Test index of double scalar in multi values - assert client.json().arrindex("test_num", "$..arr", 3.0) == [2, 8, -1, None, -1] - - # Test index of string scalar in multi values - client.json().set( - "test_string", - ".", - [ - {"arr": ["bazzz", "bar", 2, "baz", 2, "ba", "baz", 3]}, - { - "nested1_found": { - "arr": [None, "baz2", "buzz", 2, 1, 0, 1, "2", "baz", 2, 4, 5] - } - }, - {"nested2_not_found": {"arr": ["baz2", 4, 6]}}, - {"nested3_scalar": {"arr": "3"}}, - [ - {"nested41_arr": {"arr_renamed": [1, "baz", 3]}}, - {"nested42_empty_arr": {"arr": []}}, - ], - ], - ) - assert client.json().get("test_string", "$..arr") == [ - ["bazzz", "bar", 2, "baz", 2, "ba", "baz", 3], - [None, "baz2", "buzz", 2, 1, 0, 1, "2", "baz", 2, 4, 5], - ["baz2", 4, 6], - "3", - [], - ] - - assert client.json().arrindex("test_string", "$..arr", "baz") == [ - 3, - 8, - -1, - None, - -1, - ] - - assert client.json().arrindex("test_string", "$..arr", "baz", 2) == [ - 3, - 8, - -1, - None, - -1, - ] - assert client.json().arrindex("test_string", "$..arr", "baz", 4) == [ - 6, - 8, - -1, - None, - -1, - ] - assert client.json().arrindex("test_string", "$..arr", "baz", -5) == [ - 3, - 8, - -1, - None, - -1, - ] - assert client.json().arrindex("test_string", "$..arr", "baz", 4, 7) == [ - 6, - -1, - -1, - None, - -1, - ] - assert client.json().arrindex("test_string", "$..arr", "baz", 4, -1) == [ - 6, - 8, - -1, - None, - -1, - ] - assert client.json().arrindex("test_string", "$..arr", "baz", 4, 0) == [ - 6, - 8, - -1, - None, - -1, - ] - assert client.json().arrindex("test_string", "$..arr", "5", 7, -1) == [ - -1, - -1, - -1, - None, - -1, - ] - assert client.json().arrindex("test_string", "$..arr", "5", 7, 0) == [ - -1, - -1, - -1, - None, - -1, - ] - - # Test index of None scalar in multi values - client.json().set( - "test_None", - ".", - [ - {"arr": ["bazzz", "None", 2, None, 2, "ba", "baz", 3]}, - { - "nested1_found": { - "arr": ["zaz", "baz2", "buzz", 2, 1, 0, 1, "2", None, 2, 4, 5] - } - }, - {"nested2_not_found": {"arr": ["None", 4, 6]}}, - {"nested3_scalar": {"arr": None}}, - [ - {"nested41_arr": {"arr_renamed": [1, None, 3]}}, - {"nested42_empty_arr": {"arr": []}}, - ], - ], - ) - assert client.json().get("test_None", "$..arr") == [ - ["bazzz", "None", 2, None, 2, "ba", "baz", 3], - ["zaz", "baz2", "buzz", 2, 1, 0, 1, "2", None, 2, 4, 5], - ["None", 4, 6], - None, - [], - ] - - # Fail with none-scalar value - with pytest.raises(exceptions.ResponseError): - client.json().arrindex("test_None", "$..nested42_empty_arr.arr", {"arr": []}) - - # Do not fail with none-scalar value in legacy mode - assert ( - client.json().arrindex( - "test_None", ".[4][1].nested42_empty_arr.arr", '{"arr":[]}' - ) - == -1 - ) - - # Test legacy (path begins with dot) - # Test index of int scalar in single value - assert client.json().arrindex("test_num", ".[0].arr", 3) == 3 - assert client.json().arrindex("test_num", ".[0].arr", 9) == -1 - - with pytest.raises(exceptions.ResponseError): - client.json().arrindex("test_num", ".[0].arr_not", 3) - # Test index of string scalar in single value - assert client.json().arrindex("test_string", ".[0].arr", "baz") == 3 - assert client.json().arrindex("test_string", ".[0].arr", "faz") == -1 - # Test index of None scalar in single value - assert client.json().arrindex("test_None", ".[0].arr", "None") == 1 - assert client.json().arrindex("test_None", "..nested2_not_found.arr", "None") == 0 - - -@pytest.mark.redismod -def test_decoders_and_unstring(): - assert unstring("4") == 4 - assert unstring("45.55") == 45.55 - assert unstring("hello world") == "hello world" - - assert decode_list(b"45.55") == 45.55 - assert decode_list("45.55") == 45.55 - assert decode_list(["hello", b"world"]) == ["hello", "world"] - - -@pytest.mark.redismod -def test_custom_decoder(client): - import json - - import ujson - - cj = client.json(encoder=ujson, decoder=ujson) - assert cj.set("foo", Path.rootPath(), "bar") - assert "bar" == cj.get("foo") - assert cj.get("baz") is None - assert 1 == cj.delete("foo") - assert client.exists("foo") == 0 - assert not isinstance(cj.__encoder__, json.JSONEncoder) - assert not isinstance(cj.__decoder__, json.JSONDecoder) - - -@pytest.mark.redismod -def test_set_file(client): - import json - import tempfile - - obj = {"hello": "world"} - jsonfile = tempfile.NamedTemporaryFile(suffix=".json") - with open(jsonfile.name, "w+") as fp: - fp.write(json.dumps(obj)) - - nojsonfile = tempfile.NamedTemporaryFile() - nojsonfile.write(b"Hello World") - - assert client.json().set_file("test", Path.rootPath(), jsonfile.name) - assert client.json().get("test") == obj - with pytest.raises(json.JSONDecodeError): - client.json().set_file("test2", Path.rootPath(), nojsonfile.name) - - -@pytest.mark.redismod -def test_set_path(client): - import json - import tempfile - - root = tempfile.mkdtemp() - sub = tempfile.mkdtemp(dir=root) - jsonfile = tempfile.mktemp(suffix=".json", dir=sub) - nojsonfile = tempfile.mktemp(dir=root) - - with open(jsonfile, "w+") as fp: - fp.write(json.dumps({"hello": "world"})) - open(nojsonfile, "a+").write("hello") - - result = {jsonfile: True, nojsonfile: False} - assert client.json().set_path(Path.rootPath(), root) == result - assert client.json().get(jsonfile.rsplit(".")[0]) == {"hello": "world"} diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index 23af461..9b983c9 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -597,14 +597,3 @@ class TestPubSubWorkerThread: pubsub_thread.join(timeout=1.0) assert not pubsub_thread.is_alive() - -class TestPubSubDeadlock: - @pytest.mark.timeout(30, method="thread") - def test_pubsub_deadlock(self, master_host): - pool = redis.ConnectionPool(host=master_host[0], port=master_host[1]) - r = redis.Redis(connection_pool=pool) - - for i in range(60): - p = r.pubsub() - p.subscribe("my-channel-1", "my-channel-2") - pool.reset() diff --git a/tests/test_search.py b/tests/test_search.py deleted file mode 100644 index 7d666cb..0000000 --- a/tests/test_search.py +++ /dev/null @@ -1,1457 +0,0 @@ -import bz2 -import csv -import os -import time -from io import TextIOWrapper - -import pytest - -import redis -import redis.commands.search -import redis.commands.search.aggregation as aggregations -import redis.commands.search.reducers as reducers -from redis import Redis -from redis.commands.json.path import Path -from redis.commands.search import Search -from redis.commands.search.field import GeoField, NumericField, TagField, TextField -from redis.commands.search.indexDefinition import IndexDefinition, IndexType -from redis.commands.search.query import GeoFilter, NumericFilter, Query -from redis.commands.search.result import Result -from redis.commands.search.suggestion import Suggestion - -from .conftest import default_redismod_url, skip_ifmodversion_lt - -WILL_PLAY_TEXT = os.path.abspath( - os.path.join(os.path.dirname(__file__), "testdata", "will_play_text.csv.bz2") -) - -TITLES_CSV = os.path.abspath( - os.path.join(os.path.dirname(__file__), "testdata", "titles.csv") -) - - -def waitForIndex(env, idx, timeout=None): - delay = 0.1 - while True: - res = env.execute_command("ft.info", idx) - try: - res.index("indexing") - except ValueError: - break - - if int(res[res.index("indexing") + 1]) == 0: - break - - time.sleep(delay) - if timeout is not None: - timeout -= delay - if timeout <= 0: - break - - -def getClient(): - """ - Gets a client client attached to an index name which is ready to be - created - """ - rc = Redis.from_url(default_redismod_url, decode_responses=True) - return rc - - -def createIndex(client, num_docs=100, definition=None): - try: - client.create_index( - (TextField("play", weight=5.0), TextField("txt"), NumericField("chapter")), - definition=definition, - ) - except redis.ResponseError: - client.dropindex(delete_documents=True) - return createIndex(client, num_docs=num_docs, definition=definition) - - chapters = {} - bzfp = TextIOWrapper(bz2.BZ2File(WILL_PLAY_TEXT), encoding="utf8") - - r = csv.reader(bzfp, delimiter=";") - for n, line in enumerate(r): - - play, chapter, _, text = line[1], line[2], line[4], line[5] - - key = f"{play}:{chapter}".lower() - d = chapters.setdefault(key, {}) - d["play"] = play - d["txt"] = d.get("txt", "") + " " + text - d["chapter"] = int(chapter or 0) - if len(chapters) == num_docs: - break - - indexer = client.batch_indexer(chunk_size=50) - assert isinstance(indexer, Search.BatchIndexer) - assert 50 == indexer.chunk_size - - for key, doc in chapters.items(): - indexer.add_document(key, **doc) - indexer.commit() - - -# override the default module client, search requires both db=0, and text -@pytest.fixture -def modclient(): - return Redis.from_url(default_redismod_url, db=0, decode_responses=True) - - -@pytest.fixture -def client(modclient): - modclient.flushdb() - return modclient - - -@pytest.mark.redismod -def test_client(client): - num_docs = 500 - createIndex(client.ft(), num_docs=num_docs) - waitForIndex(client, "idx") - # verify info - info = client.ft().info() - for k in [ - "index_name", - "index_options", - "attributes", - "num_docs", - "max_doc_id", - "num_terms", - "num_records", - "inverted_sz_mb", - "offset_vectors_sz_mb", - "doc_table_size_mb", - "key_table_size_mb", - "records_per_doc_avg", - "bytes_per_record_avg", - "offsets_per_term_avg", - "offset_bits_per_record_avg", - ]: - assert k in info - - assert client.ft().index_name == info["index_name"] - assert num_docs == int(info["num_docs"]) - - res = client.ft().search("henry iv") - assert isinstance(res, Result) - assert 225 == res.total - assert 10 == len(res.docs) - assert res.duration > 0 - - for doc in res.docs: - assert doc.id - assert doc.play == "Henry IV" - assert len(doc.txt) > 0 - - # test no content - res = client.ft().search(Query("king").no_content()) - assert 194 == res.total - assert 10 == len(res.docs) - for doc in res.docs: - assert "txt" not in doc.__dict__ - assert "play" not in doc.__dict__ - - # test verbatim vs no verbatim - total = client.ft().search(Query("kings").no_content()).total - vtotal = client.ft().search(Query("kings").no_content().verbatim()).total - assert total > vtotal - - # test in fields - txt_total = ( - client.ft().search(Query("henry").no_content().limit_fields("txt")).total - ) - play_total = ( - client.ft().search(Query("henry").no_content().limit_fields("play")).total - ) - both_total = ( - client.ft() - .search(Query("henry").no_content().limit_fields("play", "txt")) - .total - ) - assert 129 == txt_total - assert 494 == play_total - assert 494 == both_total - - # test load_document - doc = client.ft().load_document("henry vi part 3:62") - assert doc is not None - assert "henry vi part 3:62" == doc.id - assert doc.play == "Henry VI Part 3" - assert len(doc.txt) > 0 - - # test in-keys - ids = [x.id for x in client.ft().search(Query("henry")).docs] - assert 10 == len(ids) - subset = ids[:5] - docs = client.ft().search(Query("henry").limit_ids(*subset)) - assert len(subset) == docs.total - ids = [x.id for x in docs.docs] - assert set(ids) == set(subset) - - # test slop and in order - assert 193 == client.ft().search(Query("henry king")).total - assert 3 == client.ft().search(Query("henry king").slop(0).in_order()).total - assert 52 == client.ft().search(Query("king henry").slop(0).in_order()).total - assert 53 == client.ft().search(Query("henry king").slop(0)).total - assert 167 == client.ft().search(Query("henry king").slop(100)).total - - # test delete document - client.ft().add_document("doc-5ghs2", play="Death of a Salesman") - res = client.ft().search(Query("death of a salesman")) - assert 1 == res.total - - assert 1 == client.ft().delete_document("doc-5ghs2") - res = client.ft().search(Query("death of a salesman")) - assert 0 == res.total - assert 0 == client.ft().delete_document("doc-5ghs2") - - client.ft().add_document("doc-5ghs2", play="Death of a Salesman") - res = client.ft().search(Query("death of a salesman")) - assert 1 == res.total - client.ft().delete_document("doc-5ghs2") - - -@pytest.mark.redismod -@skip_ifmodversion_lt("2.2.0", "search") -def test_payloads(client): - client.ft().create_index((TextField("txt"),)) - - client.ft().add_document("doc1", payload="foo baz", txt="foo bar") - client.ft().add_document("doc2", txt="foo bar") - - q = Query("foo bar").with_payloads() - res = client.ft().search(q) - assert 2 == res.total - assert "doc1" == res.docs[0].id - assert "doc2" == res.docs[1].id - assert "foo baz" == res.docs[0].payload - assert res.docs[1].payload is None - - -@pytest.mark.redismod -def test_scores(client): - client.ft().create_index((TextField("txt"),)) - - client.ft().add_document("doc1", txt="foo baz") - client.ft().add_document("doc2", txt="foo bar") - - q = Query("foo ~bar").with_scores() - res = client.ft().search(q) - assert 2 == res.total - assert "doc2" == res.docs[0].id - assert 3.0 == res.docs[0].score - assert "doc1" == res.docs[1].id - # todo: enable once new RS version is tagged - # self.assertEqual(0.2, res.docs[1].score) - - -@pytest.mark.redismod -def test_replace(client): - client.ft().create_index((TextField("txt"),)) - - client.ft().add_document("doc1", txt="foo bar") - client.ft().add_document("doc2", txt="foo bar") - waitForIndex(client, "idx") - - res = client.ft().search("foo bar") - assert 2 == res.total - client.ft().add_document("doc1", replace=True, txt="this is a replaced doc") - - res = client.ft().search("foo bar") - assert 1 == res.total - assert "doc2" == res.docs[0].id - - res = client.ft().search("replaced doc") - assert 1 == res.total - assert "doc1" == res.docs[0].id - - -@pytest.mark.redismod -def test_stopwords(client): - client.ft().create_index((TextField("txt"),), stopwords=["foo", "bar", "baz"]) - client.ft().add_document("doc1", txt="foo bar") - client.ft().add_document("doc2", txt="hello world") - waitForIndex(client, "idx") - - q1 = Query("foo bar").no_content() - q2 = Query("foo bar hello world").no_content() - res1, res2 = client.ft().search(q1), client.ft().search(q2) - assert 0 == res1.total - assert 1 == res2.total - - -@pytest.mark.redismod -def test_filters(client): - client.ft().create_index((TextField("txt"), NumericField("num"), GeoField("loc"))) - client.ft().add_document("doc1", txt="foo bar", num=3.141, loc="-0.441,51.458") - client.ft().add_document("doc2", txt="foo baz", num=2, loc="-0.1,51.2") - - waitForIndex(client, "idx") - # Test numerical filter - q1 = Query("foo").add_filter(NumericFilter("num", 0, 2)).no_content() - q2 = ( - Query("foo") - .add_filter(NumericFilter("num", 2, NumericFilter.INF, minExclusive=True)) - .no_content() - ) - res1, res2 = client.ft().search(q1), client.ft().search(q2) - - assert 1 == res1.total - assert 1 == res2.total - assert "doc2" == res1.docs[0].id - assert "doc1" == res2.docs[0].id - - # Test geo filter - q1 = Query("foo").add_filter(GeoFilter("loc", -0.44, 51.45, 10)).no_content() - q2 = Query("foo").add_filter(GeoFilter("loc", -0.44, 51.45, 100)).no_content() - res1, res2 = client.ft().search(q1), client.ft().search(q2) - - assert 1 == res1.total - assert 2 == res2.total - assert "doc1" == res1.docs[0].id - - # Sort results, after RDB reload order may change - res = [res2.docs[0].id, res2.docs[1].id] - res.sort() - assert ["doc1", "doc2"] == res - - -@pytest.mark.redismod -def test_payloads_with_no_content(client): - client.ft().create_index((TextField("txt"),)) - client.ft().add_document("doc1", payload="foo baz", txt="foo bar") - client.ft().add_document("doc2", payload="foo baz2", txt="foo bar") - - q = Query("foo bar").with_payloads().no_content() - res = client.ft().search(q) - assert 2 == len(res.docs) - - -@pytest.mark.redismod -def test_sort_by(client): - client.ft().create_index((TextField("txt"), NumericField("num", sortable=True))) - client.ft().add_document("doc1", txt="foo bar", num=1) - client.ft().add_document("doc2", txt="foo baz", num=2) - client.ft().add_document("doc3", txt="foo qux", num=3) - - # Test sort - q1 = Query("foo").sort_by("num", asc=True).no_content() - q2 = Query("foo").sort_by("num", asc=False).no_content() - res1, res2 = client.ft().search(q1), client.ft().search(q2) - - assert 3 == res1.total - assert "doc1" == res1.docs[0].id - assert "doc2" == res1.docs[1].id - assert "doc3" == res1.docs[2].id - assert 3 == res2.total - assert "doc1" == res2.docs[2].id - assert "doc2" == res2.docs[1].id - assert "doc3" == res2.docs[0].id - - -@pytest.mark.redismod -@skip_ifmodversion_lt("2.0.0", "search") -def test_drop_index(): - """ - Ensure the index gets dropped by data remains by default - """ - for x in range(20): - for keep_docs in [[True, {}], [False, {"name": "haveit"}]]: - idx = "HaveIt" - index = getClient() - index.hset("index:haveit", mapping={"name": "haveit"}) - idef = IndexDefinition(prefix=["index:"]) - index.ft(idx).create_index((TextField("name"),), definition=idef) - waitForIndex(index, idx) - index.ft(idx).dropindex(delete_documents=keep_docs[0]) - i = index.hgetall("index:haveit") - assert i == keep_docs[1] - - -@pytest.mark.redismod -def test_example(client): - # Creating the index definition and schema - client.ft().create_index((TextField("title", weight=5.0), TextField("body"))) - - # Indexing a document - client.ft().add_document( - "doc1", - title="RediSearch", - body="Redisearch impements a search engine on top of redis", - ) - - # Searching with complex parameters: - q = Query("search engine").verbatim().no_content().paging(0, 5) - - res = client.ft().search(q) - assert res is not None - - -@pytest.mark.redismod -def test_auto_complete(client): - n = 0 - with open(TITLES_CSV) as f: - cr = csv.reader(f) - - for row in cr: - n += 1 - term, score = row[0], float(row[1]) - assert n == client.ft().sugadd("ac", Suggestion(term, score=score)) - - assert n == client.ft().suglen("ac") - ret = client.ft().sugget("ac", "bad", with_scores=True) - assert 2 == len(ret) - assert "badger" == ret[0].string - assert isinstance(ret[0].score, float) - assert 1.0 != ret[0].score - assert "badalte rishtey" == ret[1].string - assert isinstance(ret[1].score, float) - assert 1.0 != ret[1].score - - ret = client.ft().sugget("ac", "bad", fuzzy=True, num=10) - assert 10 == len(ret) - assert 1.0 == ret[0].score - strs = {x.string for x in ret} - - for sug in strs: - assert 1 == client.ft().sugdel("ac", sug) - # make sure a second delete returns 0 - for sug in strs: - assert 0 == client.ft().sugdel("ac", sug) - - # make sure they were actually deleted - ret2 = client.ft().sugget("ac", "bad", fuzzy=True, num=10) - for sug in ret2: - assert sug.string not in strs - - # Test with payload - client.ft().sugadd("ac", Suggestion("pay1", payload="pl1")) - client.ft().sugadd("ac", Suggestion("pay2", payload="pl2")) - client.ft().sugadd("ac", Suggestion("pay3", payload="pl3")) - - sugs = client.ft().sugget("ac", "pay", with_payloads=True, with_scores=True) - assert 3 == len(sugs) - for sug in sugs: - assert sug.payload - assert sug.payload.startswith("pl") - - -@pytest.mark.redismod -def test_no_index(client): - client.ft().create_index( - ( - TextField("field"), - TextField("text", no_index=True, sortable=True), - NumericField("numeric", no_index=True, sortable=True), - GeoField("geo", no_index=True, sortable=True), - TagField("tag", no_index=True, sortable=True), - ) - ) - - client.ft().add_document( - "doc1", field="aaa", text="1", numeric="1", geo="1,1", tag="1" - ) - client.ft().add_document( - "doc2", field="aab", text="2", numeric="2", geo="2,2", tag="2" - ) - waitForIndex(client, "idx") - - res = client.ft().search(Query("@text:aa*")) - assert 0 == res.total - - res = client.ft().search(Query("@field:aa*")) - assert 2 == res.total - - res = client.ft().search(Query("*").sort_by("text", asc=False)) - assert 2 == res.total - assert "doc2" == res.docs[0].id - - res = client.ft().search(Query("*").sort_by("text", asc=True)) - assert "doc1" == res.docs[0].id - - res = client.ft().search(Query("*").sort_by("numeric", asc=True)) - assert "doc1" == res.docs[0].id - - res = client.ft().search(Query("*").sort_by("geo", asc=True)) - assert "doc1" == res.docs[0].id - - res = client.ft().search(Query("*").sort_by("tag", asc=True)) - assert "doc1" == res.docs[0].id - - # Ensure exception is raised for non-indexable, non-sortable fields - with pytest.raises(Exception): - TextField("name", no_index=True, sortable=False) - with pytest.raises(Exception): - NumericField("name", no_index=True, sortable=False) - with pytest.raises(Exception): - GeoField("name", no_index=True, sortable=False) - with pytest.raises(Exception): - TagField("name", no_index=True, sortable=False) - - -@pytest.mark.redismod -def test_partial(client): - client.ft().create_index((TextField("f1"), TextField("f2"), TextField("f3"))) - client.ft().add_document("doc1", f1="f1_val", f2="f2_val") - client.ft().add_document("doc2", f1="f1_val", f2="f2_val") - client.ft().add_document("doc1", f3="f3_val", partial=True) - client.ft().add_document("doc2", f3="f3_val", replace=True) - waitForIndex(client, "idx") - - # Search for f3 value. All documents should have it - res = client.ft().search("@f3:f3_val") - assert 2 == res.total - - # Only the document updated with PARTIAL should still have f1 and f2 values - res = client.ft().search("@f3:f3_val @f2:f2_val @f1:f1_val") - assert 1 == res.total - - -@pytest.mark.redismod -def test_no_create(client): - client.ft().create_index((TextField("f1"), TextField("f2"), TextField("f3"))) - client.ft().add_document("doc1", f1="f1_val", f2="f2_val") - client.ft().add_document("doc2", f1="f1_val", f2="f2_val") - client.ft().add_document("doc1", f3="f3_val", no_create=True) - client.ft().add_document("doc2", f3="f3_val", no_create=True, partial=True) - waitForIndex(client, "idx") - - # Search for f3 value. All documents should have it - res = client.ft().search("@f3:f3_val") - assert 2 == res.total - - # Only the document updated with PARTIAL should still have f1 and f2 values - res = client.ft().search("@f3:f3_val @f2:f2_val @f1:f1_val") - assert 1 == res.total - - with pytest.raises(redis.ResponseError): - client.ft().add_document("doc3", f2="f2_val", f3="f3_val", no_create=True) - - -@pytest.mark.redismod -def test_explain(client): - client.ft().create_index((TextField("f1"), TextField("f2"), TextField("f3"))) - res = client.ft().explain("@f3:f3_val @f2:f2_val @f1:f1_val") - assert res - - -@pytest.mark.redismod -def test_explaincli(client): - with pytest.raises(NotImplementedError): - client.ft().explain_cli("foo") - - -@pytest.mark.redismod -def test_summarize(client): - createIndex(client.ft()) - waitForIndex(client, "idx") - - q = Query("king henry").paging(0, 1) - q.highlight(fields=("play", "txt"), tags=("", "")) - q.summarize("txt") - - doc = sorted(client.ft().search(q).docs)[0] - assert "Henry IV" == doc.play - assert ( - "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa - == doc.txt - ) - - q = Query("king henry").paging(0, 1).summarize().highlight() - - doc = sorted(client.ft().search(q).docs)[0] - assert "Henry ... " == doc.play - assert ( - "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa - == doc.txt - ) - - -@pytest.mark.redismod -@skip_ifmodversion_lt("2.0.0", "search") -def test_alias(): - index1 = getClient() - index2 = getClient() - - def1 = IndexDefinition(prefix=["index1:"]) - def2 = IndexDefinition(prefix=["index2:"]) - - ftindex1 = index1.ft("testAlias") - ftindex2 = index2.ft("testAlias2") - ftindex1.create_index((TextField("name"),), definition=def1) - ftindex2.create_index((TextField("name"),), definition=def2) - - index1.hset("index1:lonestar", mapping={"name": "lonestar"}) - index2.hset("index2:yogurt", mapping={"name": "yogurt"}) - - res = ftindex1.search("*").docs[0] - assert "index1:lonestar" == res.id - - # create alias and check for results - ftindex1.aliasadd("spaceballs") - alias_client = getClient().ft("spaceballs") - res = alias_client.search("*").docs[0] - assert "index1:lonestar" == res.id - - # Throw an exception when trying to add an alias that already exists - with pytest.raises(Exception): - ftindex2.aliasadd("spaceballs") - - # update alias and ensure new results - ftindex2.aliasupdate("spaceballs") - alias_client2 = getClient().ft("spaceballs") - - res = alias_client2.search("*").docs[0] - assert "index2:yogurt" == res.id - - ftindex2.aliasdel("spaceballs") - with pytest.raises(Exception): - alias_client2.search("*").docs[0] - - -@pytest.mark.redismod -def test_alias_basic(): - # Creating a client with one index - getClient().flushdb() - index1 = getClient().ft("testAlias") - - index1.create_index((TextField("txt"),)) - index1.add_document("doc1", txt="text goes here") - - index2 = getClient().ft("testAlias2") - index2.create_index((TextField("txt"),)) - index2.add_document("doc2", txt="text goes here") - - # add the actual alias and check - index1.aliasadd("myalias") - alias_client = getClient().ft("myalias") - res = sorted(alias_client.search("*").docs, key=lambda x: x.id) - assert "doc1" == res[0].id - - # Throw an exception when trying to add an alias that already exists - with pytest.raises(Exception): - index2.aliasadd("myalias") - - # update the alias and ensure we get doc2 - index2.aliasupdate("myalias") - alias_client2 = getClient().ft("myalias") - res = sorted(alias_client2.search("*").docs, key=lambda x: x.id) - assert "doc1" == res[0].id - - # delete the alias and expect an error if we try to query again - index2.aliasdel("myalias") - with pytest.raises(Exception): - _ = alias_client2.search("*").docs[0] - - -@pytest.mark.redismod -def test_tags(client): - client.ft().create_index((TextField("txt"), TagField("tags"))) - tags = "foo,foo bar,hello;world" - tags2 = "soba,ramen" - - client.ft().add_document("doc1", txt="fooz barz", tags=tags) - client.ft().add_document("doc2", txt="noodles", tags=tags2) - waitForIndex(client, "idx") - - q = Query("@tags:{foo}") - res = client.ft().search(q) - assert 1 == res.total - - q = Query("@tags:{foo bar}") - res = client.ft().search(q) - assert 1 == res.total - - q = Query("@tags:{foo\\ bar}") - res = client.ft().search(q) - assert 1 == res.total - - q = Query("@tags:{hello\\;world}") - res = client.ft().search(q) - assert 1 == res.total - - q2 = client.ft().tagvals("tags") - assert (tags.split(",") + tags2.split(",")).sort() == q2.sort() - - -@pytest.mark.redismod -def test_textfield_sortable_nostem(client): - # Creating the index definition with sortable and no_stem - client.ft().create_index((TextField("txt", sortable=True, no_stem=True),)) - - # Now get the index info to confirm its contents - response = client.ft().info() - assert "SORTABLE" in response["attributes"][0] - assert "NOSTEM" in response["attributes"][0] - - -@pytest.mark.redismod -def test_alter_schema_add(client): - # Creating the index definition and schema - client.ft().create_index(TextField("title")) - - # Using alter to add a field - client.ft().alter_schema_add(TextField("body")) - - # Indexing a document - client.ft().add_document( - "doc1", title="MyTitle", body="Some content only in the body" - ) - - # Searching with parameter only in the body (the added field) - q = Query("only in the body") - - # Ensure we find the result searching on the added body field - res = client.ft().search(q) - assert 1 == res.total - - -@pytest.mark.redismod -def test_spell_check(client): - client.ft().create_index((TextField("f1"), TextField("f2"))) - - client.ft().add_document("doc1", f1="some valid content", f2="this is sample text") - client.ft().add_document("doc2", f1="very important", f2="lorem ipsum") - waitForIndex(client, "idx") - - # test spellcheck - res = client.ft().spellcheck("impornant") - assert "important" == res["impornant"][0]["suggestion"] - - res = client.ft().spellcheck("contnt") - assert "content" == res["contnt"][0]["suggestion"] - - # test spellcheck with Levenshtein distance - res = client.ft().spellcheck("vlis") - assert res == {} - res = client.ft().spellcheck("vlis", distance=2) - assert "valid" == res["vlis"][0]["suggestion"] - - # test spellcheck include - client.ft().dict_add("dict", "lore", "lorem", "lorm") - res = client.ft().spellcheck("lorm", include="dict") - assert len(res["lorm"]) == 3 - assert ( - res["lorm"][0]["suggestion"], - res["lorm"][1]["suggestion"], - res["lorm"][2]["suggestion"], - ) == ("lorem", "lore", "lorm") - assert (res["lorm"][0]["score"], res["lorm"][1]["score"]) == ("0.5", "0") - - # test spellcheck exclude - res = client.ft().spellcheck("lorm", exclude="dict") - assert res == {} - - -@pytest.mark.redismod -def test_dict_operations(client): - client.ft().create_index((TextField("f1"), TextField("f2"))) - # Add three items - res = client.ft().dict_add("custom_dict", "item1", "item2", "item3") - assert 3 == res - - # Remove one item - res = client.ft().dict_del("custom_dict", "item2") - assert 1 == res - - # Dump dict and inspect content - res = client.ft().dict_dump("custom_dict") - assert ["item1", "item3"] == res - - # Remove rest of the items before reload - client.ft().dict_del("custom_dict", *res) - - -@pytest.mark.redismod -def test_phonetic_matcher(client): - client.ft().create_index((TextField("name"),)) - client.ft().add_document("doc1", name="Jon") - client.ft().add_document("doc2", name="John") - - res = client.ft().search(Query("Jon")) - assert 1 == len(res.docs) - assert "Jon" == res.docs[0].name - - # Drop and create index with phonetic matcher - client.flushdb() - - client.ft().create_index((TextField("name", phonetic_matcher="dm:en"),)) - client.ft().add_document("doc1", name="Jon") - client.ft().add_document("doc2", name="John") - - res = client.ft().search(Query("Jon")) - assert 2 == len(res.docs) - assert ["John", "Jon"] == sorted(d.name for d in res.docs) - - -@pytest.mark.redismod -def test_scorer(client): - client.ft().create_index((TextField("description"),)) - - client.ft().add_document( - "doc1", description="The quick brown fox jumps over the lazy dog" - ) - client.ft().add_document( - "doc2", - description="Quick alice was beginning to get very tired of sitting by her quick sister on the bank, and of having nothing to do.", # noqa - ) - - # default scorer is TFIDF - res = client.ft().search(Query("quick").with_scores()) - assert 1.0 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("TFIDF").with_scores()) - assert 1.0 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("TFIDF.DOCNORM").with_scores()) - assert 0.1111111111111111 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("BM25").with_scores()) - assert 0.17699114465425977 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("DISMAX").with_scores()) - assert 2.0 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("DOCSCORE").with_scores()) - assert 1.0 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("HAMMING").with_scores()) - assert 0.0 == res.docs[0].score - - -@pytest.mark.redismod -def test_get(client): - client.ft().create_index((TextField("f1"), TextField("f2"))) - - assert [None] == client.ft().get("doc1") - assert [None, None] == client.ft().get("doc2", "doc1") - - client.ft().add_document( - "doc1", f1="some valid content dd1", f2="this is sample text ff1" - ) - client.ft().add_document( - "doc2", f1="some valid content dd2", f2="this is sample text ff2" - ) - - assert [ - ["f1", "some valid content dd2", "f2", "this is sample text ff2"] - ] == client.ft().get("doc2") - assert [ - ["f1", "some valid content dd1", "f2", "this is sample text ff1"], - ["f1", "some valid content dd2", "f2", "this is sample text ff2"], - ] == client.ft().get("doc1", "doc2") - - -@pytest.mark.redismod -@skip_ifmodversion_lt("2.2.0", "search") -def test_config(client): - assert client.ft().config_set("TIMEOUT", "100") - with pytest.raises(redis.ResponseError): - client.ft().config_set("TIMEOUT", "null") - res = client.ft().config_get("*") - assert "100" == res["TIMEOUT"] - res = client.ft().config_get("TIMEOUT") - assert "100" == res["TIMEOUT"] - - -@pytest.mark.redismod -def test_aggregations_groupby(client): - # Creating the index definition and schema - client.ft().create_index( - ( - NumericField("random_num"), - TextField("title"), - TextField("body"), - TextField("parent"), - ) - ) - - # Indexing a document - client.ft().add_document( - "search", - title="RediSearch", - body="Redisearch impements a search engine on top of redis", - parent="redis", - random_num=10, - ) - client.ft().add_document( - "ai", - title="RedisAI", - body="RedisAI executes Deep Learning/Machine Learning models and managing their data.", # noqa - parent="redis", - random_num=3, - ) - client.ft().add_document( - "json", - title="RedisJson", - body="RedisJSON implements ECMA-404 The JSON Data Interchange Standard as a native data type.", # noqa - parent="redis", - random_num=8, - ) - - req = aggregations.AggregateRequest("redis").group_by( - "@parent", - reducers.count(), - ) - - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "3" - - req = aggregations.AggregateRequest("redis").group_by( - "@parent", - reducers.count_distinct("@title"), - ) - - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "3" - - req = aggregations.AggregateRequest("redis").group_by( - "@parent", - reducers.count_distinctish("@title"), - ) - - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "3" - - req = aggregations.AggregateRequest("redis").group_by( - "@parent", - reducers.sum("@random_num"), - ) - - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "21" # 10+8+3 - - req = aggregations.AggregateRequest("redis").group_by( - "@parent", - reducers.min("@random_num"), - ) - - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "3" # min(10,8,3) - - req = aggregations.AggregateRequest("redis").group_by( - "@parent", - reducers.max("@random_num"), - ) - - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "10" # max(10,8,3) - - req = aggregations.AggregateRequest("redis").group_by( - "@parent", - reducers.avg("@random_num"), - ) - - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "7" # (10+3+8)/3 - - req = aggregations.AggregateRequest("redis").group_by( - "@parent", - reducers.stddev("random_num"), - ) - - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "3.60555127546" - - req = aggregations.AggregateRequest("redis").group_by( - "@parent", - reducers.quantile("@random_num", 0.5), - ) - - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "10" - - req = aggregations.AggregateRequest("redis").group_by( - "@parent", - reducers.tolist("@title"), - ) - - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == ["RediSearch", "RedisAI", "RedisJson"] - - req = aggregations.AggregateRequest("redis").group_by( - "@parent", - reducers.first_value("@title").alias("first"), - ) - - res = client.ft().aggregate(req).rows[0] - assert res == ["parent", "redis", "first", "RediSearch"] - - req = aggregations.AggregateRequest("redis").group_by( - "@parent", - reducers.random_sample("@title", 2).alias("random"), - ) - - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[2] == "random" - assert len(res[3]) == 2 - assert res[3][0] in ["RediSearch", "RedisAI", "RedisJson"] - - -@pytest.mark.redismod -def test_aggregations_sort_by_and_limit(client): - client.ft().create_index( - ( - TextField("t1"), - TextField("t2"), - ) - ) - - client.ft().client.hset("doc1", mapping={"t1": "a", "t2": "b"}) - client.ft().client.hset("doc2", mapping={"t1": "b", "t2": "a"}) - - # test sort_by using SortDirection - req = aggregations.AggregateRequest("*").sort_by( - aggregations.Asc("@t2"), aggregations.Desc("@t1") - ) - res = client.ft().aggregate(req) - assert res.rows[0] == ["t2", "a", "t1", "b"] - assert res.rows[1] == ["t2", "b", "t1", "a"] - - # test sort_by without SortDirection - req = aggregations.AggregateRequest("*").sort_by("@t1") - res = client.ft().aggregate(req) - assert res.rows[0] == ["t1", "a"] - assert res.rows[1] == ["t1", "b"] - - # test sort_by with max - req = aggregations.AggregateRequest("*").sort_by("@t1", max=1) - res = client.ft().aggregate(req) - assert len(res.rows) == 1 - - # test limit - req = aggregations.AggregateRequest("*").sort_by("@t1").limit(1, 1) - res = client.ft().aggregate(req) - assert len(res.rows) == 1 - assert res.rows[0] == ["t1", "b"] - - -@pytest.mark.redismod -def test_aggregations_load(client): - client.ft().create_index( - ( - TextField("t1"), - TextField("t2"), - ) - ) - - client.ft().client.hset("doc1", mapping={"t1": "hello", "t2": "world"}) - - # load t1 - req = aggregations.AggregateRequest("*").load("t1") - res = client.ft().aggregate(req) - assert res.rows[0] == ["t1", "hello"] - - # load t2 - req = aggregations.AggregateRequest("*").load("t2") - res = client.ft().aggregate(req) - assert res.rows[0] == ["t2", "world"] - - # load all - req = aggregations.AggregateRequest("*").load() - res = client.ft().aggregate(req) - assert res.rows[0] == ["t1", "hello", "t2", "world"] - - -@pytest.mark.redismod -def test_aggregations_apply(client): - client.ft().create_index( - ( - TextField("PrimaryKey", sortable=True), - NumericField("CreatedDateTimeUTC", sortable=True), - ) - ) - - client.ft().client.hset( - "doc1", - mapping={"PrimaryKey": "9::362330", "CreatedDateTimeUTC": "637387878524969984"}, - ) - client.ft().client.hset( - "doc2", - mapping={"PrimaryKey": "9::362329", "CreatedDateTimeUTC": "637387875859270016"}, - ) - - req = aggregations.AggregateRequest("*").apply( - CreatedDateTimeUTC="@CreatedDateTimeUTC * 10" - ) - res = client.ft().aggregate(req) - assert res.rows[0] == ["CreatedDateTimeUTC", "6373878785249699840"] - assert res.rows[1] == ["CreatedDateTimeUTC", "6373878758592700416"] - - -@pytest.mark.redismod -def test_aggregations_filter(client): - client.ft().create_index( - ( - TextField("name", sortable=True), - NumericField("age", sortable=True), - ) - ) - - client.ft().client.hset("doc1", mapping={"name": "bar", "age": "25"}) - client.ft().client.hset("doc2", mapping={"name": "foo", "age": "19"}) - - req = aggregations.AggregateRequest("*").filter("@name=='foo' && @age < 20") - res = client.ft().aggregate(req) - assert len(res.rows) == 1 - assert res.rows[0] == ["name", "foo", "age", "19"] - - req = aggregations.AggregateRequest("*").filter("@age > 15").sort_by("@age") - res = client.ft().aggregate(req) - assert len(res.rows) == 2 - assert res.rows[0] == ["age", "19"] - assert res.rows[1] == ["age", "25"] - - -@pytest.mark.redismod -@skip_ifmodversion_lt("2.0.0", "search") -def test_index_definition(client): - """ - Create definition and test its args - """ - with pytest.raises(RuntimeError): - IndexDefinition(prefix=["hset:", "henry"], index_type="json") - - definition = IndexDefinition( - prefix=["hset:", "henry"], - filter="@f1==32", - language="English", - language_field="play", - score_field="chapter", - score=0.5, - payload_field="txt", - index_type=IndexType.JSON, - ) - - assert [ - "ON", - "JSON", - "PREFIX", - 2, - "hset:", - "henry", - "FILTER", - "@f1==32", - "LANGUAGE_FIELD", - "play", - "LANGUAGE", - "English", - "SCORE_FIELD", - "chapter", - "SCORE", - 0.5, - "PAYLOAD_FIELD", - "txt", - ] == definition.args - - createIndex(client.ft(), num_docs=500, definition=definition) - - -@pytest.mark.redismod -@skip_ifmodversion_lt("2.0.0", "search") -def test_create_client_definition(client): - """ - Create definition with no index type provided, - and use hset to test the client definition (the default is HASH). - """ - definition = IndexDefinition(prefix=["hset:", "henry"]) - createIndex(client.ft(), num_docs=500, definition=definition) - - info = client.ft().info() - assert 494 == int(info["num_docs"]) - - client.ft().client.hset("hset:1", "f1", "v1") - info = client.ft().info() - assert 495 == int(info["num_docs"]) - - -@pytest.mark.redismod -@skip_ifmodversion_lt("2.0.0", "search") -def test_create_client_definition_hash(client): - """ - Create definition with IndexType.HASH as index type (ON HASH), - and use hset to test the client definition. - """ - definition = IndexDefinition(prefix=["hset:", "henry"], index_type=IndexType.HASH) - createIndex(client.ft(), num_docs=500, definition=definition) - - info = client.ft().info() - assert 494 == int(info["num_docs"]) - - client.ft().client.hset("hset:1", "f1", "v1") - info = client.ft().info() - assert 495 == int(info["num_docs"]) - - -@pytest.mark.redismod -@skip_ifmodversion_lt("2.2.0", "search") -def test_create_client_definition_json(client): - """ - Create definition with IndexType.JSON as index type (ON JSON), - and use json client to test it. - """ - definition = IndexDefinition(prefix=["king:"], index_type=IndexType.JSON) - client.ft().create_index((TextField("$.name"),), definition=definition) - - client.json().set("king:1", Path.rootPath(), {"name": "henry"}) - client.json().set("king:2", Path.rootPath(), {"name": "james"}) - - res = client.ft().search("henry") - assert res.docs[0].id == "king:1" - assert res.docs[0].payload is None - assert res.docs[0].json == '{"name":"henry"}' - assert res.total == 1 - - -@pytest.mark.redismod -@skip_ifmodversion_lt("2.2.0", "search") -def test_fields_as_name(client): - # create index - SCHEMA = ( - TextField("$.name", sortable=True, as_name="name"), - NumericField("$.age", as_name="just_a_number"), - ) - definition = IndexDefinition(index_type=IndexType.JSON) - client.ft().create_index(SCHEMA, definition=definition) - - # insert json data - res = client.json().set("doc:1", Path.rootPath(), {"name": "Jon", "age": 25}) - assert res - - total = client.ft().search(Query("Jon").return_fields("name", "just_a_number")).docs - assert 1 == len(total) - assert "doc:1" == total[0].id - assert "Jon" == total[0].name - assert "25" == total[0].just_a_number - - -@pytest.mark.redismod -@skip_ifmodversion_lt("2.2.0", "search") -def test_search_return_fields(client): - res = client.json().set( - "doc:1", - Path.rootPath(), - {"t": "riceratops", "t2": "telmatosaurus", "n": 9072, "flt": 97.2}, - ) - assert res - - # create index on - definition = IndexDefinition(index_type=IndexType.JSON) - SCHEMA = ( - TextField("$.t"), - NumericField("$.flt"), - ) - client.ft().create_index(SCHEMA, definition=definition) - waitForIndex(client, "idx") - - total = client.ft().search(Query("*").return_field("$.t", as_field="txt")).docs - assert 1 == len(total) - assert "doc:1" == total[0].id - assert "riceratops" == total[0].txt - - total = client.ft().search(Query("*").return_field("$.t2", as_field="txt")).docs - assert 1 == len(total) - assert "doc:1" == total[0].id - assert "telmatosaurus" == total[0].txt - - -@pytest.mark.redismod -def test_synupdate(client): - definition = IndexDefinition(index_type=IndexType.HASH) - client.ft().create_index( - ( - TextField("title"), - TextField("body"), - ), - definition=definition, - ) - - client.ft().synupdate("id1", True, "boy", "child", "offspring") - client.ft().add_document("doc1", title="he is a baby", body="this is a test") - - client.ft().synupdate("id1", True, "baby") - client.ft().add_document("doc2", title="he is another baby", body="another test") - - res = client.ft().search(Query("child").expander("SYNONYM")) - assert res.docs[0].id == "doc2" - assert res.docs[0].title == "he is another baby" - assert res.docs[0].body == "another test" - - -@pytest.mark.redismod -def test_syndump(client): - definition = IndexDefinition(index_type=IndexType.HASH) - client.ft().create_index( - ( - TextField("title"), - TextField("body"), - ), - definition=definition, - ) - - client.ft().synupdate("id1", False, "boy", "child", "offspring") - client.ft().synupdate("id2", False, "baby", "child") - client.ft().synupdate("id3", False, "tree", "wood") - res = client.ft().syndump() - assert res == { - "boy": ["id1"], - "tree": ["id3"], - "wood": ["id3"], - "child": ["id1", "id2"], - "baby": ["id2"], - "offspring": ["id1"], - } - - -@pytest.mark.redismod -@skip_ifmodversion_lt("2.2.0", "search") -def test_create_json_with_alias(client): - """ - Create definition with IndexType.JSON as index type (ON JSON) with two - fields with aliases, and use json client to test it. - """ - definition = IndexDefinition(prefix=["king:"], index_type=IndexType.JSON) - client.ft().create_index( - (TextField("$.name", as_name="name"), NumericField("$.num", as_name="num")), - definition=definition, - ) - - client.json().set("king:1", Path.rootPath(), {"name": "henry", "num": 42}) - client.json().set("king:2", Path.rootPath(), {"name": "james", "num": 3.14}) - - res = client.ft().search("@name:henry") - assert res.docs[0].id == "king:1" - assert res.docs[0].json == '{"name":"henry","num":42}' - assert res.total == 1 - - res = client.ft().search("@num:[0 10]") - assert res.docs[0].id == "king:2" - assert res.docs[0].json == '{"name":"james","num":3.14}' - assert res.total == 1 - - # Tests returns an error if path contain special characters (user should - # use an alias) - with pytest.raises(Exception): - client.ft().search("@$.name:henry") - - -@pytest.mark.redismod -@skip_ifmodversion_lt("2.2.0", "search") -def test_json_with_multipath(client): - """ - Create definition with IndexType.JSON as index type (ON JSON), - and use json client to test it. - """ - definition = IndexDefinition(prefix=["king:"], index_type=IndexType.JSON) - client.ft().create_index( - (TagField("$..name", as_name="name")), definition=definition - ) - - client.json().set( - "king:1", Path.rootPath(), {"name": "henry", "country": {"name": "england"}} - ) - - res = client.ft().search("@name:{henry}") - assert res.docs[0].id == "king:1" - assert res.docs[0].json == '{"name":"henry","country":{"name":"england"}}' - assert res.total == 1 - - res = client.ft().search("@name:{england}") - assert res.docs[0].id == "king:1" - assert res.docs[0].json == '{"name":"henry","country":{"name":"england"}}' - assert res.total == 1 - - -@pytest.mark.redismod -@skip_ifmodversion_lt("2.2.0", "search") -def test_json_with_jsonpath(client): - definition = IndexDefinition(index_type=IndexType.JSON) - client.ft().create_index( - ( - TextField('$["prod:name"]', as_name="name"), - TextField("$.prod:name", as_name="name_unsupported"), - ), - definition=definition, - ) - - client.json().set("doc:1", Path.rootPath(), {"prod:name": "RediSearch"}) - - # query for a supported field succeeds - res = client.ft().search(Query("@name:RediSearch")) - assert res.total == 1 - assert res.docs[0].id == "doc:1" - assert res.docs[0].json == '{"prod:name":"RediSearch"}' - - # query for an unsupported field fails - res = client.ft().search("@name_unsupported:RediSearch") - assert res.total == 0 - - # return of a supported field succeeds - res = client.ft().search(Query("@name:RediSearch").return_field("name")) - assert res.total == 1 - assert res.docs[0].id == "doc:1" - assert res.docs[0].name == "RediSearch" - - # return of an unsupported field fails - res = client.ft().search(Query("@name:RediSearch").return_field("name_unsupported")) - assert res.total == 1 - assert res.docs[0].id == "doc:1" - with pytest.raises(Exception): - res.docs[0].name_unsupported - - -@pytest.mark.redismod -def test_profile(client): - client.ft().create_index((TextField("t"),)) - client.ft().client.hset("1", "t", "hello") - client.ft().client.hset("2", "t", "world") - - # check using Query - q = Query("hello|world").no_content() - res, det = client.ft().profile(q) - assert det["Iterators profile"]["Counter"] == 2.0 - assert len(det["Iterators profile"]["Child iterators"]) == 2 - assert det["Iterators profile"]["Type"] == "UNION" - assert det["Parsing time"] < 0.5 - assert len(res.docs) == 2 # check also the search result - - # check using AggregateRequest - req = ( - aggregations.AggregateRequest("*") - .load("t") - .apply(prefix="startswith(@t, 'hel')") - ) - res, det = client.ft().profile(req) - assert det["Iterators profile"]["Counter"] == 2.0 - assert det["Iterators profile"]["Type"] == "WILDCARD" - assert det["Parsing time"] < 0.5 - assert len(res.rows) == 2 # check also the search result - - -@pytest.mark.redismod -def test_profile_limited(client): - client.ft().create_index((TextField("t"),)) - client.ft().client.hset("1", "t", "hello") - client.ft().client.hset("2", "t", "hell") - client.ft().client.hset("3", "t", "help") - client.ft().client.hset("4", "t", "helowa") - - q = Query("%hell% hel*") - res, det = client.ft().profile(q, limited=True) - assert ( - det["Iterators profile"]["Child iterators"][0]["Child iterators"] - == "The number of iterators in the union is 3" - ) - assert ( - det["Iterators profile"]["Child iterators"][1]["Child iterators"] - == "The number of iterators in the union is 4" - ) - assert det["Iterators profile"]["Type"] == "INTERSECT" - assert len(res.docs) == 3 # check also the search result diff --git a/tests/test_sentinel.py b/tests/test_sentinel.py deleted file mode 100644 index 0357443..0000000 --- a/tests/test_sentinel.py +++ /dev/null @@ -1,234 +0,0 @@ -import socket - -import pytest - -import redis.sentinel -from redis import exceptions -from redis.sentinel import ( - MasterNotFoundError, - Sentinel, - SentinelConnectionPool, - SlaveNotFoundError, -) - - -@pytest.fixture(scope="module") -def master_ip(master_host): - yield socket.gethostbyname(master_host[0]) - - -class SentinelTestClient: - def __init__(self, cluster, id): - self.cluster = cluster - self.id = id - - def sentinel_masters(self): - self.cluster.connection_error_if_down(self) - self.cluster.timeout_if_down(self) - return {self.cluster.service_name: self.cluster.master} - - def sentinel_slaves(self, master_name): - self.cluster.connection_error_if_down(self) - self.cluster.timeout_if_down(self) - if master_name != self.cluster.service_name: - return [] - return self.cluster.slaves - - def execute_command(self, *args, **kwargs): - # wrapper purely to validate the calls don't explode - from redis.client import bool_ok - - return bool_ok - - -class SentinelTestCluster: - def __init__(self, servisentinel_ce_name="mymaster", ip="127.0.0.1", port=6379): - self.clients = {} - self.master = { - "ip": ip, - "port": port, - "is_master": True, - "is_sdown": False, - "is_odown": False, - "num-other-sentinels": 0, - } - self.service_name = servisentinel_ce_name - self.slaves = [] - self.nodes_down = set() - self.nodes_timeout = set() - - def connection_error_if_down(self, node): - if node.id in self.nodes_down: - raise exceptions.ConnectionError - - def timeout_if_down(self, node): - if node.id in self.nodes_timeout: - raise exceptions.TimeoutError - - def client(self, host, port, **kwargs): - return SentinelTestClient(self, (host, port)) - - -@pytest.fixture() -def cluster(request, master_ip): - def teardown(): - redis.sentinel.Redis = saved_Redis - - cluster = SentinelTestCluster(ip=master_ip) - saved_Redis = redis.sentinel.Redis - redis.sentinel.Redis = cluster.client - request.addfinalizer(teardown) - return cluster - - -@pytest.fixture() -def sentinel(request, cluster): - return Sentinel([("foo", 26379), ("bar", 26379)]) - - -@pytest.mark.onlynoncluster -def test_discover_master(sentinel, master_ip): - address = sentinel.discover_master("mymaster") - assert address == (master_ip, 6379) - - -@pytest.mark.onlynoncluster -def test_discover_master_error(sentinel): - with pytest.raises(MasterNotFoundError): - sentinel.discover_master("xxx") - - -@pytest.mark.onlynoncluster -def test_discover_master_sentinel_down(cluster, sentinel, master_ip): - # Put first sentinel 'foo' down - cluster.nodes_down.add(("foo", 26379)) - address = sentinel.discover_master("mymaster") - assert address == (master_ip, 6379) - # 'bar' is now first sentinel - assert sentinel.sentinels[0].id == ("bar", 26379) - - -@pytest.mark.onlynoncluster -def test_discover_master_sentinel_timeout(cluster, sentinel, master_ip): - # Put first sentinel 'foo' down - cluster.nodes_timeout.add(("foo", 26379)) - address = sentinel.discover_master("mymaster") - assert address == (master_ip, 6379) - # 'bar' is now first sentinel - assert sentinel.sentinels[0].id == ("bar", 26379) - - -@pytest.mark.onlynoncluster -def test_master_min_other_sentinels(cluster, master_ip): - sentinel = Sentinel([("foo", 26379)], min_other_sentinels=1) - # min_other_sentinels - with pytest.raises(MasterNotFoundError): - sentinel.discover_master("mymaster") - cluster.master["num-other-sentinels"] = 2 - address = sentinel.discover_master("mymaster") - assert address == (master_ip, 6379) - - -@pytest.mark.onlynoncluster -def test_master_odown(cluster, sentinel): - cluster.master["is_odown"] = True - with pytest.raises(MasterNotFoundError): - sentinel.discover_master("mymaster") - - -@pytest.mark.onlynoncluster -def test_master_sdown(cluster, sentinel): - cluster.master["is_sdown"] = True - with pytest.raises(MasterNotFoundError): - sentinel.discover_master("mymaster") - - -@pytest.mark.onlynoncluster -def test_discover_slaves(cluster, sentinel): - assert sentinel.discover_slaves("mymaster") == [] - - cluster.slaves = [ - {"ip": "slave0", "port": 1234, "is_odown": False, "is_sdown": False}, - {"ip": "slave1", "port": 1234, "is_odown": False, "is_sdown": False}, - ] - assert sentinel.discover_slaves("mymaster") == [("slave0", 1234), ("slave1", 1234)] - - # slave0 -> ODOWN - cluster.slaves[0]["is_odown"] = True - assert sentinel.discover_slaves("mymaster") == [("slave1", 1234)] - - # slave1 -> SDOWN - cluster.slaves[1]["is_sdown"] = True - assert sentinel.discover_slaves("mymaster") == [] - - cluster.slaves[0]["is_odown"] = False - cluster.slaves[1]["is_sdown"] = False - - # node0 -> DOWN - cluster.nodes_down.add(("foo", 26379)) - assert sentinel.discover_slaves("mymaster") == [("slave0", 1234), ("slave1", 1234)] - cluster.nodes_down.clear() - - # node0 -> TIMEOUT - cluster.nodes_timeout.add(("foo", 26379)) - assert sentinel.discover_slaves("mymaster") == [("slave0", 1234), ("slave1", 1234)] - - -@pytest.mark.onlynoncluster -def test_master_for(cluster, sentinel, master_ip): - master = sentinel.master_for("mymaster", db=9) - assert master.ping() - assert master.connection_pool.master_address == (master_ip, 6379) - - # Use internal connection check - master = sentinel.master_for("mymaster", db=9, check_connection=True) - assert master.ping() - - -@pytest.mark.onlynoncluster -def test_slave_for(cluster, sentinel): - cluster.slaves = [ - {"ip": "127.0.0.1", "port": 6379, "is_odown": False, "is_sdown": False}, - ] - slave = sentinel.slave_for("mymaster", db=9) - assert slave.ping() - - -@pytest.mark.onlynoncluster -def test_slave_for_slave_not_found_error(cluster, sentinel): - cluster.master["is_odown"] = True - slave = sentinel.slave_for("mymaster", db=9) - with pytest.raises(SlaveNotFoundError): - slave.ping() - - -@pytest.mark.onlynoncluster -def test_slave_round_robin(cluster, sentinel, master_ip): - cluster.slaves = [ - {"ip": "slave0", "port": 6379, "is_odown": False, "is_sdown": False}, - {"ip": "slave1", "port": 6379, "is_odown": False, "is_sdown": False}, - ] - pool = SentinelConnectionPool("mymaster", sentinel) - rotator = pool.rotate_slaves() - assert next(rotator) in (("slave0", 6379), ("slave1", 6379)) - assert next(rotator) in (("slave0", 6379), ("slave1", 6379)) - # Fallback to master - assert next(rotator) == (master_ip, 6379) - with pytest.raises(SlaveNotFoundError): - next(rotator) - - -@pytest.mark.onlynoncluster -def test_ckquorum(cluster, sentinel): - assert sentinel.sentinel_ckquorum("mymaster") - - -@pytest.mark.onlynoncluster -def test_flushconfig(cluster, sentinel): - assert sentinel.sentinel_flushconfig() - - -@pytest.mark.onlynoncluster -def test_reset(cluster, sentinel): - cluster.master["is_odown"] = True - assert sentinel.sentinel_reset("mymaster") diff --git a/tests/test_timeseries.py b/tests/test_timeseries.py deleted file mode 100644 index 8c97ab8..0000000 --- a/tests/test_timeseries.py +++ /dev/null @@ -1,514 +0,0 @@ -import time -from time import sleep - -import pytest - -from .conftest import skip_ifmodversion_lt - - -@pytest.fixture -def client(modclient): - modclient.flushdb() - return modclient - - -@pytest.mark.redismod -def test_create(client): - assert client.ts().create(1) - assert client.ts().create(2, retention_msecs=5) - assert client.ts().create(3, labels={"Redis": "Labs"}) - assert client.ts().create(4, retention_msecs=20, labels={"Time": "Series"}) - info = client.ts().info(4) - assert 20 == info.retention_msecs - assert "Series" == info.labels["Time"] - - # Test for a chunk size of 128 Bytes - assert client.ts().create("time-serie-1", chunk_size=128) - info = client.ts().info("time-serie-1") - assert 128, info.chunk_size - - -@pytest.mark.redismod -@skip_ifmodversion_lt("1.4.0", "timeseries") -def test_create_duplicate_policy(client): - # Test for duplicate policy - for duplicate_policy in ["block", "last", "first", "min", "max"]: - ts_name = f"time-serie-ooo-{duplicate_policy}" - assert client.ts().create(ts_name, duplicate_policy=duplicate_policy) - info = client.ts().info(ts_name) - assert duplicate_policy == info.duplicate_policy - - -@pytest.mark.redismod -def test_alter(client): - assert client.ts().create(1) - assert 0 == client.ts().info(1).retention_msecs - assert client.ts().alter(1, retention_msecs=10) - assert {} == client.ts().info(1).labels - assert 10, client.ts().info(1).retention_msecs - assert client.ts().alter(1, labels={"Time": "Series"}) - assert "Series" == client.ts().info(1).labels["Time"] - assert 10 == client.ts().info(1).retention_msecs - - -@pytest.mark.redismod -@skip_ifmodversion_lt("1.4.0", "timeseries") -def test_alter_diplicate_policy(client): - assert client.ts().create(1) - info = client.ts().info(1) - assert info.duplicate_policy is None - assert client.ts().alter(1, duplicate_policy="min") - info = client.ts().info(1) - assert "min" == info.duplicate_policy - - -@pytest.mark.redismod -def test_add(client): - assert 1 == client.ts().add(1, 1, 1) - assert 2 == client.ts().add(2, 2, 3, retention_msecs=10) - assert 3 == client.ts().add(3, 3, 2, labels={"Redis": "Labs"}) - assert 4 == client.ts().add( - 4, 4, 2, retention_msecs=10, labels={"Redis": "Labs", "Time": "Series"} - ) - assert round(time.time()) == round(float(client.ts().add(5, "*", 1)) / 1000) - - info = client.ts().info(4) - assert 10 == info.retention_msecs - assert "Labs" == info.labels["Redis"] - - # Test for a chunk size of 128 Bytes on TS.ADD - assert client.ts().add("time-serie-1", 1, 10.0, chunk_size=128) - info = client.ts().info("time-serie-1") - assert 128 == info.chunk_size - - -@pytest.mark.redismod -@skip_ifmodversion_lt("1.4.0", "timeseries") -def test_add_duplicate_policy(client): - - # Test for duplicate policy BLOCK - assert 1 == client.ts().add("time-serie-add-ooo-block", 1, 5.0) - with pytest.raises(Exception): - client.ts().add("time-serie-add-ooo-block", 1, 5.0, duplicate_policy="block") - - # Test for duplicate policy LAST - assert 1 == client.ts().add("time-serie-add-ooo-last", 1, 5.0) - assert 1 == client.ts().add( - "time-serie-add-ooo-last", 1, 10.0, duplicate_policy="last" - ) - assert 10.0 == client.ts().get("time-serie-add-ooo-last")[1] - - # Test for duplicate policy FIRST - assert 1 == client.ts().add("time-serie-add-ooo-first", 1, 5.0) - assert 1 == client.ts().add( - "time-serie-add-ooo-first", 1, 10.0, duplicate_policy="first" - ) - assert 5.0 == client.ts().get("time-serie-add-ooo-first")[1] - - # Test for duplicate policy MAX - assert 1 == client.ts().add("time-serie-add-ooo-max", 1, 5.0) - assert 1 == client.ts().add( - "time-serie-add-ooo-max", 1, 10.0, duplicate_policy="max" - ) - assert 10.0 == client.ts().get("time-serie-add-ooo-max")[1] - - # Test for duplicate policy MIN - assert 1 == client.ts().add("time-serie-add-ooo-min", 1, 5.0) - assert 1 == client.ts().add( - "time-serie-add-ooo-min", 1, 10.0, duplicate_policy="min" - ) - assert 5.0 == client.ts().get("time-serie-add-ooo-min")[1] - - -@pytest.mark.redismod -def test_madd(client): - client.ts().create("a") - assert [1, 2, 3] == client.ts().madd([("a", 1, 5), ("a", 2, 10), ("a", 3, 15)]) - - -@pytest.mark.redismod -def test_incrby_decrby(client): - for _ in range(100): - assert client.ts().incrby(1, 1) - sleep(0.001) - assert 100 == client.ts().get(1)[1] - for _ in range(100): - assert client.ts().decrby(1, 1) - sleep(0.001) - assert 0 == client.ts().get(1)[1] - - assert client.ts().incrby(2, 1.5, timestamp=5) - assert (5, 1.5) == client.ts().get(2) - assert client.ts().incrby(2, 2.25, timestamp=7) - assert (7, 3.75) == client.ts().get(2) - assert client.ts().decrby(2, 1.5, timestamp=15) - assert (15, 2.25) == client.ts().get(2) - - # Test for a chunk size of 128 Bytes on TS.INCRBY - assert client.ts().incrby("time-serie-1", 10, chunk_size=128) - info = client.ts().info("time-serie-1") - assert 128 == info.chunk_size - - # Test for a chunk size of 128 Bytes on TS.DECRBY - assert client.ts().decrby("time-serie-2", 10, chunk_size=128) - info = client.ts().info("time-serie-2") - assert 128 == info.chunk_size - - -@pytest.mark.redismod -def test_create_and_delete_rule(client): - # test rule creation - time = 100 - client.ts().create(1) - client.ts().create(2) - client.ts().createrule(1, 2, "avg", 100) - for i in range(50): - client.ts().add(1, time + i * 2, 1) - client.ts().add(1, time + i * 2 + 1, 2) - client.ts().add(1, time * 2, 1.5) - assert round(client.ts().get(2)[1], 5) == 1.5 - info = client.ts().info(1) - assert info.rules[0][1] == 100 - - # test rule deletion - client.ts().deleterule(1, 2) - info = client.ts().info(1) - assert not info.rules - - -@pytest.mark.redismod -@skip_ifmodversion_lt("99.99.99", "timeseries") -def test_del_range(client): - try: - client.ts().delete("test", 0, 100) - except Exception as e: - assert e.__str__() != "" - - for i in range(100): - client.ts().add(1, i, i % 7) - assert 22 == client.ts().delete(1, 0, 21) - assert [] == client.ts().range(1, 0, 21) - assert [(22, 1.0)] == client.ts().range(1, 22, 22) - - -@pytest.mark.redismod -def test_range(client): - for i in range(100): - client.ts().add(1, i, i % 7) - assert 100 == len(client.ts().range(1, 0, 200)) - for i in range(100): - client.ts().add(1, i + 200, i % 7) - assert 200 == len(client.ts().range(1, 0, 500)) - # last sample isn't returned - assert 20 == len( - client.ts().range(1, 0, 500, aggregation_type="avg", bucket_size_msec=10) - ) - assert 10 == len(client.ts().range(1, 0, 500, count=10)) - - -@pytest.mark.redismod -@skip_ifmodversion_lt("99.99.99", "timeseries") -def test_range_advanced(client): - for i in range(100): - client.ts().add(1, i, i % 7) - client.ts().add(1, i + 200, i % 7) - - assert 2 == len( - client.ts().range( - 1, - 0, - 500, - filter_by_ts=[i for i in range(10, 20)], - filter_by_min_value=1, - filter_by_max_value=2, - ) - ) - assert [(0, 10.0), (10, 1.0)] == client.ts().range( - 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align="+" - ) - assert [(-5, 5.0), (5, 6.0)] == client.ts().range( - 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align=5 - ) - - -@pytest.mark.redismod -@skip_ifmodversion_lt("99.99.99", "timeseries") -def test_rev_range(client): - for i in range(100): - client.ts().add(1, i, i % 7) - assert 100 == len(client.ts().range(1, 0, 200)) - for i in range(100): - client.ts().add(1, i + 200, i % 7) - assert 200 == len(client.ts().range(1, 0, 500)) - # first sample isn't returned - assert 20 == len( - client.ts().revrange(1, 0, 500, aggregation_type="avg", bucket_size_msec=10) - ) - assert 10 == len(client.ts().revrange(1, 0, 500, count=10)) - assert 2 == len( - client.ts().revrange( - 1, - 0, - 500, - filter_by_ts=[i for i in range(10, 20)], - filter_by_min_value=1, - filter_by_max_value=2, - ) - ) - assert [(10, 1.0), (0, 10.0)] == client.ts().revrange( - 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align="+" - ) - assert [(1, 10.0), (-9, 1.0)] == client.ts().revrange( - 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align=1 - ) - - -@pytest.mark.redismod -def testMultiRange(client): - client.ts().create(1, labels={"Test": "This", "team": "ny"}) - client.ts().create(2, labels={"Test": "This", "Taste": "That", "team": "sf"}) - for i in range(100): - client.ts().add(1, i, i % 7) - client.ts().add(2, i, i % 11) - - res = client.ts().mrange(0, 200, filters=["Test=This"]) - assert 2 == len(res) - assert 100 == len(res[0]["1"][1]) - - res = client.ts().mrange(0, 200, filters=["Test=This"], count=10) - assert 10 == len(res[0]["1"][1]) - - for i in range(100): - client.ts().add(1, i + 200, i % 7) - res = client.ts().mrange( - 0, 500, filters=["Test=This"], aggregation_type="avg", bucket_size_msec=10 - ) - assert 2 == len(res) - assert 20 == len(res[0]["1"][1]) - - # test withlabels - assert {} == res[0]["1"][0] - res = client.ts().mrange(0, 200, filters=["Test=This"], with_labels=True) - assert {"Test": "This", "team": "ny"} == res[0]["1"][0] - - -@pytest.mark.redismod -@skip_ifmodversion_lt("99.99.99", "timeseries") -def test_multi_range_advanced(client): - client.ts().create(1, labels={"Test": "This", "team": "ny"}) - client.ts().create(2, labels={"Test": "This", "Taste": "That", "team": "sf"}) - for i in range(100): - client.ts().add(1, i, i % 7) - client.ts().add(2, i, i % 11) - - # test with selected labels - res = client.ts().mrange(0, 200, filters=["Test=This"], select_labels=["team"]) - assert {"team": "ny"} == res[0]["1"][0] - assert {"team": "sf"} == res[1]["2"][0] - - # test with filterby - res = client.ts().mrange( - 0, - 200, - filters=["Test=This"], - filter_by_ts=[i for i in range(10, 20)], - filter_by_min_value=1, - filter_by_max_value=2, - ) - assert [(15, 1.0), (16, 2.0)] == res[0]["1"][1] - - # test groupby - res = client.ts().mrange(0, 3, filters=["Test=This"], groupby="Test", reduce="sum") - assert [(0, 0.0), (1, 2.0), (2, 4.0), (3, 6.0)] == res[0]["Test=This"][1] - res = client.ts().mrange(0, 3, filters=["Test=This"], groupby="Test", reduce="max") - assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[0]["Test=This"][1] - res = client.ts().mrange(0, 3, filters=["Test=This"], groupby="team", reduce="min") - assert 2 == len(res) - assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[0]["team=ny"][1] - assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[1]["team=sf"][1] - - # test align - res = client.ts().mrange( - 0, - 10, - filters=["team=ny"], - aggregation_type="count", - bucket_size_msec=10, - align="-", - ) - assert [(0, 10.0), (10, 1.0)] == res[0]["1"][1] - res = client.ts().mrange( - 0, - 10, - filters=["team=ny"], - aggregation_type="count", - bucket_size_msec=10, - align=5, - ) - assert [(-5, 5.0), (5, 6.0)] == res[0]["1"][1] - - -@pytest.mark.redismod -@skip_ifmodversion_lt("99.99.99", "timeseries") -def test_multi_reverse_range(client): - client.ts().create(1, labels={"Test": "This", "team": "ny"}) - client.ts().create(2, labels={"Test": "This", "Taste": "That", "team": "sf"}) - for i in range(100): - client.ts().add(1, i, i % 7) - client.ts().add(2, i, i % 11) - - res = client.ts().mrange(0, 200, filters=["Test=This"]) - assert 2 == len(res) - assert 100 == len(res[0]["1"][1]) - - res = client.ts().mrange(0, 200, filters=["Test=This"], count=10) - assert 10 == len(res[0]["1"][1]) - - for i in range(100): - client.ts().add(1, i + 200, i % 7) - res = client.ts().mrevrange( - 0, 500, filters=["Test=This"], aggregation_type="avg", bucket_size_msec=10 - ) - assert 2 == len(res) - assert 20 == len(res[0]["1"][1]) - assert {} == res[0]["1"][0] - - # test withlabels - res = client.ts().mrevrange(0, 200, filters=["Test=This"], with_labels=True) - assert {"Test": "This", "team": "ny"} == res[0]["1"][0] - - # test with selected labels - res = client.ts().mrevrange(0, 200, filters=["Test=This"], select_labels=["team"]) - assert {"team": "ny"} == res[0]["1"][0] - assert {"team": "sf"} == res[1]["2"][0] - - # test filterby - res = client.ts().mrevrange( - 0, - 200, - filters=["Test=This"], - filter_by_ts=[i for i in range(10, 20)], - filter_by_min_value=1, - filter_by_max_value=2, - ) - assert [(16, 2.0), (15, 1.0)] == res[0]["1"][1] - - # test groupby - res = client.ts().mrevrange( - 0, 3, filters=["Test=This"], groupby="Test", reduce="sum" - ) - assert [(3, 6.0), (2, 4.0), (1, 2.0), (0, 0.0)] == res[0]["Test=This"][1] - res = client.ts().mrevrange( - 0, 3, filters=["Test=This"], groupby="Test", reduce="max" - ) - assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[0]["Test=This"][1] - res = client.ts().mrevrange( - 0, 3, filters=["Test=This"], groupby="team", reduce="min" - ) - assert 2 == len(res) - assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[0]["team=ny"][1] - assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[1]["team=sf"][1] - - # test align - res = client.ts().mrevrange( - 0, - 10, - filters=["team=ny"], - aggregation_type="count", - bucket_size_msec=10, - align="-", - ) - assert [(10, 1.0), (0, 10.0)] == res[0]["1"][1] - res = client.ts().mrevrange( - 0, - 10, - filters=["team=ny"], - aggregation_type="count", - bucket_size_msec=10, - align=1, - ) - assert [(1, 10.0), (-9, 1.0)] == res[0]["1"][1] - - -@pytest.mark.redismod -def test_get(client): - name = "test" - client.ts().create(name) - assert client.ts().get(name) is None - client.ts().add(name, 2, 3) - assert 2 == client.ts().get(name)[0] - client.ts().add(name, 3, 4) - assert 4 == client.ts().get(name)[1] - - -@pytest.mark.redismod -def test_mget(client): - client.ts().create(1, labels={"Test": "This"}) - client.ts().create(2, labels={"Test": "This", "Taste": "That"}) - act_res = client.ts().mget(["Test=This"]) - exp_res = [{"1": [{}, None, None]}, {"2": [{}, None, None]}] - assert act_res == exp_res - client.ts().add(1, "*", 15) - client.ts().add(2, "*", 25) - res = client.ts().mget(["Test=This"]) - assert 15 == res[0]["1"][2] - assert 25 == res[1]["2"][2] - res = client.ts().mget(["Taste=That"]) - assert 25 == res[0]["2"][2] - - # test with_labels - assert {} == res[0]["2"][0] - res = client.ts().mget(["Taste=That"], with_labels=True) - assert {"Taste": "That", "Test": "This"} == res[0]["2"][0] - - -@pytest.mark.redismod -def test_info(client): - client.ts().create(1, retention_msecs=5, labels={"currentLabel": "currentData"}) - info = client.ts().info(1) - assert 5 == info.retention_msecs - assert info.labels["currentLabel"] == "currentData" - - -@pytest.mark.redismod -@skip_ifmodversion_lt("1.4.0", "timeseries") -def testInfoDuplicatePolicy(client): - client.ts().create(1, retention_msecs=5, labels={"currentLabel": "currentData"}) - info = client.ts().info(1) - assert info.duplicate_policy is None - - client.ts().create("time-serie-2", duplicate_policy="min") - info = client.ts().info("time-serie-2") - assert "min" == info.duplicate_policy - - -@pytest.mark.redismod -def test_query_index(client): - client.ts().create(1, labels={"Test": "This"}) - client.ts().create(2, labels={"Test": "This", "Taste": "That"}) - assert 2 == len(client.ts().queryindex(["Test=This"])) - assert 1 == len(client.ts().queryindex(["Taste=That"])) - assert [2] == client.ts().queryindex(["Taste=That"]) - - -@pytest.mark.redismod -def test_pipeline(client): - pipeline = client.ts().pipeline() - pipeline.create("with_pipeline") - for i in range(100): - pipeline.add("with_pipeline", i, 1.1 * i) - pipeline.execute() - - info = client.ts().info("with_pipeline") - assert info.lastTimeStamp == 99 - assert info.total_samples == 100 - assert client.ts().get("with_pipeline")[1] == 99 * 1.1 - - -@pytest.mark.redismod -def test_uncompressed(client): - client.ts().create("compressed") - client.ts().create("uncompressed", uncompressed=True) - compressed_info = client.ts().info("compressed") - uncompressed_info = client.ts().info("uncompressed") - assert compressed_info.memory_usage != uncompressed_info.memory_usage -- 2.34.1