Source code for indra_cogex.client.neo4j_client

"""Neo4j client module."""

import inspect
import logging
from functools import lru_cache, wraps
from itertools import count
from typing import Any, Dict, Iterable, List, Mapping, Optional, Set, Tuple, Union
import json

import neo4j.graph
from indra.config import get_config
from indra.databases import identifiers
from indra.ontology.standardize import get_standard_agent
from indra.statements import Agent
from neo4j import GraphDatabase, Transaction, unit_of_work

from indra_cogex.representation import Node, Relation, norm_id, \
    triple_query, triple_parameter_query

__all__ = ["Neo4jClient", "autoclient"]

logger = logging.getLogger(__name__)


[docs]class Neo4jClient: """A client to communicate with an INDRA CogEx neo4j instance Parameters ---------- url : The bolt URL to the neo4j instance to override INDRA_NEO4J_URL set as an environment variable or set in the INDRA config file. auth : A tuple consisting of the user name and password for the neo4j instance to override INDRA_NEO4J_USER and INDRA_NEO4J_PASSWORD set as environment variables or set in the INDRA config file. """ #: The session session: Optional[neo4j.Session] def __init__( self, url: Optional[str] = None, auth: Optional[Tuple[str, str]] = None, ): """Initialize the Neo4j client.""" self.driver = None self.session = None if not url: INDRA_NEO4J_URL = get_config("INDRA_NEO4J_URL") if INDRA_NEO4J_URL: url = INDRA_NEO4J_URL logger.info("Using configured URL for INDRA neo4j connection") else: logger.info("INDRA_NEO4J_URL not configured") if not auth: INDRA_NEO4J_USER = get_config("INDRA_NEO4J_USER") INDRA_NEO4J_PASSWORD = get_config("INDRA_NEO4J_PASSWORD") if INDRA_NEO4J_USER and INDRA_NEO4J_PASSWORD: auth = (INDRA_NEO4J_USER, INDRA_NEO4J_PASSWORD) logger.info("Using configured credentials for INDRA neo4j connection") else: logger.info("INDRA_NEO4J_USER and INDRA_NEO4J_PASSWORD not configured") # Set max_connection_lifetime to something smaller than the timeouts # on the server or on the way to the server. See # https://github.com/neo4j/neo4j-python-driver/issues/316#issuecomment-564020680 self.driver = GraphDatabase.driver( url, auth=auth, max_connection_lifetime=3 * 60, ) def __del__(self): # Safely shut down the driver as a Neo4jClient object is garbage collected # https://neo4j.com/docs/api/python-driver/current/api.html#driver-object-lifetime if self.driver is not None: self.driver.close()
[docs] def create_tx( self, query: str, query_params: Optional[Mapping[str, Any]] = None, ): """Run a transaction which writes to the neo4j instance. Parameters ---------- query : The query string to be executed. query_params : Parameters associated with the query. """ with self.driver.session() as session: return session.write_transaction( do_cypher_tx, query, query_params=query_params )
[docs] def query_dict(self, query: str, **query_params) -> Dict: """Run a read-only query that generates a dictionary.""" return dict(self.query_tx(query, **query_params))
[docs] def query_dict_value_json(self, query: str, **query_params) -> Dict: """Run a read-only query that generates a dictionary.""" return { key: json.loads(j) for key, j in self.query_tx(query, **query_params) }
[docs] def query_tx( self, query: str, squeeze: bool = False, **query_params ) -> Union[List[List[Any]], None]: """Run a read-only query and return the results. Parameters ---------- query : The query string to be executed. squeeze : If true, unpacks the 0-indexed element in each value returned. Useful when only returning value per row of the results. query_params : kwargs to pass to query Returns ------- values : A list of results where each result is a list of one or more objects (typically neo4j nodes or relations). """ # For documentation on the session and transaction classes see # https://neo4j.com/docs/api/python-driver/4.4/api.html#session-construction # and # https://neo4j.com/docs/api/python-driver/4.4/api.html#explicit-transactions # Documentation on transaction functions are here: # https://neo4j.com/docs/python-manual/4.4/session-api/#python-driver-simple-transaction-fn with self.driver.session() as session: # do_cypher_tx is ultimately called as # `transaction_function(tx, *args, **kwargs)` in the neo4j code, # where *args and **kwargs are passed through unchanged, meaning # do_cypher_tx can expect query and **query_params values = session.read_transaction( do_cypher_tx, query, **query_params ) if squeeze: values = [value[0] for value in values] return values
[docs] def query_nodes(self, query: str, **query_params) -> List[Node]: """Run a read-only query for nodes. Parameters ---------- query : The query string to be executed. query_params : Query parameters to pass to cypher Returns ------- values : A list of :class:`Node` instances corresponding to the results of the query """ return [self.neo4j_to_node(res) for res in self.query_tx(query, squeeze=True, **query_params)]
[docs] def query_relations(self, query: str, **query_params) -> List[Relation]: """Run a read-only query for relations. Parameters ---------- query : The query string to be executed. Must have a ``RETURN`` with a single element ``p`` where in the ``MATCH`` part of the query it has something like ``p=(h)-[r]->(t)``. query_params : Query parameters to pass to query transaction function that will fill out the placeholders in the cypher query Returns ------- values : A list of :class:`Relation` instances corresponding to the results of the query """ return [ self.neo4j_to_relation(res) for res in self.query_tx(query, squeeze=True, **query_params) ]
[docs] def get_session(self, renew: Optional[bool] = False) -> neo4j.Session: """Return an existing session or create one if needed. Parameters ---------- renew : If True, a new session is created. Default: False Returns ------- session A neo4j session. """ if self.session is None or renew: sess = self.driver.session() self.session = sess return self.session
[docs] def close_session(self): """Close the session if it exists.""" if self.session is not None: self.session.close()
[docs] def has_relation( self, source: Tuple[str, str], target: Tuple[str, str], relation: str, source_type: Optional[str] = None, target_type: Optional[str] = None, ) -> bool: """Return True if there is a relation between the source and the target. Parameters ---------- source : Source namespace and identifier. target : Target namespace and identifier. relation : Relation type. source_type : A constraint on the source type target_type : A constraint on the target type Returns ------- related : True if there is a relation of the given type, otherwise False. """ res = self.get_relations( source, target, relation, limit=1, source_type=source_type, target_type=target_type, ) if res: return True else: return False
[docs] def get_relations( self, source: Optional[Tuple[str, str]] = None, target: Optional[Tuple[str, str]] = None, relation: Optional[str] = None, source_type: Optional[str] = None, target_type: Optional[str] = None, limit: Optional[int] = None, bidirectional: Optional[bool] = False, ) -> List[Relation]: """Return relations based on source, target and type constraints. This is a generic function for getting relations, all of its parameters are optional, though at least a source or a target needs to be provided. Parameters ---------- source : Surce namespace and ID. target : Target namespace and ID. relation : Relation type. source_type : A constraint on the source type target_type : A constraint on the target type limit : A limit on the number of relations returned. bidirectional : If True, return both directions of relationships between the source and target. Returns ------- rels : A list of relations matching the constraints. """ if not source and not target: raise ValueError("source or target should be specified") source = norm_id(*source) if source else None target = norm_id(*target) if target else None match = triple_query( source_id=source, source_type=source_type, relation_type=relation, target_id=target, target_type=target_type, relation_direction=("both" if bidirectional else "right"), ) query = """ MATCH p=%s RETURN DISTINCT p %s """ % ( match, "" if not limit else "LIMIT %s" % limit, ) return self.query_relations(query)
[docs] def get_source_relations( self, target: Tuple[str, str], relation: Optional[str] = None, target_type: Optional[str] = None, source_type: Optional[str] = None, ) -> List[Relation]: """Get relations that connect sources to the given target. Parameters ---------- target : Target namespace and identifier. relation : Relation type. target_type : A constraint on the target node type. source_type : A constraint on the source node type. Returns ------- rels : A list of relations matching the constraints. """ return self.get_relations( source=None, target=target, relation=relation, target_type=target_type, source_type=source_type, )
[docs] def get_target_relations( self, source: Tuple[str, str], relation: Optional[str] = None, source_type: Optional[str] = None, target_type: Optional[str] = None, ) -> List[Relation]: """Get relations that connect targets from the given source. Parameters ---------- source : Source namespace and identifier. relation : Relation type. source_type : A constraint on the source node type. target_type : A constraint on the target node type. Returns ------- rels : A list of relations matching the constraints. """ return self.get_relations( source=source, target=None, relation=relation, source_type=source_type, target_type=target_type, )
def get_target_relations_for_sources( self, sources: Iterable[Tuple[str, str]], relation: Optional[str] = None, source_type: Optional[str] = None, target_type: Optional[str] = None, ) -> Mapping[Tuple[str, str], List[Relation]]: match = triple_query( source_name="s", source_type=source_type, relation_type=relation, target_type=target_type, ) sources = [norm_id(*source) for source in sources] query = """ MATCH p=%s WHERE s.id IN $sources RETURN p """ % match from collections import defaultdict rels = defaultdict(list) for res in self.query_tx(query, squeeze=True, sources=sources): rel = self.neo4j_to_relation(res) rels[(rel.source_ns, rel.source_id)].append(rel) return rels def get_source_relations_for_targets( self, targets: Iterable[Tuple[str, str]], relation: Optional[str] = None, target_type: Optional[str] = None, source_type: Optional[str] = None, ) -> Mapping[Tuple[str, str], List[Relation]]: match = triple_query( source_type=source_type, relation_type=relation, target_name="t", target_type=target_type, ) targets = [norm_id(*target) for target in targets] query = """ MATCH p=%s WHERE t.id IN $targets RETURN p """ % match from collections import defaultdict rels = defaultdict(list) for res in self.query_tx(query, squeeze=True, targets=targets): rel = self.neo4j_to_relation(res) rels[(rel.target_ns, rel.target_id)].append(rel) return rels
[docs] def get_all_relations( self, node: Tuple[str, str], relation: Optional[str] = None, node_type: Optional[str] = None, other_type: Optional[str] = None, ) -> List[Relation]: """Get relations that connect sources and targets with the given node. Parameters ---------- node : Node namespace and identifier. relation : Relation type. node_type : Type constraint on the queried node itself other_type : Type constraint on the other node in the relation Returns ------- rels : A list of relations matching the constraints. """ rels = self.get_relations( source=node, relation=relation, source_type=node_type, target_type=other_type, bidirectional=True, ) return rels
[docs] @staticmethod def get_property_from_relations(relations: List[Relation], prop: str) -> Set[str]: """Return the set of property values on given relations. Parameters ---------- relations : The relations, each of which may or may not contain a value for the given property. prop : The key/name of the property to look for on each relation. Returns ------- props A set of the values of the given property on the given list of relations. """ props = {rel.data[prop] for rel in relations if prop in rel.data} return props
[docs] def get_sources( self, target: Tuple[str, str], relation: str = None, source_type: Optional[str] = None, target_type: Optional[str] = None, ) -> List[Node]: """Return the nodes related to the target via a given relation type. Parameters ---------- target : The target node's ID. relation : The relation label to constrain to when finding sources. source_type : A constraint on the source type target_type : A constraint on the target type Returns ------- sources A list of source nodes. """ return self.get_common_sources( [target], relation, source_type=source_type, target_type=target_type, )
[docs] def get_common_sources( self, targets: List[Tuple[str, str]], relation: str, source_type: Optional[str] = None, target_type: Optional[str] = None, ) -> List[Node]: """Return the common source nodes related to all the given targets via a given relation type. Parameters ---------- targets : The target nodes' IDs. relation : The relation label to constrain to when finding sources. source_type : A constraint on the source type target_type : A constraint on the target type Returns ------- sources A list of source nodes. """ name_generator = (f"target_id{c}" for c in count(0)) prop_params = [] for _ in range(len(targets)): prop_params.append(next(name_generator)) parts = [] query_params = {} for prop_param, target in zip(prop_params, targets): part = triple_parameter_query( source_name="s", source_type=source_type, relation_type=relation, target_prop_name="id", target_prop_param=prop_param, target_type=target_type, ) parts.append(part) query_params[prop_param] = norm_id(*target) query = """ MATCH %s RETURN DISTINCT s """ % ",".join( parts ) return self.query_nodes(query, **query_params)
[docs] def get_targets( self, source: Tuple[str, str], relation: Optional[str] = None, source_type: Optional[str] = None, target_type: Optional[str] = None, ) -> List[Node]: """Return the nodes related to the source via a given relation type. Parameters ---------- source : Source namespace and identifier. relation : The relation label to constrain to when finding targets. source_type : A constraint on the source type target_type : A constraint on the target type Returns ------- targets A list of target nodes. """ return self.get_common_targets( [source], relation, source_type=source_type, target_type=target_type, )
[docs] def get_common_targets( self, sources: List[Tuple[str, str]], relation: str, source_type: Optional[str] = None, target_type: Optional[str] = None, ) -> List[Node]: """Return the common target nodes related to all the given sources via a given relation type. Parameters ---------- sources : Source namespace and identifier. relation : The relation label to constrain to when finding targets. source_type : A constraint on the source type target_type : A constraint on the target type Returns ------- targets A list of target nodes. """ name_generator = (f"source_id{c}" for c in count(0)) prop_params = [] for _ in range(len(sources)): prop_params.append(next(name_generator)) parts = [] query_params = {} for prop_param, source in zip(prop_params, sources): part = triple_parameter_query( source_prop_name="id", source_prop_param=prop_param, source_type=source_type, relation_type=relation, target_name="t", target_type=target_type, ) parts.append(part) query_params[prop_param] = norm_id(*source) query = """ MATCH %s RETURN DISTINCT t """ % ",".join( parts ) return self.query_nodes(query, **query_params)
[docs] def get_target_agents( self, source: Tuple[str, str], relation: str, source_type: Optional[str] = None, ) -> List[Agent]: """Return the nodes related to the source via a given relation type as INDRA Agents. Parameters ---------- source : Source namespace and identifier. relation : The relation label to constrain to when finding targets. source_type : A constraint on the source type Returns ------- targets A list of target nodes as INDRA Agents. """ targets = self.get_targets(source, relation, source_type=source_type) agents = [self.node_to_agent(target) for target in targets] return agents
[docs] def get_source_agents(self, target: Tuple[str, str], relation: str) -> List[Agent]: """Return the nodes related to the target via a given relation type as INDRA Agents. Parameters ---------- target : Target namespace and identifier. relation : The relation label to constrain to when finding sources. Returns ------- sources A list of source nodes as INDRA Agents. """ sources = self.get_sources( target, relation, source_type="BioEntity", target_type="BioEntity", ) agents = [self.node_to_agent(source) for source in sources] return agents
[docs] def get_predecessors( self, target: Tuple[str, str], relations: Iterable[str], source_type: Optional[str] = None, target_type: Optional[str] = None, ) -> List[Node]: """Return the nodes that precede the given node via the given relation types. Parameters ---------- target : The target node's ID. relations : The relation labels to constrain to when finding predecessors. source_type : A constraint on the source type target_type : A constraint on the target type Returns ------- predecessors A list of predecessor nodes. """ match = triple_parameter_query( source_name="s", source_type=source_type, relation_type="%s*1.." % "|".join(relations), target_prop_param="target", target_prop_name="id", target_type=target_type, ) query = ( """ MATCH %s RETURN DISTINCT s """ % match ) return self.query_nodes(query, target=norm_id(*target))
[docs] def get_successors( self, source: Tuple[str, str], relations: Iterable[str], source_type: Optional[str] = None, target_type: Optional[str] = None, ) -> List[Node]: """Return the nodes that precede the given node via the given relation types. Parameters ---------- source : The source node's ID. relations : The relation labels to constrain to when finding successors. source_type : A constraint on the source type target_type : A constraint on the target type Returns ------- predecessors A list of successors nodes. """ match = triple_parameter_query( source_prop_param="source", source_prop_name="id", source_type=source_type, relation_type="%s*1.." % "|".join(relations), target_name="t", target_type=target_type, ) query = ( """ MATCH %s RETURN DISTINCT t """ % match ) return self.query_nodes(query, source=norm_id(*source))
[docs] @staticmethod def neo4j_to_node(neo4j_node: neo4j.graph.Node) -> Node: """Return a Node from a neo4j internal node. Parameters ---------- neo4j_node : A neo4j internal node using its internal data structure and identifier scheme. Returns ------- node : A Node object with the INDRA standard identifier scheme. """ props = dict(neo4j_node) node_id = props.pop("id") db_ns, db_id = process_identifier(node_id) return Node(db_ns, db_id, neo4j_node.labels, props)
[docs] @classmethod def neo4j_to_relation(cls, neo4j_path: neo4j.graph.Path) -> Relation: """Return a Relation from a neo4j internal single-relation path. Parameters ---------- neo4j_path : A neo4j internal single-edge path using its internal data structure and identifier scheme. Returns ------- relation : A Relation object with the INDRA standard identifier scheme. """ return cls.neo4j_to_relations(neo4j_path)[0]
[docs] @staticmethod def neo4j_to_relations(neo4j_path: neo4j.graph.Path) -> List[Relation]: """Return a list of Relations from a neo4j internal multi-relation path. Parameters ---------- neo4j_path : A neo4j internal single-edge path using its internal data structure and identifier scheme. Returns ------- : A list of Relation objects with the INDRA standard identifier scheme. """ relations = [] for neo4j_relation in neo4j_path.relationships: rel_type = neo4j_relation.type props = dict(neo4j_relation) source_ns, source_id = process_identifier(neo4j_relation.start_node["id"]) target_ns, target_id = process_identifier(neo4j_relation.end_node["id"]) rel = Relation(source_ns, source_id, target_ns, target_id, rel_type, props) relations.append(rel) return relations
[docs] @staticmethod def node_to_agent(node: Node) -> Agent: """Return an INDRA Agent from a Node. Parameters ---------- node : A Node object. Returns ------- agent : An INDRA Agent with standardized name and expanded/standardized db_refs. """ name = node.data.get("name") if not name: name = f"{node.db_ns}:{node.db_id}" return get_standard_agent(name, {node.db_ns: node.db_id})
[docs] def delete_all(self): """Delete everything in the neo4j database.""" query = """MATCH(n) DETACH DELETE n""" return self.create_tx(query)
[docs] def create_nodes(self, nodes: List[Node]): """Create a set of new graph nodes.""" nodes_str = ",\n".join([str(n) for n in nodes]) query = """CREATE %s""" % nodes_str return self.create_tx(query)
[docs] def add_nodes(self, nodes: List[Node]): """Merge a set of graph nodes (create or update).""" if not nodes: return prop_str = ",\n".join(["n.%s = node.%s" % (k, k) for k in nodes[0].data]) # labels_str = ':'.join(nodes[0].labels) query = ( """ UNWIND $nodes AS node MERGE (n {id: node.id}) SET %s WITH n, node CALL apoc.create.addLabels(n, node.labels) YIELD n """ % prop_str ) return self.create_tx( query, query_params={ "nodes": [dict(**n.to_json()["data"], labels=n.labels) for n in nodes] }, )
[docs] def add_relations(self, relations: List[Relation]): """Merge a set of graph relations (create or update).""" if not relations: return None labels_str = relations[0].rel_type prop_str = ",\n".join( ["rel.%s = relation.%s" % (k, k) for k in relations[0].data] ) query = """ UNWIND $relations AS relation MATCH (e1 {id: relation.source_id}), (e2 {id: relation.target_id}) MERGE (e1)-[rel:%s]->(e2) SET %s """ % ( labels_str, prop_str, ) rel_params = [] for rel in relations: rd = dict(source_id=rel.source_id, target_id=rel.target_id, **rel.data) rel_params.append(rd) return self.create_tx(query, query_params={"relations": rel_params})
[docs] def add_node(self, node: Node): """Merge a single node into the graph.""" prop_str = ",\n".join(["n.%s = '%s'" % (k, v) for k, v in node.data.items()]) query = """ MERGE (n:%s {id: '%s'}) SET %s """ % ( node.labels, norm_id(node.db_ns, node.db_id), prop_str, ) return self.create_tx(query)
[docs] def create_single_property_node_index( self, index_name: str, label: str, property_name: str, exist_ok: bool = False ): """Create a single property node index. Reference: https://neo4j.com/docs/cypher-manual/4.4/indexes-for-search-performance/#administration-indexes-create-a-single-property-b-tree-index-only-if-it-does-not-already-exist Parameters ---------- index_name : The name of the index. label : The label of the node. property_name : The property name to index. exist_ok : If True, ignore the indexes that already exist. If False, raise error if index already exists. Default: False. """ logger.info( f"Creating index '{index_name}' for label '{label}' on property " f"'{property_name}'. Index is created in background and may not " f"be available immediately." ) if_not = " IF NOT EXISTS" if exist_ok else "" create_query = ( f"CREATE INDEX {index_name}{if_not} FOR (n:{label}) ON (n.{property_name})" ) self.create_tx(create_query)
[docs] def create_single_property_relationship_index( self, index_name: str, rel_type: str, property_name: str ): """Create a single property relationship index. NOTE: Relationship indexes can only be created once, and there is no IF NOT EXISTS option to silently ignore if the index already exists. Reference: https://neo4j.com/docs/cypher-manual/4.4/indexes-for-search-performance/#administration-indexes-create-a-single-property-b-tree-index-for-relationships Parameters ---------- index_name : The name of the index. rel_type : The relationship type to index a property on property_name : The property name to index. """ logger.info( f"Creating index '{index_name}' for relationship type '{rel_type}' on " f"property '{property_name}'. Index is created in background and may not " f"be available immediately." ) create_query = ( f"CREATE INDEX {index_name} FOR ()-[r:{rel_type}]-() ON (r.{property_name})" ) self.create_tx(create_query)
def process_identifier(identifier: str) -> Tuple[str, str]: """Process a neo4j-internal identifier string into an INDRA namespace and ID. Parameters ---------- identifier : An identifier string (containing both prefix and ID) corresponding to an internal neo4j graph node. Returns ------- : A tuple of the INDRA-standard namespace, identifier corresponding to the input identifier. """ graph_ns, graph_id = identifier.split(":", maxsplit=1) db_ns, db_id = identifiers.get_ns_id_from_identifiers(graph_ns, graph_id) # This is a corner case where the prefix is not in the registry # and in those cases we just use the upper case version of the prefix # in the graph to revert it to the INDRA-compatible key. if not db_ns: db_ns = graph_ns.upper() db_id = graph_id else: db_id = identifiers.ensure_prefix_if_needed(db_ns, db_id) return db_ns, db_id
[docs]def autoclient(*, cache: bool = False, maxsize: Optional[int] = 128): """Wrap a function that takes a client for easier usage. Arguments --------- cache : Should the result be cached using :func:`functools.lru_cache`? Is False by default. maxsize : If cache is True, this is the value passed to the ``maxsize`` argument of :func:`functools.lru_cache`. Set to None for unlimited caching, but beware that this can potentially use a lot of memory and isn't a good idea for queries that can take a lot of different kinds of input over time. Returns ------- : A decorator object that will wrap the function Examples -------- Not appropriate for caching (i.e., many possible inputs, especially in a web app scenario): .. code-block:: python @autoclient() def get_tissues_for_gene(gene: Tuple[str, str], *, client: Neo4jClient): return client.get_targets( gene, relation="expressed_in", source_type="BioEntity", target_type="BioEntity", ) Appropriate for caching (e.g., doen't take inputs at all): .. code-block:: python @autoclient(cache=True, maxsize=1) def get_node_count(*, client: Neo4jClient) -> Counter: return Counter( { label[0]: client.query_tx(f"MATCH (n:{label[0]}) RETURN count(*)")[0][0] for label in client.query_tx("call db.labels();") } ) """ def _decorator(func): signature = inspect.signature(func) client_param = signature.parameters.get("client") if client_param is None: raise ValueError( "the autoclient decorator can't be applied to a function that" " doesn't take a neo4j client." ) if client_param.kind != inspect.Parameter.KEYWORD_ONLY: raise ValueError( "the autoclient decorator can't be applied to a function whose" " `client` argument isn't keyword-only" ) @wraps(func) def _wrapped(*args, **kwargs): client = kwargs.get("client") if client is None: kwargs["client"] = Neo4jClient() rv = func(*args, **kwargs) if client is None: kwargs["client"].close_session() return rv if cache: _wrapped = lru_cache(maxsize=maxsize)(_wrapped) return _wrapped return _decorator
# Follows example here: # https://neo4j.com/docs/python-manual/4.4/session-api/#python-driver-simple-transaction-fn # and from the docstring of neo4j.Session.read_transaction @unit_of_work() def do_cypher_tx( tx: Transaction, query: str, **query_params ) -> List[List]: # 'parameters' and '**kwparameters' of tx.run are ultimately merged at query # run-time result = tx.run(query, parameters=query_params) return [record.values() for record in result]