diff --git a/inference.py b/inference.py index 9fa8f24e..f6edde21 100644 --- a/inference.py +++ b/inference.py @@ -341,15 +341,6 @@ def inference_monomer_model(args): print("running in monomer mode...") config = model_config(args.model_name) - template_featurizer = templates.TemplateHitFeaturizer( - mmcif_dir=args.template_mmcif_dir, - max_template_date=args.max_template_date, - max_hits=config.data.predict.max_templates, - kalign_binary_path=args.kalign_binary_path, - release_dates_path=args.release_dates_path, - obsolete_pdbs_path=args.obsolete_pdbs_path - ) - use_small_bfd = args.preset == 'reduced_dbs' # (args.bfd_database_path is None) if use_small_bfd: assert args.bfd_database_path is not None @@ -357,8 +348,6 @@ def inference_monomer_model(args): assert args.bfd_database_path is not None assert args.uniref30_database_path is not None - data_processor = data_pipeline.DataPipeline(template_featurizer=template_featurizer,) - output_dir_base = args.output_dir random_seed = args.data_random_seed @@ -423,11 +412,31 @@ def inference_monomer_model(args): use_small_bfd=use_small_bfd, no_cpus=args.cpus, ) + t = time.perf_counter() alignment_runner.run(fasta_path, local_alignment_dir) + print(f"Alignment data time: {time.perf_counter() - t}") - feature_dict = data_processor.process_fasta(fasta_path=fasta_path, - alignment_dir=local_alignment_dir) - + features_output_path = os.path.join(local_alignment_dir, 'features.pkl') + if os.path.exists(features_output_path): + feature_dict = pickle.load(open(features_output_path, 'rb')) + + else: + template_featurizer = templates.TemplateHitFeaturizer( + mmcif_dir=args.template_mmcif_dir, + max_template_date=args.max_template_date, + max_hits=config.data.predict.max_templates, + kalign_binary_path=args.kalign_binary_path, + release_dates_path=args.release_dates_path, + obsolete_pdbs_path=args.obsolete_pdbs_path + ) + + data_processor = data_pipeline.DataPipeline(template_featurizer=template_featurizer,) + + feature_dict = data_processor.process_fasta(fasta_path=fasta_path, + alignment_dir=local_alignment_dir) + with open(features_output_path, 'wb') as f: + pickle.dump(feature_dict, f, protocol=4) + # Remove temporary FASTA file os.remove(fasta_path)