Source code for pyrosetta_help.alphafold.constraints

__all__ = ['add_pae_constraints',
           'add_interchain_pae_constraints',
           'add_stretch_constraint',
           'make_pae_constraint',]

from typing import (Optional, List)

import numpy as np
import pyrosetta
from ..common_ops.distances import measure_distance_matrix


[docs]def add_pae_constraints(pose: pyrosetta.Pose, errors: np.ndarray, cutoff: float = 12, tolerance: Optional[float] = None, weight: float = 1, adjecency_threshold=5) -> None: """ Add constrains to the pose based on the errors matrix. NB. this matrix is a reshaped version of what AF2 returns. A harmonic function is added to CA atoms that are in residues with the error under a specified cutoff. The mu is the current distance and the standard deviation of the harmonic is the error times ``weight``. To find out how many were added: >>> len(pose.constraint_set().get_all_constraints()) :param pose: :param errors: :param cutoff: :param tolerance: if None Harmonic, if value, tollerance of FlatHarmonic :param weight: this is added to the SD part so squared inverse. :param adjecency_threshold: min residue separation of sequence neighbours :return: """ for r1_idx, r2_idx in np.argwhere(errors < cutoff): if abs(r1_idx - r2_idx) < adjecency_threshold: continue # skip neighbours elif r1_idx <= r2_idx: continue # add once. d_error = errors[r1_idx, r2_idx] apc = make_pae_constraint(pose=pose, residue1_pose_idx=r1_idx + 1, residue2_pose_idx=r2_idx + 1, error=d_error, weight=weight, tolerance=tolerance) pose.add_constraint(apc)
[docs]def make_pae_constraint(pose, residue1_pose_idx: int, # one indexed. residue2_pose_idx: int, # one indexed. error: float, tolerance: Optional[float] = None, weight: float = 1): """ Add a constraint between two residues based on the PAE error from AlphaFold2 (the colourful heatmap in EBI-AF2). :param pose: :param residue1_pose_idx: :param residue2_pose_idx: :param error: :param tolerance: :param weight: :return: """ get_ca = lambda r, i: pyrosetta.AtomID(atomno_in=r.atom_index('CA'), rsd_in=i) FlatHarmonicFunc = pyrosetta.rosetta.core.scoring.func.FlatHarmonicFunc HarmonicFunc = pyrosetta.rosetta.core.scoring.func.HarmonicFunc AtomPairConstraint = pyrosetta.rosetta.core.scoring.constraints.AtomPairConstraint residue1 = pose.residue(residue1_pose_idx) ca1_atom = get_ca(residue1, residue1_pose_idx) residue2 = pose.residue(residue2_pose_idx) ca2_atom = get_ca(residue2, residue2_pose_idx) ca1_xyz = residue1.xyz(ca1_atom.atomno()) ca2_xyz = residue2.xyz(ca2_atom.atomno()) d = (ca1_xyz - ca2_xyz).norm() if not tolerance: fun = HarmonicFunc(x0_in=d, sd_in=error * weight) else: fun = FlatHarmonicFunc(x0_in=d, sd_in=error * weight, tol_in=tolerance) return AtomPairConstraint(ca1_atom, ca2_atom, fun)
[docs]def add_interchain_pae_constraints(pose, errors, cutoff=15): """ Add constraints between residues that are interacting according to the PAE error matrix but are in different 'chains' (sensu PyRosetta FoldTree). :param pose: :param errors: :param cutoff: :return: """ xdistances = measure_distance_matrix(pose) for c in (1, 2): xdistances[pose.chain_begin(c) - 1: pose.chain_end(c), pose.chain_begin(c) - 1: pose.chain_end(c)] = np.nan with np.errstate(invalid='ignore'): mask = xdistances < cutoff for r1_idx, r2_idx in np.argwhere(mask): apc = make_pae_constraint(pose=pose, residue1_pose_idx=r1_idx + 1, residue2_pose_idx=r2_idx + 1, error=errors[r1_idx, r2_idx], tolerance=3) # flatharmonic pose.add_constraint(apc)
[docs]def add_stretch_constraint(pose: pyrosetta.Pose, weight: float = 5, slope_in: float = -0.05, residue_index_A: int = 1, residue_index_B: int = -1, distance: Optional[float] = None, sigmoid: bool = True ) -> pyrosetta.rosetta.core.scoring.constraints.AtomPairConstraint: """ Add a constraint to "stretch out" the model, because ``slope_in`` is negative. The weight needs to be negative for sigmoid=False or it will attractive :param pose: Pose to add constraint to :param weight: how strength of constraint (max of 0.5 for ``SigmoidFunc``) :param slope_in: negative number to stretch :param residue_index_A: first residue? :param residue_index_B: last residue is "-1" :param distance: if omitted, the midpoint of Sigmoid will be the current distance :param sigmoid: use sigmoid or identity/linear (bad idea) :return: """ # get current length if residue_index_B == -1: residue_index_B = pose.total_residue() assert pose.residue(residue_index_A).is_protein, f'residue idx {residue_index_A} is not an AA' assert pose.residue(residue_index_B).is_protein, f'residue idx {residue_index_B} is not an AA' first_ca = pyrosetta.AtomID(atomno_in=pose.residue(residue_index_A).atom_index('CA'), rsd_in=residue_index_A) last_ca = pyrosetta.AtomID(atomno_in=pose.residue(residue_index_B).atom_index('CA'), rsd_in=residue_index_B) first_ca_xyz = pose.residue(residue_index_A).xyz(first_ca.atomno()) last_ca_xyz = pose.residue(residue_index_B).xyz(last_ca.atomno()) if distance is None: distance = (first_ca_xyz - last_ca_xyz).norm() # make & add con sf = pyrosetta.rosetta.core.scoring.func AtomPairConstraint = pyrosetta.rosetta.core.scoring.constraints.AtomPairConstraint # noqa if sigmoid: fun = sf.ScalarWeightedFunc(weight, sf.SigmoidFunc(x0_in=distance, slope_in=slope_in)) else: fun = sf.ScalarWeightedFunc(weight / distance, sf.IdentityFunc()) con = AtomPairConstraint(first_ca, last_ca, fun) pose.add_constraint(con) return con