Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import torch
import simtk.openmm as openmm

from rhofold.data.balstn import BLASTN
from rhofold.rhofold import RhoFold
Expand Down Expand Up @@ -103,13 +104,20 @@ def main(config):
path=unrelaxed_model, chain_id=None,
confidence=output['plddt'][0].data.cpu().numpy(),
logger=logger)
# Check if CUDA is available in OpenMM
cuda_available = False
for i in range(openmm.Platform.getNumPlatforms()):
if openmm.Platform.getPlatform(i).getName().lower() == "cuda":
cuda_available = True
break

# Amber relaxation
if config.relax_steps is not None:
relax_steps = int(config.relax_steps)
if relax_steps > 0:
with timing(f'Amber Relaxation : {relax_steps} iterations', logger=logger):
amber_relax = AmberRelaxation(max_iterations=relax_steps, logger=logger)
use_gpu = cuda_available and config.device.startswith("cuda:")
amber_relax = AmberRelaxation(max_iterations=relax_steps, logger=logger, use_gpu=use_gpu)
relaxed_model = f'{config.output_dir}/relaxed_{relax_steps}_model.pdb'
amber_relax.process(unrelaxed_model, relaxed_model)

Expand Down