Source code for indra_cogex.sources.processor

# -*- coding: utf-8 -*-

"""Base classes for processors."""

import csv
import gzip
import json
import logging
import pickle
from abc import ABC, abstractmethod
from collections import defaultdict
from pathlib import Path
from typing import ClassVar, Iterable, List, Tuple, Optional, Mapping, Any

import click
import pystow
from more_click import verbose_option
from tqdm import tqdm

from indra.statements.validate import assert_valid_db_refs, assert_valid_evidence
from indra.statements import Evidence

from indra_cogex.representation import Node, Relation, norm_id

__all__ = [
    "Processor",
]

logger = logging.getLogger(__name__)

# deal with importing from wherever with
#  https://stackoverflow.com/questions/36922843/neo4j-3-x-load-csv-absolute-file-path
# Find neo4j conf and comment out this line: dbms.directories.import=import
# /usr/local/Cellar/neo4j/4.1.3/libexec/conf/neo4j.conf for me
# data stored in /usr/local/var/neo4j/data/databases


# See https://neo4j.com/docs/operations-manual/4.4/tools/neo4j-admin/neo4j-admin-import/
# especially sections 4, 5 and 6
NEO4J_DATA_TYPES = (
    "int",
    "long",
    "float",
    "double",
    "boolean",
    "byte",
    "short",
    "char",
    "string",
    "point",
    "date",
    "localtime",
    "time",
    "localdatetime",
    "datetime",
    "duration",
    # Used in node files
    "ID",
    "LABEL",
    # Used in relationship files
    "START_ID",
    "END_ID",
    "TYPE",
)


[docs]class Processor(ABC): """A processor creates nodes and iterables to upload to Neo4j.""" name: ClassVar[str] module: ClassVar[pystow.Module] directory: ClassVar[Path] nodes_path: ClassVar[Path] nodes_indra_path: ClassVar[Path] edges_path: ClassVar[Path] importable = True node_types = ClassVar[Iterable[str]] def __init_subclass__(cls, **kwargs): """Initialize the class attributes.""" cls.module = pystow.module("indra", "cogex", cls.name) cls.directory = cls.module.base # These are nodes directly in the neo4j encoding cls.nodes_path = cls.module.join(name="nodes.tsv.gz") # These are nodes in the original INDRA-oriented representation # needed for assembly cls.nodes_indra_path = cls.module.join(name="nodes.pkl") cls.edges_path = cls.module.join(name="edges.tsv.gz")
[docs] @abstractmethod def get_nodes(self) -> Iterable[Node]: """Iterate over the nodes to upload."""
[docs] @abstractmethod def get_relations(self) -> Iterable[Relation]: """Iterate over the relations to upload."""
[docs] @classmethod def get_cli(cls) -> click.Command: """Get the CLI for this processor.""" @click.command() @verbose_option def _main(): click.secho(f"Building {cls.name}", fg="green", bold=True) processor = cls() processor.dump() return _main
[docs] @classmethod def cli(cls) -> None: """Run the CLI for this processor.""" cls.get_cli()()
[docs] def dump(self) -> Tuple[Path, List[Node], Path]: """Dump the contents of this processor to CSV files ready for use in ``neo4-admin import``.""" node_paths, nodes = self._dump_nodes() edge_paths = self._dump_edges() return node_paths, nodes, edge_paths
@classmethod def _get_node_paths(cls, node_type: str) -> Tuple[Path, Path, Path]: # If the processor returns multiple types of nodes, add node_type to the file name if len(cls.node_types) > 1: return ( cls.module.join(name=f"nodes_{node_type}.tsv.gz"), cls.module.join(name=f"nodes_{node_type}.pkl"), cls.module.join(name=f"nodes_{node_type}_sample.tsv"), ) return ( cls.nodes_path, cls.nodes_indra_path, cls.module.join(name="nodes_sample.tsv"), ) def _dump_nodes(self) -> Tuple[Path, List[Node]]: paths_by_type = {} nodes_by_type = defaultdict(list) # Get all the nodes nodes = tqdm( self.get_nodes(), desc="Node generation", unit_scale=True, unit="node", ) # Map the nodes to their types for node in nodes: nodes_by_type[node.labels[0]].append(node) # Get the paths for each type of node and dump the nodes for node_type in nodes_by_type: nodes_path, nodes_indra_path, sample_path = self._get_node_paths(node_type) nodes = sorted(nodes_by_type[node_type], key=lambda x: (x.db_ns, x.db_id)) with open(nodes_indra_path, "wb") as fh: pickle.dump(nodes, fh) self._dump_nodes_to_path(nodes, nodes_path, sample_path) paths_by_type[node_type] = nodes_path return paths_by_type, dict(nodes_by_type) def _dump_nodes_to_path(self, nodes, nodes_path, sample_path=None, write_mode="wt"): logger.info(f"Dumping into {nodes_path}...") nodes = list(validate_nodes(nodes)) metadata = sorted(set(key for node in nodes for key in node.data)) try: validate_headers(metadata) except TypeError as e: logger.error(f"Bad node data type in header for {self.name}") raise e node_rows = ( ( norm_id(node.db_ns, node.db_id), ";".join(node.labels), *[node.data.get(key, "") for key in metadata], ) for node in tqdm(nodes, desc="Node serialization", unit_scale=True) ) header = "id:ID", ":LABEL", *metadata with gzip.open(nodes_path, mode=write_mode) as node_file: node_writer = csv.writer(node_file, delimiter="\t") # type: ignore # Only add header when writing to a new file if write_mode == "wt": node_writer.writerow(header) if sample_path: with sample_path.open("w") as node_sample_file: node_sample_writer = csv.writer(node_sample_file, delimiter="\t") node_sample_writer.writerow(header) for _, node_row in zip(range(10), node_rows): node_sample_writer.writerow(node_row) node_writer.writerow(node_row) # Write remaining nodes node_writer.writerows(node_rows) return nodes_path def _dump_edges(self) -> Path: sample_path = self.module.join(name="edges_sample.tsv") logger.info(f"Dumping into {self.edges_path}...") rels = self.get_relations() return self._dump_edges_to_path(rels, self.edges_path, sample_path) def _dump_edges_to_path(self, rels, edges_path, sample_path=None, write_mode="wt"): logger.info(f"Dumping into {edges_path}...") rels = validate_relations(rels) rels = sorted( rels, key=lambda r: (r.source_ns, r.source_id, r.target_ns, r.target_id) ) metadata = sorted(set(key for rel in rels for key in rel.data)) try: validate_headers(metadata) except TypeError as e: logger.error(f"Bad edge data type in header for {self.name}") raise e edge_rows = ( ( norm_id(rel.source_ns, rel.source_id), norm_id(rel.target_ns, rel.target_id), rel.rel_type, *[rel.data.get(key) for key in metadata], ) for rel in tqdm(rels, desc="Edges", unit_scale=True) ) with gzip.open(self.edges_path, mode=write_mode) as edge_file: edge_writer = csv.writer(edge_file, delimiter="\t") # type: ignore header = ":START_ID", ":END_ID", ":TYPE", *metadata # Only add header when writing to a new file if write_mode == "wt": edge_writer.writerow(header) if sample_path: with sample_path.open("w") as edge_sample_file: edge_sample_writer = csv.writer(edge_sample_file, delimiter="\t") edge_sample_writer.writerow(header) for _, edge_row in zip(range(10), edge_rows): edge_sample_writer.writerow(edge_row) edge_writer.writerow(edge_row) # Write remaining edges edge_writer.writerows(edge_rows) return edges_path
def assert_valid_node( db_ns: str, db_id: str, data: Optional[Mapping[str, Any]] = None ) -> None: if db_ns == "indra_evidence": if data and data.get("evidence"): ev = Evidence._from_json(json.loads(data["evidence"])) assert_valid_evidence(ev) else: assert_valid_db_refs({db_ns: db_id}) def validate_nodes(nodes: Iterable[Node]) -> Iterable[Node]: for idx, node in enumerate(nodes): try: assert_valid_node(node.db_ns, node.db_id, node.data) yield node except Exception as e: logger.info(f"{idx}: {node} - {e}") continue def validate_relations(relations: Iterable[Relation]) -> Iterable[Relation]: for idx, rel in enumerate(relations): try: assert_valid_node(rel.source_ns, rel.source_id) assert_valid_node(rel.target_ns, rel.target_id) yield rel except Exception as e: logger.info(f"{idx}: {rel} - {e}") continue def validate_headers(headers: Iterable[str]) -> None: """Check for data types in the headers""" for header in headers: # If : is in the header and there is something after it check if # it's a valid data type if ":" in header and header.split(":")[1]: dtype = header.split(":")[1] # Strip trailing '[]' for array types if dtype.endswith("[]"): dtype = dtype[:-2] if dtype not in NEO4J_DATA_TYPES: raise TypeError( f"Invalid header data type '{dtype}' for header {header}" )