Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 81 additions & 10 deletions baseline/indexer/indexer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "flood_fill.cc"
#include "gemmi/symmetry.hpp"
#include "peaks_to_rlvs.cc"
#include "refine_crystal.cc"
#include "scan_static_predictor.cc"
#include "score_crystals.cc"
#include "xyz_to_rlp.cc"
Expand Down Expand Up @@ -68,6 +69,12 @@ int main(int argc, char **argv) {
.help("The maximum number of candidate lattices to refine during indexing")
.default_value<size_t>(50)
.scan<'u', size_t>();
parser.add_argument("--macro-cycles")
.help(
"The number of macrocycles of refinement to run after the initial indexing. "
"Set to zero for no post-indexing refinement.")
.default_value<size_t>(5)
.scan<'u', size_t>();
parser.add_argument("--test")
.help("Enable additional output for testing")
.default_value<bool>(false)
Expand Down Expand Up @@ -115,6 +122,7 @@ int main(int argc, char **argv) {
std::string filename = parser.get<std::string>("refl");
double max_cell = parser.get<float>("max-cell");
size_t max_refine = parser.get<size_t>("max-refine");
size_t macro_cycles = parser.get<size_t>("macro-cycles");

// Parse the experiment list (a json file) and load the models.
// Will be moved to dx2.
Expand Down Expand Up @@ -161,25 +169,39 @@ int main(int argc, char **argv) {
// coordinates on the detector into reciprocal space.
xyz_to_rlp_results results = xyz_to_rlp(xyzobs_px, panel, beam, scan, gonio);
logger.info("Number of reflections: {}", results.rlp.extent(0));
uint32_t n_points = parser.get<uint32_t>("fft-npoints");

// Determine the d_min limit of the data, will be used in macrocycles of refinement.
double d_min_data;
std::vector<double> d_values(results.rlp.extent(0), 0);
for (int i = 0; i < results.rlp.extent(0); ++i) {
d_values[i] = 1.0 / Eigen::Map<Vector3d>(&results.rlp(i, 0)).norm();
}
d_min_data = *std::min_element(d_values.begin(), d_values.end());
logger.debug("dmin of highest resolution spot: {:.5f}", d_min_data);

// If a resolution limit was not specified, determine from the highest resolution spot.
double d_min;
if (parser.is_used("dmin")) {
d_min = parser.get<float>("dmin");
} else {
std::vector<double> d_values(results.rlp.extent(0), 0);
for (int i = 0; i < results.rlp.extent(0); ++i) {
d_values[i] = 1.0 / Eigen::Map<Vector3d>(&results.rlp(i, 0)).norm();
}
d_min = *std::min_element(d_values.begin(), d_values.end());
logger.info("Setting dmin based on highest resolution spot: {:.5f}", d_min);
/* rough calculation of suitable d_min based on max cell
see also Campbell, J. (1998). J. Appl. Cryst., 31(3), 407-413.
fft_cell should be greater than twice max_cell, so say:
fft_cell = 2.5 * max_cell
then:
fft_cell = n_points * d_min/2
2.5 * max_cell = n_points * d_min/2
a little bit of rearrangement:
d_min = 5 * max_cell/n_points. */
d_min = 5.0 * max_cell / n_points;
d_min = std::max(d_min, d_min_data);
logger.info("Setting dmin to {:.5f}", d_min);
}

// b_iso is an isotropic b-factor used to weight the points when doing the fft.
// i.e. high resolution (weaker) spots are downweighted by the expected
// intensity fall-off as as function of resolution.
double b_iso = -4.0 * std::pow(d_min, 2) * log(0.05);
uint32_t n_points = parser.get<uint32_t>("fft-npoints");
logger.info("Setting b_iso = {:.3f}", b_iso);

// Create an array to store the fft result. This is a 3D grid of points, typically 256^3.
Expand Down Expand Up @@ -374,6 +396,55 @@ int main(int argc, char **argv) {
expt.beam().set_s0(best_beam.get_s0());
expt.detector().update(best_panel.get_d_matrix());

// Now do macrocycles of refinement
if (macro_cycles) {
std::vector<double> d_steps;
double d_min_step = (d_min - d_min_data) / macro_cycles;
logger.info("Performing {} macro cycles with a dmin step of {:.3f}",
macro_cycles,
d_min_step);
for (int i = 0; i < macro_cycles; ++i) {
d_steps.push_back(d_min - (i + 1) * d_min_step);
}
for (int i = 0; i < macro_cycles; ++i) {
double d_min = d_steps[i];
logger.info("Performing macro cycle {} with d_min={:.3f}", i + 1, d_min);
results = xyz_to_rlp(
xyzobs_px, expt.detector().panels()[0], expt.beam(), scan, gonio);

// Make a selection on dmin and rotation angle like dials
std::vector<bool> selection(results.rlp.extent(0), true);
double osc_trim_limit = scan.get_oscillation()[0] + 360.0;
for (int i = 0; i < results.rlp.extent(0); ++i) {
Eigen::Map<Vector3d> rlp_i(&results.rlp(i, 0));
if (1.0 / rlp_i.norm() <= d_min) {
selection[i] = false;
} else if (results.xyzobs_mm(i, 2) * RAD2DEG > osc_trim_limit) {
selection[i] = false;
}
}
ReflectionTable reflections;
reflections.add_column(std::string("flags"), flags.size(), 1, flags);
reflections.add_column(std::string("xyzobs.mm.value"),
results.xyzobs_mm.extent(0),
3,
results.xyzobs_mm_data);
reflections.add_column(
std::string("s1"), results.s1.extent(0), 3, results.s1_data);
reflections.add_column(
std::string("rlp"), results.rlp.extent(0), 3, results.rlp_data);
reflections.add_column(std::string("entering"), enterings);
const ReflectionTable filtered = reflections.select(selection);

refine_crystal(expt.crystal(),
filtered,
gonio,
expt.beam(),
expt.detector().panels()[0],
scan_width);
}
}

// Save the indexed experiment list.
json elist_out = expt.to_json();
std::string efile_name = "indexed.expt";
Expand Down Expand Up @@ -511,7 +582,7 @@ int main(int argc, char **argv) {
flags_ = strong_reflections.column<std::size_t>("flags");
flag_span = flags_.value();
} else {
auto flag_span = flags_.value();
flag_span = flags_.value();
for (int i = 0; i < assign_results.miller_indices.extent(0); ++i) {
if ((assign_results.miller_indices(i, 0)) != 0
| (assign_results.miller_indices(i, 1) != 0)
Expand All @@ -530,7 +601,7 @@ int main(int argc, char **argv) {
strong_reflections);
// reset the predicted flags as these are observed not predicted
for (int i = 0; i < flag_span.extent(0); ++i) {
flag_span(i, 0) = flag_span(i, 0) & ~predicted_value;
flag_span(i, 0) &= ~predicted_value;
}

// Save the indexed reflection table.
Expand Down
4 changes: 4 additions & 0 deletions baseline/indexer/refine_candidate.cc
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#pragma once

#include <dx2/beam.hpp>
#include <dx2/crystal.hpp>
#include <dx2/detector.hpp>
Expand Down Expand Up @@ -88,5 +90,7 @@ double refine_indexing_candidate(Crystal &crystal,
// updated during the refinement
crystal.set_A_matrix(target.orientation_parameterisation().get_state()
* target.cell_parameterisation().get_state());
panel.update(target.detector_parameterisation().get_state());
beam.set_s0(target.beam_parameterisation().get_state());
return xyrmsd;
}
58 changes: 58 additions & 0 deletions baseline/indexer/refine_crystal.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#include <dx2/beam.hpp>
#include <dx2/crystal.hpp>
#include <dx2/detector.hpp>
#include <dx2/experiment.hpp>
#include <dx2/goniometer.hpp>
#include <dx2/reflection.hpp>
#include <vector>

#include "assign_indices.cc"
#include "ffs_logger.hpp"
#include "refine_candidate.cc"
#include "reflection_filter.cc"

void refine_crystal(Crystal &crystal,
ReflectionTable const &obs,
Goniometer gonio,
MonochromaticBeam &beam,
Panel &panel,
double scan_width) {
std::vector<int> miller_indices_data;
int count;
auto preassign = std::chrono::system_clock::now();

// First assign miller indices to the data using the crystal model.
auto rlp_ = obs.column<double>("rlp");
const mdspan_type<double> &rlp = rlp_.value();
auto xyzobs_mm_ = obs.column<double>("xyzobs.mm.value");
const mdspan_type<double> &xyzobs_mm = xyzobs_mm_.value();
assign_indices_results results =
assign_indices_global(crystal.get_A_matrix(), rlp, xyzobs_mm);
auto t2 = std::chrono::system_clock::now();
std::chrono::duration<double> elapsed_time = t2 - preassign;
logger.debug("Time for assigning indices: {:.5f}s", elapsed_time.count());

logger.info("Indexed {}/{} reflections",
results.number_indexed,
results.miller_indices.extent(0));

auto t3 = std::chrono::system_clock::now();
std::chrono::duration<double> elapsed_time1 = t3 - t2;
logger.debug("Time for correct: {:.5f}s", elapsed_time1.count());

// Perform filtering of the data prior to candidate refinement.
ReflectionTable sel_obs = reflection_filter_preevaluation(
obs, results.miller_indices, gonio, crystal, beam, panel, scan_width, 100);
auto postfilter = std::chrono::system_clock::now();
std::chrono::duration<double> elapsed_timefilter = postfilter - t3;
logger.debug("Time for reflection_filter: {:.5f}s", elapsed_timefilter.count());

auto t4 = std::chrono::system_clock::now();
double xyrmsd = refine_indexing_candidate(crystal, gonio, beam, panel, sel_obs);
std::chrono::duration<double> elapsed_time_refine =
std::chrono::system_clock::now() - t4;
auto flags_ = sel_obs.column<std::size_t>("flags");
int n_refl = flags_.value().extent(0);
logger.debug("Time for refinement: {:.5f}s", elapsed_time_refine.count());
logger.info("rmsd_xy {:.5f} on {} reflections", xyrmsd, n_refl);
}
1 change: 1 addition & 0 deletions baseline/indexer/reflection_filter.cc
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#pragma once
#include <Eigen/Dense>
#include <dx2/beam.hpp>
#include <dx2/crystal.hpp>
Expand Down
1 change: 1 addition & 0 deletions baseline/refiner/target.cc
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#pragma once

#include <cmath>
#include <dx2/beam.hpp>
Expand Down
2 changes: 1 addition & 1 deletion dx2
21 changes: 21 additions & 0 deletions tests/test_baseline_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import subprocess
from pathlib import Path

import h5py
import numpy as np


def test_baseline_indexer(tmp_path, dials_data):
indexer_path: str | Path | None = os.getenv("INDEXER")
Expand Down Expand Up @@ -194,6 +197,16 @@ def test_baseline_indexer(tmp_path, dials_data):
}
assert candidate_crystals == expected_crystals_output

assert (tmp_path / "indexed.refl").is_file()
assert (tmp_path / "indexed.expt").is_file()
with h5py.File(tmp_path / "indexed.refl") as f:
flags = f["/dials/processing/group_0/flags"]
assert len(flags) == 703
n_indexed = np.sum(np.array(flags, dtype=int) == 36)
n_unindexed = np.sum(np.array(flags, dtype=int) == 32)
assert n_indexed == 55
assert n_unindexed == 648


def test_baseline_indexer_c2sum(tmp_path, dials_data):
indexer_path: str | Path | None = os.getenv("INDEXER")
Expand Down Expand Up @@ -367,3 +380,11 @@ def test_baseline_indexer_c2sum(tmp_path, dials_data):
},
}
assert candidate_crystals == expected_crystals_output

with h5py.File(tmp_path / "indexed.refl") as f:
flags = f["/dials/processing/group_0/flags"]
assert len(flags) == 107999
n_indexed = np.sum(np.array(flags, dtype=int) == 36)
n_unindexed = np.sum(np.array(flags, dtype=int) == 32)
assert n_indexed == 107265
assert n_unindexed == 734