Source code for prolif.plotting.network

"""
Plot a Ligand Interaction Network --- :mod:`prolif.plotting.network`
====================================================================

.. versionadded:: 0.3.2

.. versionchanged:: 2.0.0
    Replaced ``LigNetwork.from_ifp`` with ``LigNetwork.from_fingerprint`` which works
    without requiring a dataframe with atom indices.

.. autoclass:: LigNetwork
   :members:

"""

import json
import re
import warnings
from collections import defaultdict
from copy import deepcopy
from html import escape
from pathlib import Path

import numpy as np
import pandas as pd
from rdkit import Chem
from rdkit.Chem import rdDepictor

from prolif.exceptions import RunRequiredError
from prolif.plotting.utils import grouped_interaction_colors
from prolif.residue import ResidueId
from prolif.utils import requires

try:
    from IPython.display import HTML
except ModuleNotFoundError:
    pass
else:
    warnings.filterwarnings(
        "ignore", "Consider using IPython.display.IFrame instead"  # pragma: no cover
    )


[docs]class LigNetwork: """Creates a ligand interaction diagram Parameters ---------- df : pandas.DataFrame Dataframe with a 4-level index (ligand, protein, interaction, atoms) and ``weight`` and ``distance`` columns for values lig_mol : rdkit.Chem.rdChem.Mol Ligand molecule use_coordinates : bool If ``True``, uses the coordinates of the molecule directly, otherwise generates 2D coordinates from scratch. See also ``flatten_coordinates``. flatten_coordinates : bool If this is ``True`` and ``use_coordinates=True``, generates 2D coordinates that are constrained to fit the 3D conformation of the ligand as best as possible. kekulize : bool Kekulize the ligand molsize : int Multiply the coordinates by this number to create a bigger and more readable depiction rotation : int Rotate the structure on the XY plane carbon : float Size of the carbon atom dots on the depiction. Use `0` to hide the carbon dots Attributes ---------- COLORS : dict Dictionnary of colors used in the diagram. Subdivided in several dictionaries: - "interactions": mapping between interactions types and colors - "atoms": mapping between atom symbol and colors - "residues": mapping between residues types and colors RESIDUE_TYPES : dict Mapping between residue names (3 letter code) and types. The types are then used to define how each residue should be colored. Notes ----- You can customize the diagram by tweaking :attr:`LigNetwork.COLORS` and :attr:`LigNetwork.RESIDUE_TYPES` by adding or modifying the dictionaries inplace. .. versionchanged:: 2.0.0 Replaced ``LigNetwork.from_ifp`` with ``LigNetwork.from_fingerprint`` which works without requiring a dataframe with atom indices. Replaced ``match3D`` parameter with ``use_coordinates`` and ``flatten_coordinates`` to give users more control and allow them to provide their own 2D coordinates. Added support for displaying peptides as the "ligand". Changed the default color for VanDerWaals. """ COLORS = { "interactions": {**grouped_interaction_colors}, "atoms": { "C": "black", "N": "blue", "O": "red", "S": "#dece1b", "P": "orange", "F": "lime", "Cl": "lime", "Br": "lime", "I": "lime", }, "residues": { "Aliphatic": "#59e382", "Aromatic": "#b559e3", "Acidic": "#e35959", "Basic": "#5979e3", "Polar": "#59bee3", "Sulfur": "#e3ce59", }, } RESIDUE_TYPES = { "ALA": "Aliphatic", "GLY": "Aliphatic", "ILE": "Aliphatic", "LEU": "Aliphatic", "PRO": "Aliphatic", "VAL": "Aliphatic", "PHE": "Aromatic", "TRP": "Aromatic", "TYR": "Aromatic", "ASP": "Acidic", "GLU": "Acidic", "ARG": "Basic", "HIS": "Basic", "HID": "Basic", "HIE": "Basic", "HIP": "Basic", "HSD": "Basic", "HSE": "Basic", "HSP": "Basic", "LYS": "Basic", "SER": "Polar", "THR": "Polar", "ASN": "Polar", "GLN": "Polar", "CYS": "Sulfur", "CYM": "Sulfur", "CYX": "Sulfur", "MET": "Sulfur", } _LIG_PI_INTERACTIONS = ["EdgeToFace", "FaceToFace", "PiStacking", "PiCation"] _DISPLAYED_ATOM = { # index 0 in indices tuple by default "HBDonor": 1, "XBDonor": 1, } _JS_TEMPLATE = """ var ifp, legend, nodes, edges, legend_buttons; function drawGraph(_id, nodes, edges, options) { var container = document.getElementById(_id); nodes = new vis.DataSet(nodes); edges = new vis.DataSet(edges); var data = {nodes: nodes, edges: edges}; var network = new vis.Network(container, data, options); network.on("stabilizationIterationsDone", function () { network.setOptions( { physics: false } ); }); return network; } nodes = %(nodes)s; edges = %(edges)s; ifp = drawGraph('%(div_id)s', nodes, edges, %(options)s); """ _HTML_TEMPLATE = """ <html> <head> <script type="text/javascript" src="https://unpkg.com/vis-network@9.0.4/dist/vis-network.min.js"></script> <link href="https://unpkg.com/vis-network@9.0.4/dist/dist/vis-network.min.css" rel="stylesheet" type="text/css" /> <style type="text/css"> body { padding: 0; margin: 0; background: #fff; } .legend-btn.residues.disabled { background: #b4b4b4 !important; color: #555 !important; } .legend-btn.interactions.disabled { border-color: #b4b4b4 !important; color: #555 !important; } </style> </head> <body> <div id="mynetwork"></div> <div id="networklegend"></div> <script type="text/javascript"> %(js)s </script> </body> </html> """ def __init__( self, df, lig_mol, use_coordinates=False, flatten_coordinates=True, kekulize=False, molsize=35, rotation=0, carbon=0.16, ): self.df = df self._interacting_atoms = set( [atom for atoms in df.index.get_level_values("atoms") for atom in atoms] ) mol = deepcopy(lig_mol) if kekulize: Chem.Kekulize(mol) if use_coordinates: if flatten_coordinates: rdDepictor.GenerateDepictionMatching3DStructure(mol, lig_mol) else: rdDepictor.Compute2DCoords(mol, clearConfs=True) xyz = mol.GetConformer().GetPositions() if rotation: theta = np.radians(rotation) c, s = np.cos(theta), np.sin(theta) R = np.array([[c, s], [-s, c]]) xy, z = xyz[:, :2], xyz[:, 2:3] center = xy.mean(axis=0) xy = ((xy - center) @ R.T) + center xyz = np.concatenate([xy, z], axis=1) if carbon: self._carbon = { "label": " ", "shape": "dot", "color": self.COLORS["atoms"]["C"], "size": molsize * carbon, } else: self._carbon = {"label": " ", "shape": "text"} self.xyz = molsize * xyz self.mol = mol self._multiplier = molsize self.options = {} self._max_interaction_width = 6 self._avoidOverlap = 0.8 self._springConstant = 0.1 self._bond_color = "black" self._default_atom_color = "grey" self._default_residue_color = "#dbdbdb" self._default_interaction_color = "#dbdbdb" # regroup interactions of the same color temp = defaultdict(list) interactions = set(df.index.get_level_values("interaction").unique()) for interaction in interactions: color = self.COLORS["interactions"].get( interaction, self._default_interaction_color ) temp[color].append(interaction) self._interaction_types = { interaction: "/".join(interaction_group) for interaction_group in temp.values() for interaction in interaction_group }
[docs] @classmethod def from_fingerprint( cls, fp, ligand_mol, kind="aggregate", frame=0, display_all=False, threshold=0.3, **kwargs, ): """Helper method to create a ligand interaction diagram from a :class:`~prolif.fingerprint.Fingerprint` object. Notes ----- Two kinds of diagrams can be rendered: either for a designated frame or by aggregating the results on the whole IFP and optionnally discarding interactions that occur less frequently than a threshold. In the latter case (aggregate), only the group of atoms most frequently involved in each interaction is used to draw the edge. Parameters ---------- fp : prolif.fingerprint.Fingerprint The fingerprint object already executed using one of the ``run`` or ``run_from_iterable`` methods. lig : rdkit.Chem.rdChem.Mol Ligand molecule kind : str One of ``"aggregate"`` or ``"frame"`` frame : int Frame number (see :attr:`~prolif.fingerprint.Fingerprint.ifp`). Only applicable for ``kind="frame"`` display_all : bool Display all occurences for a given pair of residues and interaction, or only the shortest one. Only applicable for ``kind="frame"``. Not relevant if ``count=False`` in the ``Fingerprint`` object. threshold : float Frequency threshold, between 0 and 1. Only applicable for ``kind="aggregate"`` kwargs : object Other arguments passed to the :class:`LigNetwork` class .. versionchanged:: 2.0.0 Added the ``display_all`` parameter. """ if not hasattr(fp, "ifp"): raise RunRequiredError( "Please run the fingerprint analysis before attempting to display" " results." ) if kind == "frame": df = cls._make_frame_df_from_fp(fp, frame=frame, display_all=display_all) return cls(df, ligand_mol, **kwargs) if kind == "aggregate": df = cls._make_agg_df_from_fp(fp, threshold=threshold) return cls(df, ligand_mol, **kwargs) raise ValueError(f'{kind!r} must be "aggregate" or "frame"')
@staticmethod def _get_records(ifp, all_metadata): records = [] for (lig_resid, prot_resid), int_data in ifp.items(): for int_name, metadata_tuple in int_data.items(): entry = { "ligand": str(lig_resid), "protein": str(prot_resid), "interaction": int_name, } if all_metadata: for metadata in metadata_tuple: records.append( { **entry, "atoms": metadata["parent_indices"]["ligand"], "distance": metadata.get("distance", 0), } ) else: # extract interaction with shortest distance metadata = min( metadata_tuple, key=lambda m: m.get("distance", np.nan) ) entry["atoms"] = metadata["parent_indices"]["ligand"] entry["distance"] = metadata.get("distance", 0) records.append(entry) return records @classmethod def _make_agg_df_from_fp(cls, fp, threshold=0.3): data = [] for ifp in fp.ifp.values(): data.extend(cls._get_records(ifp, all_metadata=False)) df = pd.DataFrame(data) # add weight for each atoms, and average distance df["weight"] = 1 df = df.groupby(["ligand", "protein", "interaction", "atoms"]).agg( weight=("weight", "sum"), distance=("distance", "mean") ) df["weight"] = df["weight"] / len(fp.ifp) # merge different ligand atoms of the same residue/interaction group before # applying the threshold df = df.join( df.groupby(level=["ligand", "protein", "interaction"]).agg( weight_total=("weight", "sum") ), ) # threshold and keep most occuring ligand atom df = ( df.loc[df["weight_total"] >= threshold] .drop(columns="weight_total") .sort_values("weight", ascending=False) .groupby(level=["ligand", "protein", "interaction"]) .head(1) .sort_index() ) return df @classmethod def _make_frame_df_from_fp(cls, fp, frame=0, display_all=False): ifp = fp.ifp[frame] data = cls._get_records(ifp, all_metadata=display_all) df = pd.DataFrame(data) df["weight"] = 1 df = df.set_index(["ligand", "protein", "interaction", "atoms"]).reindex( columns=["weight", "distance"] ) return df def _make_carbon(self): return deepcopy(self._carbon) def _make_lig_node(self, atom): """Prepare ligand atoms""" idx = atom.GetIdx() elem = atom.GetSymbol() if elem == "H" and idx not in self._interacting_atoms: self.exclude.append(idx) return charge = atom.GetFormalCharge() if charge != 0: charge = "{}{}".format( "" if abs(charge) == 1 else str(charge), "+" if charge > 0 else "-" ) label = f"{elem}{charge}" shape = "ellipse" else: label = elem shape = "circle" if elem == "C": node = self._make_carbon() else: node = { "label": label, "shape": shape, "color": "white", "font": { "color": self.COLORS["atoms"].get(elem, self._default_atom_color) }, } node.update( { "id": idx, "x": float(self.xyz[idx, 0]), "y": float(self.xyz[idx, 1]), "fixed": True, "group": "ligand", "borderWidth": 0, } ) self.nodes[idx] = node def _make_lig_edge(self, bond): """Prepare ligand bonds""" idx = [bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()] if any(i in self.exclude for i in idx): return btype = bond.GetBondTypeAsDouble() if btype == 1: self.edges.append( { "from": idx[0], "to": idx[1], "color": self._bond_color, "physics": False, "group": "ligand", "width": 4, } ) else: self._make_non_single_bond(idx, btype) def _make_non_single_bond(self, ids, btype, bdist=0.06, dash=[10]): """Prepare double, triple and aromatic bonds""" xyz = self.xyz[ids] d = xyz[1, :2] - xyz[0, :2] length = np.sqrt((d**2).sum()) u = d / length p = np.array([-u[1], u[0]]) nodes = [] dist = bdist * self._multiplier * np.ceil(btype) dashes = False if btype in [2, 3] else dash for perp in (p, -p): for point in xyz: xy = point[:2] + perp * dist _id = hash(xy.tobytes()) nodes.append(_id) self.nodes[_id] = { "id": _id, "x": xy[0], "y": xy[1], "shape": "text", "label": " ", "fixed": True, "physics": False, } l1, l2, r1, r2 = nodes self.edges.extend( [ { "from": l1, "to": l2, "color": self._bond_color, "physics": False, "dashes": dashes, "group": "ligand", "width": 4, }, { "from": r1, "to": r2, "color": self._bond_color, "physics": False, "dashes": dashes, "group": "ligand", "width": 4, }, ] ) if btype == 3: self.edges.append( { "from": ids[0], "to": ids[1], "color": self._bond_color, "physics": False, "group": "ligand", "width": 4, } ) def _make_interactions(self, mass=2): """Prepare lig-prot interactions""" restypes = {} for prot_res in self.df.index.get_level_values("protein").unique(): resname = ResidueId.from_string(prot_res).name restype = self.RESIDUE_TYPES.get(resname) restypes[prot_res] = restype color = self.COLORS["residues"].get(restype, self._default_residue_color) node = { "id": prot_res, "label": prot_res, "color": color, "shape": "box", "borderWidth": 0, "physics": True, "mass": mass, "group": "protein", "residue_type": restype, } self.nodes[prot_res] = node for (lig_res, prot_res, interaction, lig_indices), ( weight, distance, ) in self.df.iterrows(): if interaction in self._LIG_PI_INTERACTIONS: centroid = self._get_ring_centroid(lig_indices) origin = f"centroid({lig_res}, {prot_res}, {interaction})" self.nodes[origin] = { "id": origin, "x": centroid[0], "y": centroid[1], "shape": "text", "label": " ", "fixed": True, "physics": False, "group": "ligand", } else: i = self._DISPLAYED_ATOM.get(interaction, 0) origin = lig_indices[i] edge = { "from": origin, "to": prot_res, "title": f"{interaction}: {distance:.2f}Å", "interaction_type": self._interaction_types.get( interaction, interaction ), "color": self.COLORS["interactions"].get( interaction, self._default_interaction_color ), "smooth": {"type": "cubicBezier", "roundness": 0.2}, "dashes": [10], "width": weight * self._max_interaction_width, "group": "interaction", } self.edges.append(edge) def _get_ring_centroid(self, indices): """Find ring centroid coordinates using the indices of the ring atoms""" return self.xyz[list(indices)].mean(axis=0) def _patch_hydrogens(self): """Patch hydrogens on heteroatoms Hydrogen atoms that aren't part of any interaction have been hidden at this stage, but they should be added to the label of the heteroatom for clarity """ to_patch = defaultdict(int) for idx in self.exclude: h = self.mol.GetAtomWithIdx(idx) atom = h.GetNeighbors()[0] if atom.GetSymbol() != "C": to_patch[atom.GetIdx()] += 1 for idx, nH in to_patch.items(): node = self.nodes[idx] h_str = "H" if nH == 1 else f"H{nH}" label = re.sub(r"(\w+)(.*)", rf"\1{h_str}\2", node["label"]) node["label"] = label node["shape"] = "ellipse" def _make_graph_data(self): """Prepares the nodes and edges""" self.exclude = [] self.nodes = {} self.edges = [] # show residues self._make_interactions() # show ligand for atom in self.mol.GetAtoms(): self._make_lig_node(atom) for bond in self.mol.GetBonds(): self._make_lig_edge(bond) self._patch_hydrogens() self.nodes = list(self.nodes.values()) def _get_js(self, width="100%", height="500px", div_id="mynetwork", fontsize=20): """Returns the JavaScript code to draw the network""" self.width = width self.height = height self._make_graph_data() options = { "width": width, "height": height, "nodes": { "font": {"size": fontsize}, }, "physics": { "barnesHut": { "avoidOverlap": self._avoidOverlap, "springConstant": self._springConstant, } }, } options.update(self.options) js = self._JS_TEMPLATE % dict( div_id=div_id, nodes=json.dumps(self.nodes), edges=json.dumps(self.edges), options=json.dumps(options), ) js += self._get_legend() return js def _get_html(self, **kwargs): """Returns the HTML code to draw the network""" return self._HTML_TEMPLATE % dict(js=self._get_js(**kwargs)) def _get_legend(self, height="90px"): available = {} buttons = [] map_color_restype = {c: t for t, c in self.COLORS["residues"].items()} map_color_interactions = { self.COLORS["interactions"].get(i, self._default_interaction_color): t for i, t in self._interaction_types.items() } # residues for node in self.nodes: if node.get("group", "") == "protein": color = node["color"] available[color] = map_color_restype.get(color, "Unknown") available = { k: v for k, v in sorted(available.items(), key=lambda item: item[1]) } for i, (color, restype) in enumerate(available.items()): buttons.append( {"index": i, "label": restype, "color": color, "group": "residues"} ) # interactions available.clear() for edge in self.edges: if edge.get("group", "") == "interaction": color = edge["color"] available[color] = map_color_interactions[color] available = { k: v for k, v in sorted(available.items(), key=lambda item: item[1]) } for i, (color, interaction) in enumerate(available.items()): buttons.append( { "index": i, "label": interaction, "color": color, "group": "interactions", } ) # JS code if all("px" in h for h in [self.height, height]): h1 = int(re.findall(r"(\d+)\w+", self.height)[0]) h2 = int(re.findall(r"(\d+)\w+", height)[0]) self.height = f"{h1+h2}px" return """ legend_buttons = %(buttons)s; legend = document.getElementById('%(div_id)s'); var div_residues = document.createElement('div'); var div_interactions = document.createElement('div'); var disabled = []; var legend_callback = function() { this.classList.toggle("disabled"); var hide = this.classList.contains("disabled"); var show = !hide; var btn_label = this.innerHTML; if (hide) { disabled.push(btn_label); } else { disabled = disabled.filter(x => x !== btn_label); } var node_update = [], edge_update = []; // click on residue type if (this.classList.contains("residues")) { nodes.forEach((node) => { // find nodes corresponding to this type if (node.residue_type === btn_label) { // if hiding this type and residue isn't already hidden if (hide && !node.hidden) { node.hidden = true; node_update.push(node); // if showing this type and residue isn't already visible } else if (show && node.hidden) { // display if there's at least one of its edge that isn't hidden num_edges_active = edges.filter(x => x.to === node.id) .map(x => Boolean(x.hidden)) .filter(x => !x) .length; if (num_edges_active > 0) { node.hidden = false; node_update.push(node); } } } }); ifp.body.data.nodes.update(node_update); // click on interaction type } else { edges.forEach((edge) => { // find edges corresponding to this type if (edge.interaction_type === btn_label) { edge.hidden = !edge.hidden; edge_update.push(edge); // number of active edges for the corresponding residue var num_edges_active = edges.filter(x => x.to === edge.to) .map(x => Boolean(x.hidden)) .filter(x => !x) .length; // find corresponding residue var ix = nodes.findIndex(x => x.id === edge.to); // only change visibility if residue_type not being hidden if (!(disabled.includes(nodes[ix].residue_type))) { // hide if no edge being shown for this residue if (hide && (num_edges_active === 0)) { nodes[ix].hidden = true; node_update.push(nodes[ix]); // show if edges are being shown } else if (show && (num_edges_active > 0)) { nodes[ix].hidden = false; node_update.push(nodes[ix]); } } } }); ifp.body.data.nodes.update(node_update); ifp.body.data.edges.update(edge_update); } }; legend_buttons.forEach(function(v,i) { if (v.group === "residues") { var div = div_residues; var border = "none"; var color = v.color; } else { var div = div_interactions; var border = "3px dashed " + v.color; var color = "white"; } var button = div.appendChild(document.createElement('button')); button.classList.add("legend-btn", v.group); button.innerHTML = v.label; Object.assign(button.style, { "cursor": "pointer", "background-color": color, "border": border, "border-radius": "5px", "padding": "5px", "margin": "5px", "font": "14px 'Arial', sans-serif", }); button.onclick = legend_callback; }); legend.appendChild(div_residues); legend.appendChild(div_interactions); """ % dict( div_id="networklegend", buttons=json.dumps(buttons) )
[docs] @requires("IPython.display") def display(self, **kwargs): """Prepare and display the network""" html = self._get_html(**kwargs) iframe = ( '<iframe width="{width}" height="{height}" frameborder="0" ' 'srcdoc="{doc}"></iframe>' ) return HTML( iframe.format(width=self.width, height=self.height, doc=escape(html)) )
[docs] @requires("IPython.display") def show(self, filename, **kwargs): """Save the network as HTML and display the resulting file""" html = self._get_html(**kwargs) with open(filename, "w") as f: f.write(html) iframe = ( '<iframe width="{width}" height="{height}" frameborder="0" ' 'src="{filename}"></iframe>' ) return HTML( iframe.format(width=self.width, height=self.height, filename=filename) )
[docs] def save(self, fp, **kwargs): """Save the network to an HTML file Parameters ---------- fp : str or file-like object Name of the output file, or file-like object """ html = self._get_html(**kwargs) if isinstance(fp, (str, Path)): with open(fp, "w") as f: f.write(html) elif hasattr(fp, "write") and callable(fp.write): fp.write(html)