Skip to content

Commit

Permalink
Update in preparation for v0.2.
Browse files Browse the repository at this point in the history
  • Loading branch information
pichuan committed Jan 15, 2022
1 parent 4877d69 commit 0e92898
Show file tree
Hide file tree
Showing 28 changed files with 304 additions and 156 deletions.
12 changes: 9 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Build with:
# sudo docker build -t deepconsensus .
# For GPU:
# sudo docker build --build-arg build_gpu=true --build-arg FROM_IMAGE=nvidia/cuda:11.3.0-cudnn8-runtime -t deepconsensus_gpu .


# <internal>
# https://blog.realkinetic.com/building-minimal-docker-containers-for-python-applications-37d0272c52f3

ARG FROM_IMAGE=continuumio/miniconda3

FROM continuumio/miniconda3 as conda_setup
RUN conda config --add channels defaults && \
Expand All @@ -29,9 +31,13 @@ RUN conda create -n bio \
bioconda::samtools=1.10 \
bioconda::pyfastx=0.8.4 \
&& conda clean -a
RUN wget https://github.com/PacificBiosciences/align-clr-to-ccs/releases/download/0.0.3/actc && \
RUN wget https://github.com/PacificBiosciences/align-clr-to-ccs/releases/download/0.1.0/actc && \
chmod +x actc && \
mv actc /opt/conda/bin/

FROM ${FROM_IMAGE} as builder
COPY --from=conda_setup /opt/conda /opt/conda

ENV PATH=/opt/conda/envs/bio/bin:/opt/conda/bin:"${PATH}"
ENV LD_LIBRARY_PATH=/opt/conda/envs/bio/lib:/opt/mytools/lib/x86_64-linux-gnu:"${LD_LIBRARY_PATH}"

Expand Down
33 changes: 27 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,34 @@ Biosciences (PacBio) Circular Consensus Sequencing (CCS) data.

### From pip package

If you're on a GPU machine:

```bash
pip install deepconsensus[gpu]==0.2.0
# To make sure the `deepconsensus` CLI works, set the PATH:
export PATH="/home/${USER}/.local/bin:${PATH}"
```

If you're on a CPU machine:

```bash
pip install deepconsensus==0.1.0
pip install deepconsensus[cpu]==0.2.0
# To make sure the `deepconsensus` CLI works, set the PATH:
export PATH="/home/${USER}/.local/bin:${PATH}"
```

### From Docker image

For GPU:

```bash
sudo docker pull google/deepconsensus:0.2.0-gpu
```

For CPU:

```bash
sudo docker pull google/deepconsensus:0.2.0
```

### From source
Expand Down Expand Up @@ -46,11 +72,6 @@ place of the CCS reads for downstream analyses.
See the [quick start](https://github.com/google/deepconsensus/blob/main/docs/quick_start.md)
for an example of inputs and outputs.

NOTE: This initial release of DeepConsensus (v0.1) is not yet optimized for
speed, and only runs on CPUs. We anticipate this version to be too slow for many
uses. We are now prioritizing speed improvements, which we anticipate can
achieve acceptable runtimes.

## How to cite

If you are using DeepConsensus in your work, please cite:
Expand Down
6 changes: 3 additions & 3 deletions deepconsensus/inference/quick_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,13 @@ def _process_input_helper(
tf.TensorSpec(
shape=(options.example_height, model_params.max_length,
model_params.num_channels),
dtype=tf.int64),
dtype=dc_constants.TF_DATA_TYPE),
'subreads/num_passes':
tf.TensorSpec(shape=(), dtype=tf.int64),
tf.TensorSpec(shape=(), dtype=tf.int32),
'name':
tf.TensorSpec(shape=(), dtype=tf.string),
'window_pos':
tf.TensorSpec(shape=(), dtype=tf.int64),
tf.TensorSpec(shape=(), dtype=tf.int32),
})
dataset = dataset.map(map_func=_process_input_helper)
dataset = dataset.batch(batch_size=options.batch_size, drop_remainder=False)
Expand Down
2 changes: 1 addition & 1 deletion deepconsensus/inference/quick_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_end_to_end(self, subreads, fasta, expected_lengths):
output_lengths.append(len(record.sequence))
count += 1
self.assertEqual(count, 2)
# <internal>
# TODO: Figure out why lengths are not deterministic.
# Not deterministic, might be due to the test model used since other runs
# with the release model have been deterministic so far.
# self.assertEqual(expected_lengths, output_lengths)
Expand Down
4 changes: 2 additions & 2 deletions deepconsensus/models/data_providers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ def get_test_dataset(inference: bool) -> Tuple[str, Dict[str, Any]]:
if inference:
dataset_path = 'human_1m/tf_examples/inference/*.tfrecord.gz'
summary_json = 'human_1m/tf_examples/summary/summary.inference.json'
size_key = 'n_inference_examples'
size_key = 'n_examples_inference'
else:
dataset_path = 'human_1m/tf_examples/train/*.tfrecord.gz'
summary_json = 'human_1m/tf_examples/summary/summary.training.json'
size_key = 'n_train_examples'
size_key = 'n_examples_train'
file_pattern = test_utils.deepconsensus_testdata(dataset_path)
summary_json_path = test_utils.deepconsensus_testdata(summary_json)
summary = json.load(tf.io.gfile.GFile(summary_json_path))
Expand Down
8 changes: 4 additions & 4 deletions deepconsensus/models/losses_and_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ def xentropy_ins_cost_fn(y_pred: tf.Tensor, eps=1e-7) -> tf.Tensor:
class AlignmentLoss(tf.keras.losses.Loss):
r"""Implements a differentiable alignment loss for DeepConsensus.
#<internal>
#<internal>
#TODO: support for from_logits argument.
#TODO: support for annealing schedules (depending on DC API?).
Attributes:
subs_cost_fn: A (batched) function $\Delta^{B \times L_1} \times \Delta^{B
Expand Down Expand Up @@ -447,7 +447,7 @@ def banded_alignment(self, subs_costs, ins_costs, del_cost, seq_lens, inf,
input_band = self.weave_band(values, inf)
subs_band = self.weave_band(subs_costs, inf)
ins_costs_pad = tf.pad(ins_costs, [[0, 0], [1, 0]], constant_values=0.)
# <internal>
# TODO: uphere
insert_expand = tf.tile(
ins_costs_pad[:, tf.newaxis, :], multiples=[1, len_1 + 1, 1])
insert_band = self.weave_band(insert_expand, inf)
Expand Down Expand Up @@ -492,7 +492,7 @@ def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
dtype = y_pred.dtype
# Defines an appropriate large positive float to represent "infinity".
# inf = tf.dtypes.float16.max if dtype == tf.dtypes.float16 else 1e9
inf = tf.convert_to_tensor(1e9, dtype) # <internal>
inf = tf.convert_to_tensor(1e9, dtype) # TODO: float16 support?

# Removes internal gaps, computes length excl. pad and converts to one-hot.
y_true, seq_lens = self.preprocess_y_true(y_true)
Expand Down
2 changes: 1 addition & 1 deletion deepconsensus/models/losses_and_metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ class AlignmentLossTest(parameterized.TestCase):
loss_reg=None,
expected_loss=64.472, # 4*log(eps), with eps = 1e-7
width=None),
# <internal>
# TODO: included test cases for soft alignment.
dict(
testcase_name='with band, identical sequences',
sequences=(['TTAGGC', 'AGCTGG'], ['TTAGGC', 'AGCTGG']),
Expand Down
6 changes: 3 additions & 3 deletions deepconsensus/models/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _set_base_transformer_hparams(params):
params.add_pos_encoding = True
params.use_relative_pos_enc = True
# Num heads should be divisible by hidden size. This value should be tuned for
# the production setting. <internal>
# the production setting. TODO: update this parameter after
# tuning.
params.num_heads = 2
params.layer_norm = False
Expand Down Expand Up @@ -166,8 +166,8 @@ def _set_test_data_hparams(params):
params.test_path = params.train_path
params.inference_path = os.path.join(
curr_dir, '../testdata/human_1m/tf_examples/inference/*')
params.n_train_examples = 253
params.n_eval_examples = 253
params.n_examples_train = 253
params.n_examples_eval = 253
params.max_passes = 20

# The test dataset uniquely sets these model-level parameters because the test
Expand Down
7 changes: 3 additions & 4 deletions deepconsensus/models/model_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def run_inference(out_dir: str, params: ml_collections.ConfigDict,
random.seed(params.seed)
tf.random.set_seed(params.seed)

# <internal>
# TODO: multiple GPUs don't provide a speedup for inference. We
# may need to explicitly distribute data to the workers to see speedup.
strategy = tf.distribute.MirroredStrategy()

Expand All @@ -100,11 +100,10 @@ def run_inference(out_dir: str, params: ml_collections.ConfigDict,

def main(unused_args=None):
if not FLAGS.params:
params = model_utils.read_params_from_json(
checkpoint_path=FLAGS.checkpoint_path)
params = model_utils.read_params_from_json(checkpoint_path=FLAGS.checkpoint)
else:
params = FLAGS.params
run_inference(FLAGS.out_dir, params, FLAGS.checkpoint_path, FLAGS.tpu,
run_inference(FLAGS.out_dir, params, FLAGS.checkpoint, FLAGS.tpu,
FLAGS.tpu_topology, FLAGS.limit)


Expand Down
4 changes: 2 additions & 2 deletions deepconsensus/models/model_train_custom_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def get_datasets(
def get_step_counts(params: ml_collections.ConfigDict) -> Tuple[int, int]:
"""Returns the steps for training and evaluation."""
if params.limit <= 0:
steps_per_epoch = params.n_train_examples // params.batch_size
steps_per_eval = params.n_eval_examples // params.batch_size
steps_per_epoch = params.n_examples_train // params.batch_size
steps_per_eval = params.n_examples_eval // params.batch_size
else:
# When `params.limit` is set, use it to determine epoch size.
steps_per_epoch = max(1, params.limit // params.batch_size)
Expand Down
14 changes: 7 additions & 7 deletions deepconsensus/models/model_train_custom_loop_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,25 +81,25 @@ class GetStepCountsTest(parameterized.TestCase):
@parameterized.named_parameters(
dict(
testcase_name='simple',
n_train_examples=1000,
n_eval_examples=100,
n_examples_train=1000,
n_examples_eval=100,
batch_size=10,
limit=-1,
expected_step_counts=(100, 10)),
dict(
testcase_name='with_limit',
n_train_examples=1000,
n_eval_examples=100,
n_examples_train=1000,
n_examples_eval=100,
batch_size=10,
limit=100,
expected_step_counts=(10, 10)),
)
def test_get_step_counts(self, n_train_examples, n_eval_examples, batch_size,
def test_get_step_counts(self, n_examples_train, n_examples_eval, batch_size,
limit, expected_step_counts):
params = model_configs.get_config('fc+test')
with params.unlocked():
params.n_train_examples = n_train_examples
params.n_eval_examples = n_eval_examples
params.n_examples_train = n_examples_train
params.n_examples_eval = n_examples_eval
params.limit = limit
params.batch_size = batch_size

Expand Down
2 changes: 1 addition & 1 deletion deepconsensus/models/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def encode(self, inputs: tf.Tensor, attention_bias: tf.Tensor,
tf.cast(inputs[:, :, i], tf.int32))
embedded_inputs.append(embedded)

# <internal>
# TODO: experiment with computing a weighted average using snr as
# weights to aggregate subread-level embeddings (instead of concatenating).
if self.params.use_sn:
# The last four elements in the last dimension in the inputs tensor
Expand Down
4 changes: 2 additions & 2 deletions deepconsensus/models/networks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ def test_predict_and_model_fn_equal(self, config_name, inference):
input_example = get_input_example(config, inference=inference)
softmax_output_predict = model.predict(input_example)
softmax_output = model(input_example, training=False).numpy()
self.assertTrue(np.array_equal(softmax_output_predict, softmax_output))

self.assertTrue(
np.allclose(softmax_output_predict, softmax_output, rtol=1e-05))

if __name__ == '__main__':
absltest.main()
2 changes: 1 addition & 1 deletion deepconsensus/postprocess/stitch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def get_full_sequence(deepconsensus_outputs: Iterable[DCModelOutput],
example_width: int,
fill_n: bool = False):
"""Stitch together windows of predictions into a full sequence."""
# <internal>
# TODO: Check if sorting is still necessary.
sorted_deepconsensus_outputs = sorted(
deepconsensus_outputs, key=lambda dc: dc.window_pos)
# Build up the full sequence from the sorted windows.
Expand Down
2 changes: 1 addition & 1 deletion deepconsensus/preprocess/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
'The output filename must end in .tfrecord.gz'))
flags.DEFINE_string('truth_to_ccs', None, 'Input truth bam aligned to ccs.')
flags.DEFINE_string('truth_bed', None, 'Input truth bedfile.')
# <internal>
# TODO TODO
flags.DEFINE_string('truth_split', None,
'Input file defining train/eval/test splits.')
flags.DEFINE_integer(
Expand Down
31 changes: 19 additions & 12 deletions deepconsensus/preprocess/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,6 @@ def iter_examples(self) -> 'DcExample':
if start_pos > self.ccs_width:
break
if window.is_empty:
logging.warning('window at %d has no ccs alignment.', start_pos)
self.counter['n_examples_no_ccs_idx'] += 1
continue
# If the label extends beyond width + padding,
Expand Down Expand Up @@ -595,15 +594,15 @@ def __repr__(self):
end = preview.ccs.ccs_bounds.stop
output = ''
output += (f'{self.name} CCS({start}-{end}) {self.label_coords}'.strip() +
f'\n{"-"*(preview.width+21)}\n')
f'\n{"-"*(preview.width+24)}\n')
for subread in self.subreads:
subread_range = subread.name.split('/')[2]
output += f'{subread_range:<20} >{str(subread)}\n'
output += f'{"CCS":<20} >{str(preview.ccs)}\n'
output += f'{subread_range:<20} {subread.strand} >{str(subread)}\n'
output += f'{"CCS":<22} >{str(preview.ccs)}\n'

if self.is_training:
label = str(self.label)
output += f'{"Label":<20} >{label}\n'
output += f'{"Label":<22} >{label}\n'
return output


Expand Down Expand Up @@ -775,6 +774,7 @@ def expand_clip_indent(read: pysam.AlignedSegment,
* Expand sequence by placing gaps where deletions are present in alignment.
* Remove bases that are part of soft-clips.
* Indent alignment if start position is > 0.
* Reverse ip/pw values when the strand is reverse.
Args:
read: a pysam aligned segment representing a subread, ccs, or label aln.
Expand All @@ -798,11 +798,22 @@ def expand_clip_indent(read: pysam.AlignedSegment,
# Fill read objects based on aligned read idx positions.
new_seq[read_idx >= 0] = list(read.seq)

if read.is_reverse:
strand = dc_constants.Strand.REVERSE
else:
strand = dc_constants.Strand.FORWARD

# pw/ip values are never set for labels.
# truth_range is used to test if we are working with a label Read.
if not truth_range:
new_pw[read_idx >= 0] = read.get_tag('pw')
new_ip[read_idx >= 0] = read.get_tag('ip')
# Reverse ip/pw values if the strand is reversed.
pw_vals = read.get_tag('pw')
ip_vals = read.get_tag('ip')
if strand == dc_constants.Strand.REVERSE:
pw_vals = pw_vals[::-1]
ip_vals = ip_vals[::-1]
new_pw[read_idx >= 0] = pw_vals
new_ip[read_idx >= 0] = ip_vals
sn = np.array(read.get_tag('sn'))
else:
sn = np.empty(0)
Expand All @@ -812,16 +823,12 @@ def expand_clip_indent(read: pysam.AlignedSegment,
new_cigar = np.fromiter(cigar_seq, dtype=np.uint8)
# Filter hard_clip from cigar.
new_cigar = new_cigar[new_cigar != dc_constants.PYSAM_CHARD_CLIP]
if read.is_reverse:
strand = dc_constants.Strand.REVERSE
else:
strand = dc_constants.Strand.FORWARD

# Trim sequence if it is soft-padded.
if np.sum(new_cigar == dc_constants.PYSAM_CSOFT_CLIP) > 0:
new_seq[new_cigar ==
dc_constants.PYSAM_CSOFT_CLIP] = dc_constants.GAP_OR_PAD
# <internal>
# TODO: binary search ignoring -1 vals here.
qstart = np.where(read_idx == read.query_alignment_start)[0][0]
qend = np.where(read_idx == read.query_alignment_end - 1)[0][0] + 1
# Trim soft-padded segments from truth regions.
Expand Down
Loading

0 comments on commit 0e92898

Please sign in to comment.