#
from typing import (Optional, Tuple, Dict)
import pyrosetta
import warnings
[docs]class AtomicInteractions:
"""
Gets the per atom energies for the per_atom.
>>> ai = AtomicInteractions(pose, 1) # noqa - pose residue index 1
>>> print(ai.describe_best())
>>> ai.per_atom[(' N ', 2, ' N ')]
Unfortunately, bonding is not taken into account therefore the total is more favourable
as these are ignored.
>>> ai.total, ai.expected_total
"""
score_types = ['fa_atr', 'fa_rep', 'fa_sol', 'fa_elec']
term_relevance_cutoff = 0.5
[docs] def __init__(self,
pose: pyrosetta.Pose,
target_idx: int,
threshold: int = 3,
scorefxn: Optional[pyrosetta.ScoreFunction] = None,
weighted: bool = True,
halved: bool = False):
# --- input
self.pose = pose
self.target_idx = target_idx
self.target_residue = self.pose.residue(self.target_idx)
assert self.target_residue, f'{self.target_idx} not in pose'
self.threshold = threshold
if scorefxn is None:
self.scorefxn = pyrosetta.get_fa_scorefxn() # noqa
else:
self.scorefxn = scorefxn
self.weighted = weighted
stm = pyrosetta.rosetta.core.scoring.ScoreTypeManager()
get_weights = lambda name: self.scorefxn.get_weight(stm.score_type_from_name(name))
if not weighted:
self.weights = {name: 1. for name in self.score_types}
else:
self.weights = {name: get_weights(st_name) for name, st_name in
zip(self.score_types, ['fa_atr', 'fa_rep', 'fa_sol', 'fa_elec'])}
self.halved = halved
# --- derived
self.neighbors = self._get_neighbors() # ResidueVector
self.interactions = self._get_interactions() # Dict[Tuple(str, int, str), Dict[str, float]]
def _get_interactions(self) -> Dict[Tuple[str, int, str], Dict[str, float]]:
interactions = {}
ratio = 2 if self.halved else 1
# Iterate per target residue's atom per all other residues' atoms
for i in range(1, self.target_residue.natoms() + 1):
if self.target_residue.atom_type(i).element() == 'X':
continue
iname = self.target_residue.atom_name(i)
interactions[iname] = {}
for r in self.neighbors: # noqa
if r == self.target_idx:
continue
other = self.pose.residue(r)
interactions[iname][r] = {}
for o in range(1, other.natoms() + 1):
oname = other.atom_name(o)
if r == self.target_idx and o == i:
continue # self to self should be zero...!
elif other.atom_type(o).element() == 'X':
continue
score = pyrosetta.toolbox.atom_pair_energy.etable_atom_pair_energies(self.target_residue,
i,
other,
o,
self.scorefxn)
# per_atom[iname][r][oname] = dict(zip(score_types, score))
interactions[iname][r][oname] = {st: s * self.weights[st] / ratio
for st, s in zip(self.score_types, score)}
# correct fa_rep for bonded
for atomname, other_residue_index, other_atomname in self._get_connections():
# [(' C1 ', 59, ' SG '),
# (' C2 ', 59, ' SG '),
# (' H1 ', 59, ' SG '),
# (' H2 ', 59, ' SG '),
# (' C1 ', 59, ' CB '),
# (' C1 ', 59, ' V1 ')]
if atomname not in interactions:
continue # virtual
if other_atomname not in interactions[atomname][other_residue_index]:
continue # virtual
del interactions[atomname][other_residue_index][other_atomname]
# reshape
reshaped = {(target_atomname, other_resi, atomname): interactions[target_atomname][other_resi][atomname] for
target_atomname in interactions for other_resi in interactions[target_atomname] for atomname in
interactions[target_atomname][other_resi]}
return reshaped
def _get_neighbors(self) -> pyrosetta.rosetta.core.select.residue_selector.ResidueVector:
cc_sele = self.get_cc_selector()
neighs = pyrosetta.rosetta.core.select.residue_selector.ResidueVector(cc_sele.apply(self.pose))
return neighs
[docs] def get_cc_selector(self):
resi_sele = self.get_target_selector()
cc_sele = pyrosetta.rosetta.core.select.residue_selector.CloseContactResidueSelector()
cc_sele.central_residue_group_selector(resi_sele)
cc_sele.threshold(self.threshold)
return cc_sele
[docs] def get_target_selector(self):
return pyrosetta.rosetta.core.select.residue_selector.ResidueIndexSelector(self.target_idx)
def _get_connections(self):
residue = self.pose.residue(self.target_idx)
for conn_id in range(1, residue.n_current_residue_connections() + 1):
atomno = residue.residue_connect_atom_index(conn_id)
atomname = residue.atom_name(atomno)
adjecent_atomnos = residue.bonded_neighbor(atomno)
other_residue_index = residue.residue_connection_partner(conn_id)
other_residue = self.pose.residue(other_residue_index)
other_atomname, other_atomno = self._get_other_connecting_atom(conn_id)
other_adjecent_atomnos = other_residue.bonded_neighbor(other_atomno)
return [(atomname, other_residue_index, other_atomname), ] + \
[(residue.atom_name(no), other_residue_index, other_atomname) for no in adjecent_atomnos] + \
[(atomname, other_residue_index, other_residue.atom_name(no)) for no in other_adjecent_atomnos]
def _get_other_connecting_atom(self, conn_id) -> Tuple[str, int]:
other_residue_index = self.target_residue.residue_connection_partner(conn_id)
other_residue = self.pose.residue(other_residue_index)
for other_conn_id in range(1, other_residue.n_current_residue_connections() + 1):
if other_residue.residue_connection_partner(other_conn_id) == self.target_idx:
other_atom_index = other_residue.residue_connect_atom_index(other_conn_id)
return other_residue.atom_name(other_atom_index), other_atom_index
else:
raise ValueError('No connection found?!')
@property
def best_interactions(self):
# dict is ordered...
get_max = lambda k: max(map(abs, self.interactions[k].values()))
return {k: self.interactions[k] for k in sorted(self.interactions, key=get_max, reverse=True)
if get_max(k) >= self.term_relevance_cutoff}
[docs] def describe_atom(self, residue, atomname):
atomno = residue.atom_index(atomname)
atomtype = residue.atom_type(atomno)
verdict = {'heavyatom': atomtype.is_heavyatom(),
'polar hydrogen': atomtype.is_polar_hydrogen(),
'H-acceptor': atomtype.is_acceptor(),
'H-donor': atomtype.is_donor(),
'aromatic': atomtype.is_aromatic()}
return ', '.join([atomtype.atom_type_name(), atomtype.element()] +
[k for k, v in verdict.items() if v])
[docs] def describe_interaction(self, target_atomname, other_resi, other_atomname):
other_residue = self.pose.residue(other_resi)
target_text = self.describe_atom(self.target_residue, target_atomname)
other_text = self.describe_atom(other_residue, other_atomname)
# [('fa_sol', 0.7810698032540209), ('fa_elec', 0.37053877404394236)]
all_terms = self.interactions[target_atomname, other_resi, other_atomname]
# sort & filter | > 0.25 |
terms = sorted(filter(lambda t: abs(t[1]) >= self.term_relevance_cutoff,
all_terms.items(),
),
key=lambda t: -abs(t[1]))
term_text = ' + '.join([f'{score_type} ({value:.1f} kcal/mol)' for score_type, value in terms])
return f'{self.target_idx}.{target_atomname.strip()} ({target_text}) - ' + \
f'{other_resi}.{other_atomname.strip()} ({other_text}) ' + \
term_text
[docs] def describe_best(self):
return '\n'.join([self.describe_interaction(*k) for k in self.best_interactions])
@property
def expected_total(self):
self.scorefxn(self.pose)
return self.scorefxn.get_sub_score(self.pose, self.get_target_selector().apply(self.pose))
@property
def total(self):
return sum([sum(d.values()) for d in self.interactions.values()])
[docs]class NeighbourInteractions(AtomicInteractions):
# old name!
[docs] def __init__(self, *args, **kwargs):
warnings.warn(f'`NeighbourInteractions` renamed to `AtomicInteractions`', DeprecationWarning)
super().__init__(*args, **kwargs)