Skip to content

Commit c00148c

Browse files
committed
utility routine to copy dense tensors
1 parent b1c8bd6 commit c00148c

1 file changed

Lines changed: 25 additions & 2 deletions

File tree

src/tamm/tamm_utils.hpp

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1615,8 +1615,11 @@ template<typename TensorType>
16151615
void from_block_cyclic_tensor(Tensor<TensorType> bc_tensor, Tensor<TensorType> tensor,
16161616
bool is_bc = true) {
16171617
const auto ndims = bc_tensor.num_modes();
1618-
EXPECTS(ndims == 2);
1619-
if(is_bc) EXPECTS(bc_tensor.is_block_cyclic());
1618+
1619+
if(is_bc) {
1620+
EXPECTS(ndims == 2);
1621+
EXPECTS(bc_tensor.is_block_cyclic());
1622+
}
16201623
EXPECTS(bc_tensor.kind() == TensorBase::TensorKind::dense);
16211624
EXPECTS(bc_tensor.distribution().kind() == DistributionKind::dense);
16221625

@@ -1660,6 +1663,26 @@ void from_block_cyclic_tensor(Tensor<TensorType> bc_tensor, Tensor<TensorType> t
16601663
block_for(ec, tensor(), tamm_bc_lambda);
16611664
}
16621665

1666+
// tamm does not support set, add, mult ops for dense tensors
1667+
template<typename TensorType>
1668+
void copy_dense_tensor(Tensor<TensorType> stensor, Tensor<TensorType> dtensor) {
1669+
EXPECTS(stensor.kind() == TensorBase::TensorKind::dense);
1670+
EXPECTS(stensor.distribution().kind() == DistributionKind::dense);
1671+
EXPECTS(dtensor.kind() == TensorBase::TensorKind::dense);
1672+
EXPECTS(dtensor.distribution().kind() == DistributionKind::dense);
1673+
1674+
// stensor might be on a smaller process group
1675+
ExecutionContext& ec = get_ec(stensor());
1676+
1677+
auto tamm_bc_lambda = [&](const IndexVector& blockid) {
1678+
std::vector<TensorType> buffer(stensor.block_size(blockid));
1679+
stensor.get(blockid, buffer);
1680+
dtensor.put(blockid, buffer);
1681+
};
1682+
1683+
block_for(ec, dtensor(), tamm_bc_lambda);
1684+
}
1685+
16631686
// convert dense tamm tensor to regular tamm tensor
16641687
template<typename TensorType>
16651688
void from_dense_tensor(Tensor<TensorType> d_tensor, Tensor<TensorType> tensor) {

0 commit comments

Comments
 (0)