diff --git a/torchdrug/data/molecule.py b/torchdrug/data/molecule.py index dcf5f39..8d2504e 100644 --- a/torchdrug/data/molecule.py +++ b/torchdrug/data/molecule.py @@ -31,6 +31,7 @@ class Molecule(Graph): atom_map (array_likeb optional): atom mappings of shape :math:`(|V|,)` bond_stereo (array_like, optional): bond stereochem of shape :math:`(|E|,)` stereo_atoms (array_like, optional): ids of stereo atoms of shape :math:`(|E|,)` + conformer (array_like, optional): list of conformer of shape :math`(|V|, 3)`. """ bond2id = {"SINGLE": 0, "DOUBLE": 1, "TRIPLE": 2, "AROMATIC": 3} @@ -44,19 +45,20 @@ class Molecule(Graph): def __init__(self, edge_list=None, atom_type=None, bond_type=None, formal_charge=None, explicit_hs=None, chiral_tag=None, radical_electrons=None, atom_map=None, bond_stereo=None, stereo_atoms=None, - **kwargs): + atom_conformer=None, **kwargs): if "num_relation" not in kwargs: kwargs["num_relation"] = len(self.bond2id) super(Molecule, self).__init__(edge_list=edge_list, **kwargs) atom_type, bond_type = self._standarize_atom_bond(atom_type, bond_type) - formal_charge = self._standarize_attribute(formal_charge, self.num_node) - explicit_hs = self._standarize_attribute(explicit_hs, self.num_node) - chiral_tag = self._standarize_attribute(chiral_tag, self.num_node) - radical_electrons = self._standarize_attribute(radical_electrons, self.num_node) - atom_map = self._standarize_attribute(atom_map, self.num_node) - bond_stereo = self._standarize_attribute(bond_stereo, self.num_edge) - stereo_atoms = self._standarize_attribute(stereo_atoms, (self.num_edge, 2)) + formal_charge = self._standarize_attribute(formal_charge, self.num_node, torch.long) + explicit_hs = self._standarize_attribute(explicit_hs, self.num_node, torch.long) + chiral_tag = self._standarize_attribute(chiral_tag, self.num_node, torch.long) + radical_electrons = self._standarize_attribute(radical_electrons, self.num_node, torch.long) + atom_map = self._standarize_attribute(atom_map, self.num_node, torch.long) + bond_stereo = self._standarize_attribute(bond_stereo, self.num_edge, torch.long) + stereo_atoms = self._standarize_attribute(stereo_atoms, (self.num_edge, 2), torch.long) + atom_conformer = self._standarize_attribute(atom_conformer, (self.num_node, 3), torch.float) with self.node(): self.atom_type = atom_type @@ -65,6 +67,7 @@ def __init__(self, edge_list=None, atom_type=None, bond_type=None, formal_charge self.chiral_tag = chiral_tag self.radical_electrons = radical_electrons self.atom_map = atom_map + self.atom_conformer = atom_conformer with self.edge(): self.bond_type = bond_type @@ -81,13 +84,13 @@ def _standarize_atom_bond(self, atom_type, bond_type): bond_type = torch.as_tensor(bond_type, dtype=torch.long, device=self.device) return atom_type, bond_type - def _standarize_attribute(self, attribute, size): + def _standarize_attribute(self, attribute, size, data_type): if attribute is not None: - attribute = torch.as_tensor(attribute, dtype=torch.long, device=self.device) + attribute = torch.as_tensor(attribute, dtype=data_type, device=self.device) else: if isinstance(size, torch.Tensor): size = size.tolist() - attribute = torch.zeros(size, dtype=torch.long, device=self.device) + attribute = torch.zeros(size, dtype=data_type, device=self.device) return attribute @classmethod @@ -111,7 +114,7 @@ def _maybe_num_node(self, edge_list): @classmethod def from_smiles(cls, smiles, node_feature="default", edge_feature="default", graph_feature=None, - with_hydrogen=False, kekulize=False): + with_hydrogen=False, kekulize=False, conformer=False): """ Create a molecule from a SMILES string. @@ -131,11 +134,11 @@ def from_smiles(cls, smiles, node_feature="default", edge_feature="default", gra if mol is None: raise ValueError("Invalid SMILES `%s`" % smiles) - return cls.from_molecule(mol, node_feature, edge_feature, graph_feature, with_hydrogen, kekulize) + return cls.from_molecule(mol, node_feature, edge_feature, graph_feature, with_hydrogen, kekulize, conformer) @classmethod def from_molecule(cls, mol, node_feature="default", edge_feature="default", graph_feature=None, - with_hydrogen=False, kekulize=False): + with_hydrogen=False, kekulize=False, conformer=False): """ Create a molecule from a RDKit object. @@ -150,6 +153,9 @@ def from_molecule(cls, mol, node_feature="default", edge_feature="default", grap Note this only affects the relation in ``edge_list``. For ``bond_type``, aromatic bonds are always stored explicitly. By default, aromatic bonds are stored. + conformer (bool, optional): generate molecule 3D conformer. + MMFF is used for 3D conformer. + By default, conformer is not used. """ if mol is None: mol = cls.empty_mol @@ -158,6 +164,15 @@ def from_molecule(cls, mol, node_feature="default", edge_feature="default", grap mol = Chem.AddHs(mol) if kekulize: Chem.Kekulize(mol) + if conformer: + mol = Chem.AddHs(mol) + Chem.AllChem.EmbedMolecule(mol) + conf = mol.GetConformer() + conf_list = Chem.AllChem.MMFFOptimizeMoleculeConfs(mol, maxIters=2000) + mol = Chem.RemoveHs(mol) + conf_list = torch.tensor(conf_list) + opt_conf = torch.argmin(conf_list[:,1]) + conformer = mol.GetConformer(int(opt_conf)) node_feature = cls._standarize_option(node_feature) edge_feature = cls._standarize_option(edge_feature) @@ -169,6 +184,7 @@ def from_molecule(cls, mol, node_feature="default", edge_feature="default", grap chiral_tag = [] radical_electrons = [] atom_map = [] + atom_conformer = [] _node_feature = [] atoms = [mol.GetAtomWithIdx(i) for i in range(mol.GetNumAtoms())] + [cls.dummy_atom] for atom in atoms: @@ -178,6 +194,10 @@ def from_molecule(cls, mol, node_feature="default", edge_feature="default", grap chiral_tag.append(atom.GetChiralTag()) radical_electrons.append(atom.GetNumRadicalElectrons()) atom_map.append(atom.GetAtomMapNum()) + if conformer: + position = conformer.GetAtomPosition(atom.GetIdx()) + atom_conformer.append([position.x, position.y, position.z]) + feature = [] for name in node_feature: func = R.get("features.atom.%s" % name) @@ -189,6 +209,8 @@ def from_molecule(cls, mol, node_feature="default", edge_feature="default", grap explicit_hs = torch.tensor(explicit_hs)[:-1] chiral_tag = torch.tensor(chiral_tag)[:-1] radical_electrons = torch.tensor(radical_electrons)[:-1] + if len(atom_conformer) > 0: + atom_conformer = torch.tensor(atom_conformer)[:-1] if len(node_feature) > 0: _node_feature = torch.tensor(_node_feature)[:-1] else: @@ -247,7 +269,7 @@ def from_molecule(cls, mol, node_feature="default", edge_feature="default", grap chiral_tag=chiral_tag, radical_electrons=radical_electrons, atom_map=atom_map, bond_stereo=bond_stereo, stereo_atoms=stereo_atoms, node_feature=_node_feature, edge_feature=_edge_feature, graph_feature=_graph_feature, - num_node=mol.GetNumAtoms(), num_relation=num_relation) + num_node=mol.GetNumAtoms(), num_relation=num_relation, atom_conformer=atom_conformer) def to_smiles(self, isomeric=True, atom_map=True, canonical=False): """