Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix wgpu shader object uniform data handling #91

Merged
merged 2 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/wgpu/wgpu-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ Context::~Context()
}
}

DeviceImpl::~DeviceImpl() {}
DeviceImpl::~DeviceImpl()
{
m_shaderObjectLayoutCache = decltype(m_shaderObjectLayoutCache)();
m_queue.setNull();
}

Result DeviceImpl::getNativeDeviceHandles(DeviceNativeHandles* outHandles)
{
Expand Down
76 changes: 26 additions & 50 deletions src/wgpu/wgpu-shader-object.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,6 @@ Result ShaderObjectImpl::init(IDevice* device, ShaderObjectLayoutImpl* layout)

m_layout = layout;

m_constantBufferTransientHeap = nullptr;
m_constantBufferTransientHeapVersion = 0;
m_isConstantBufferDirty = true;

// If the layout tells us that there is any uniform data,
Expand Down Expand Up @@ -193,21 +191,15 @@ Result ShaderObjectImpl::init(IDevice* device, ShaderObjectLayoutImpl* layout)
return SLANG_OK;
}

Result ShaderObjectImpl::_writeOrdinaryData(
PassEncoderImpl* encoder,
IBuffer* buffer,
Offset offset,
Size destSize,
ShaderObjectLayoutImpl* specializedLayout
)
Result ShaderObjectImpl::_writeOrdinaryData(uint8_t* destData, Size destSize, ShaderObjectLayoutImpl* specializedLayout)
{
auto src = m_data.getBuffer();
// TODO: Change size_t to Count?
auto srcSize = size_t(m_data.getCount());

SLANG_RHI_ASSERT(srcSize <= destSize);

encoder->uploadBufferDataImpl(buffer, offset, srcSize, src);
std::memcpy(destData, src, srcSize);

// In the case where this object has any sub-objects of
// existential/interface type, we need to recurse on those objects
Expand Down Expand Up @@ -284,13 +276,7 @@ Result ShaderObjectImpl::_writeOrdinaryData(

auto subObjectOffset = subObjectRangePendingDataOffset + i * subObjectRangePendingDataStride;

subObject->_writeOrdinaryData(
encoder,
buffer,
offset + subObjectOffset,
destSize - subObjectOffset,
subObjectLayout
);
subObject->_writeOrdinaryData(destData + subObjectOffset, destSize - subObjectOffset, subObjectLayout);
}
}

Expand Down Expand Up @@ -385,49 +371,40 @@ void ShaderObjectImpl::writeSamplerDescriptor(
}
}

bool ShaderObjectImpl::shouldAllocateConstantBuffer(TransientResourceHeapImpl* transientHeap)
{
return m_isConstantBufferDirty || m_constantBufferTransientHeap != transientHeap ||
m_constantBufferTransientHeapVersion != transientHeap->getVersion();
}

Result ShaderObjectImpl::_ensureOrdinaryDataBufferCreatedIfNeeded(
PassEncoderImpl* encoder,
ShaderObjectLayoutImpl* specializedLayout
)
{
// If data has been changed since last allocation/filling of constant buffer,
// we will need to allocate a new one.
//
if (!shouldAllocateConstantBuffer(encoder->m_commandBuffer->m_transientHeap))
{
return SLANG_OK;
}
m_isConstantBufferDirty = false;
m_constantBufferTransientHeap = encoder->m_commandBuffer->m_transientHeap;
m_constantBufferTransientHeapVersion = encoder->m_commandBuffer->m_transientHeap->getVersion();

m_constantBufferSize = specializedLayout->getTotalOrdinaryDataSize();
if (m_constantBufferSize == 0)
{
return SLANG_OK;
}

// Once we have computed how large the buffer should be, we can allocate
// it from the transient resource heap.
//
SLANG_RETURN_ON_FAIL(encoder->m_commandBuffer->m_transientHeap
->allocateConstantBuffer(m_constantBufferSize, m_constantBuffer, m_constantBufferOffset));

// Once the buffer is allocated, we can use `_writeOrdinaryData` to fill it in.
// For simplicity we always create a new buffer when the data changes.
//
// Note that `_writeOrdinaryData` is potentially recursive in the case
// where this object contains interface/existential-type fields, so we
// don't need or want to inline it into this call site.
//
SLANG_RETURN_ON_FAIL(
_writeOrdinaryData(encoder, m_constantBuffer, m_constantBufferOffset, m_constantBufferSize, specializedLayout)
);
if (m_isConstantBufferDirty)
{
// First, fill in a CPU buffer using `_writeOrdinaryData`.
//
// Note that `_writeOrdinaryData` is potentially recursive in the case
// where this object contains interface/existential-type fields, so we
// don't need or want to inline it into this call site.
//
m_constantBufferData.resize(m_constantBufferSize);
SLANG_RETURN_ON_FAIL(_writeOrdinaryData(m_constantBufferData.data(), m_constantBufferSize, specializedLayout));
// With all the data collected we create a new constant buffer.
//
BufferDesc bufferDesc = {};
bufferDesc.size = m_constantBufferSize;
bufferDesc.usage = BufferUsage::ConstantBuffer;
bufferDesc.defaultState = ResourceState::ConstantBuffer;
bufferDesc.memoryType = MemoryType::DeviceLocal;
ComPtr<IBuffer> buffer;
SLANG_RETURN_ON_FAIL(encoder->m_device->createBuffer(bufferDesc, nullptr, buffer.writeRef()));
m_constantBuffer = checked_cast<BufferImpl*>(buffer.get());
}

return SLANG_OK;
}
Expand Down Expand Up @@ -685,8 +662,7 @@ Result ShaderObjectImpl::bindOrdinaryDataBufferIfNeeded(
//
if (m_constantBuffer && m_constantBufferSize > 0)
{
auto bufferImpl = checked_cast<BufferImpl*>(m_constantBuffer);
writeBufferDescriptor(context, ioOffset, bufferImpl, m_constantBufferOffset, m_constantBufferSize);
writeBufferDescriptor(context, ioOffset, m_constantBuffer, 0, m_constantBufferSize);
ioOffset.binding++;
}

Expand Down
24 changes: 6 additions & 18 deletions src/wgpu/wgpu-shader-object.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,7 @@ class ShaderObjectImpl : public ShaderObjectBaseImpl<ShaderObjectImpl, ShaderObj

/// Write the uniform/ordinary data of this object into the given `dest` buffer at the given
/// `offset`
Result _writeOrdinaryData(
PassEncoderImpl* encoder,
IBuffer* buffer,
Offset offset,
Size destSize,
ShaderObjectLayoutImpl* specializedLayout
);
Result _writeOrdinaryData(uint8_t* destData, Size destSize, ShaderObjectLayoutImpl* specializedLayout);

public:
/// Write a single descriptor using the Vulkan API
Expand Down Expand Up @@ -105,8 +99,6 @@ class ShaderObjectImpl : public ShaderObjectBaseImpl<ShaderObjectImpl, ShaderObj
span<RefPtr<SamplerImpl>> samplers
);

bool shouldAllocateConstantBuffer(TransientResourceHeapImpl* transientHeap);

/// Ensure that the `m_ordinaryDataBuffer` has been created, if it is needed
Result _ensureOrdinaryDataBufferCreatedIfNeeded(
PassEncoderImpl* encoder,
Expand Down Expand Up @@ -165,19 +157,15 @@ class ShaderObjectImpl : public ShaderObjectBaseImpl<ShaderObjectImpl, ShaderObj
std::vector<ResourceSlot> m_resources;
std::vector<RefPtr<SamplerImpl>> m_samplers;

// The transient constant buffer that holds the GPU copy of the constant data,
// weak referenced.
IBuffer* m_constantBuffer = nullptr;
// The offset into the transient constant buffer where the constant data starts.
Offset m_constantBufferOffset = 0;
// The size of the constant buffer for this object.
Size m_constantBufferSize = 0;
// The constant buffer containing all the ordinary data for this object.
RefPtr<BufferImpl> m_constantBuffer;
// A CPU memory buffer containing the ordinary data for this object.
std::vector<uint8_t> m_constantBufferData;

/// Dirty bit tracking whether the constant buffer needs to be updated.
bool m_isConstantBufferDirty = true;
/// The transient heap from which the constant buffer is allocated.
TransientResourceHeapImpl* m_constantBufferTransientHeap;
/// The version of the transient heap when the constant buffer is allocated.
uint64_t m_constantBufferTransientHeapVersion;

/// Get the layout of this shader object with specialization arguments considered
///
Expand Down