Skip to content

Commit

Permalink
- raygen shader working with rw texture and onstant buffers, ray trac…
Browse files Browse the repository at this point in the history
…ing is hanging on TraceRay related to scene
  • Loading branch information
polymonster authored and GBDixonAlex committed Feb 1, 2025
1 parent 2621ef9 commit 39d1537
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 45 deletions.
61 changes: 57 additions & 4 deletions examples/raytraced_triangle/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ use gfx::BufferUsage;
use gfx::RaytracingBLASInfo;
use gfx::RaytracingInstanceInfo;
use gfx::RaytracingTLASInfo;
use hotline_rs::gfx::RaytracingTLAS;
use hotline_rs::gfx::Texture;
use hotline_rs::gfx::Pipeline;
use hotline_rs::*;

use gfx::CmdBuf;
Expand All @@ -16,6 +19,11 @@ use os::Window;
use os::win32 as os_platform;
use gfx::d3d12 as gfx_platform;

struct RaytracingViewport {
viewport: [f32; 4],
scissor: [f32; 4]
}

fn main() -> Result<(), hotline_rs::Error> {
let mut app = os_platform::App::create(os::AppInfo {
name: String::from("raytraced_triangle"),
Expand Down Expand Up @@ -110,18 +118,20 @@ fn main() -> Result<(), hotline_rs::Error> {
build_flags: AccelerationStructureBuildFlags::PREFER_FAST_TRACE
})?;

let window_rect = window.get_viewport_rect();

// unordered access rw texture
let rw_info = gfx::TextureInfo {
format: gfx::Format::RGBA8n,
tex_type: gfx::TextureType::Texture2D,
width: 512,
height: 512,
width: window_rect.width as u64,
height: window_rect.height as u64,
depth: 1,
array_layers: 1,
mip_levels: 1,
samples: 1,
usage: gfx::TextureUsage::SHADER_RESOURCE | gfx::TextureUsage::UNORDERED_ACCESS,
initial_state: gfx::ResourceState::ShaderResource,
initial_state: gfx::ResourceState::UnorderedAccess,
};
let raytracing_output = device.create_texture::<u8>(&rw_info, None).unwrap();

Expand All @@ -140,7 +150,36 @@ fn main() -> Result<(), hotline_rs::Error> {

let raytracing_pipeline = pmfx.get_raytracing_pipeline("raytracing")?;
cmd.set_raytracing_pipeline(&raytracing_pipeline.pipeline);
cmd.set_heap(&raytracing_pipeline.pipeline, &device.get_shader_heap()); // TODO: we are here.

// bind rw tex on u0
let uav0 = raytracing_output.get_uav_index().expect("expect raytracing_output to have a uav");
if let Some(u0) = raytracing_pipeline.pipeline.get_pipeline_slot(0, 0, gfx::DescriptorType::UnorderedAccess) {
cmd.set_binding(&raytracing_pipeline.pipeline, device.get_shader_heap(), u0.index, uav0);
}

// set push constants on b0
let border = 0.1;
let aspect = window_rect.width as f32 / window_rect.height as f32;
cmd.push_compute_constants(0, 8, 0, gfx::as_u8_slice(&RaytracingViewport {
viewport: [
-1.0 + border,
-1.0 + border * aspect,
1.0 - border,
1.0 - border * aspect
],
scissor: [
-1.0 + border / aspect,
-1.0 + border,
1.0 - border / aspect,
1.0 - border
]
}));

// bind tlas on t0
let srv0 = tlas.get_srv_index().expect("expect tlas to have an srv");
if let Some(t0) = raytracing_pipeline.pipeline.get_pipeline_slot(0, 0, gfx::DescriptorType::ShaderResource) {
cmd.set_binding(&raytracing_pipeline.pipeline, device.get_shader_heap(), t0.index, srv0);
}

cmd.dispatch_rays(&raytracing_pipeline.sbt, gfx::Size3 {
x: window_rect.width as u32,
Expand All @@ -155,6 +194,13 @@ fn main() -> Result<(), hotline_rs::Error> {
state_after: gfx::ResourceState::CopyDst,
});

cmd.transition_barrier(&gfx::TransitionBarrier {
texture: Some(&raytracing_output),
buffer: None,
state_before: gfx::ResourceState::UnorderedAccess,
state_after: gfx::ResourceState::CopySrc,
});

cmd.copy_texture_region(&swap_chain.get_backbuffer_texture(), 0, 0, 0, 0, &raytracing_output, None);

cmd.transition_barrier(&gfx::TransitionBarrier {
Expand All @@ -164,6 +210,13 @@ fn main() -> Result<(), hotline_rs::Error> {
state_after: gfx::ResourceState::Present,
});

cmd.transition_barrier(&gfx::TransitionBarrier {
texture: Some(&raytracing_output),
buffer: None,
state_before: gfx::ResourceState::CopySrc,
state_after: gfx::ResourceState::UnorderedAccess,
});

cmd.close()?;

// execute command buffer
Expand Down
1 change: 1 addition & 0 deletions shaders/raytracing_example.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ void MyRaygenShader()
lerp(g_rayGenCB.viewport.top, g_rayGenCB.viewport.bottom, lerpValues.y),
0.0f);


if (IsInsideViewport(origin.xy, g_rayGenCB.stencil))
{
// Trace the ray.
Expand Down
3 changes: 3 additions & 0 deletions shaders/raytracing_example.pmfx
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
geometry: Triangles
}
]
push_constants: [
"g_rayGenCB"
]
}
}
}
13 changes: 10 additions & 3 deletions src/gfx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1036,9 +1036,6 @@ pub trait RaytracingPipeline<D: Device>: Send + Sync {}
/// An opaque shader table binding type..
pub trait RaytracingShaderBindingTable<D: Device>: Send + Sync {}

/// An opaque top level acceleration structure for ray tracing geometry
pub trait RaytracingTLAS<D: Device>: Send + Sync {}

/// An opaque bottom level acceleration structure for ray tracing geometry
pub trait RaytracingBLAS<D: Device>: Send + Sync {}

Expand Down Expand Up @@ -1390,6 +1387,8 @@ pub trait CmdBuf<D: Device>: Send + Sync + Clone {
/// Binds the heap with offset (texture srv, uav) on to the `slot` of a pipeline.
/// this is like a traditional bindful render architecture `cmd.set_binding(pipeline, heap, 0, texture1_id)`
fn set_binding<T: Pipeline>(&self, pipeline: &T, heap: &D::Heap, slot: u32, offset: usize);
// TODO:
fn set_tlas(&self, tlas: &D::RaytracingTLAS);
/// Push a small amount of data into the command buffer for a render pipeline, num values and dest offset are the numbr of 32bit values
fn push_render_constants<T: Sized>(&self, slot: u32, num_values: u32, dest_offset: u32, data: &[T]);
/// Push a small amount of data into the command buffer for a compute pipeline, num values and dest offset are the numbr of 32bit values
Expand Down Expand Up @@ -1500,6 +1499,14 @@ pub trait Texture<D: Device>: Send + Sync {
fn get_shader_heap_id(&self) -> Option<u16>;
}

/// An opaque top level acceleration structure for ray tracing geometry
pub trait RaytracingTLAS<D: Device>: Send + Sync {
/// Return the index to access in a shader (if the resource has msaa this is the resolved view)
fn get_srv_index(&self) -> Option<usize>;
/// Return the id of the shader heap
fn get_shader_heap_id(&self) -> u16;
}

/// An opaque shader heap type, use to create views of resources for binding and access in shaders
pub trait Heap<D: Device>: Send + Sync {
/// Deallocate a resource from the heap and mark space in free list for re-use
Expand Down
91 changes: 53 additions & 38 deletions src/gfx/d3d12.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1725,7 +1725,8 @@ pub struct RaytracingBLAS {
}

pub struct RaytracingTLAS {
pub(crate) tlas_buffer: Buffer
pub(crate) tlas_buffer: Buffer,
pub(crate) shader_heap_id: u16
}

impl super::Device for Device {
Expand Down Expand Up @@ -2407,28 +2408,6 @@ impl super::Device for Device {
cbv_index = Some(heap.get_handle_index(&h));
}

if info.usage.contains(super::BufferUsage::SHADER_RESOURCE) {
let h = heap.allocate();
self.device.CreateShaderResourceView(
&buf,
Some(&D3D12_SHADER_RESOURCE_VIEW_DESC {
Format: dxgi_format,
ViewDimension: D3D12_SRV_DIMENSION_BUFFER,
Shader4ComponentMapping: D3D12_DEFAULT_SHADER_4_COMPONENT_MAPPING,
Anonymous: D3D12_SHADER_RESOURCE_VIEW_DESC_0 {
Buffer: D3D12_BUFFER_SRV {
FirstElement: 0,
NumElements: info.num_elements as u32,
StructureByteStride: info.stride as u32,
Flags: D3D12_BUFFER_SRV_FLAG_NONE
}
}
}),
h,
);
srv_index = Some(heap.get_handle_index(&h));
}

// create uav
if info.usage.contains(super::BufferUsage::UNORDERED_ACCESS) {
let h = heap.allocate();
Expand Down Expand Up @@ -2468,6 +2447,29 @@ impl super::Device for Device {
}
}

// srv
if info.usage.contains(super::BufferUsage::SHADER_RESOURCE) {
let h = heap.allocate();
self.device.CreateShaderResourceView(
&buf,
Some(&D3D12_SHADER_RESOURCE_VIEW_DESC {
Format: dxgi_format,
ViewDimension: D3D12_SRV_DIMENSION_BUFFER,
Shader4ComponentMapping: D3D12_DEFAULT_SHADER_4_COMPONENT_MAPPING,
Anonymous: D3D12_SHADER_RESOURCE_VIEW_DESC_0 {
Buffer: D3D12_BUFFER_SRV {
FirstElement: 0,
NumElements: info.num_elements as u32,
StructureByteStride: info.stride as u32,
Flags: D3D12_BUFFER_SRV_FLAG_NONE
}
}
}),
h,
);
srv_index = Some(heap.get_handle_index(&h));
}

Ok(Buffer {
resource: Some(buf),
vbv,
Expand Down Expand Up @@ -3215,7 +3217,7 @@ impl super::Device for Device {

Ok(RaytracingPipeline {
state_object,
root_signature: root_signature.root_signature.clone(), // TODO: we
root_signature: root_signature.root_signature.clone(), // TODO:
lookup: root_signature
})
}
Expand Down Expand Up @@ -3264,12 +3266,13 @@ impl super::Device for Device {
&mut table_buffer,
).expect("hotline_rs::gfx::d3d12: failed to create a shader binding table buffer");

//
if let Some(resource) = table_buffer.as_ref() {
let range = D3D12_RANGE { Begin: 0, End: 0 };
let mut map_data = std::ptr::null_mut();
resource.Map(0, Some(&range), Some(&mut map_data)).expect("hotline_rs::gfx::d3d12: failed to map buffer data for the shader binding table");
std::ptr::copy_nonoverlapping(idents.as_ptr() as *mut _, map_data, buffer_size);
for ident in &idents {
std::ptr::copy_nonoverlapping(*ident as *mut _, map_data, D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES as usize);
}
resource.Unmap(0, None);
}

Expand Down Expand Up @@ -3301,8 +3304,8 @@ impl super::Device for Device {
Ok(RaytracingShaderBindingTable {
ray_generation: create_shader_table(vec![raygen_id]),
miss: create_shader_table(miss_ids),
hit_group: create_shader_table(callable_ids),
callable: create_shader_table(hit_group_ids),
hit_group: create_shader_table(hit_group_ids),
callable: create_shader_table(callable_ids),
})
}
}
Expand Down Expand Up @@ -3480,7 +3483,7 @@ impl super::Device for Device {

// UAV buffer the tlas
let tlas_buffer = self.create_buffer::<u8>(&BufferInfo {
usage: super::BufferUsage::UNORDERED_ACCESS | super::BufferUsage::BUFFER_ONLY,
usage: super::BufferUsage::UNORDERED_ACCESS | super::BufferUsage::BUFFER_ONLY | super::BufferUsage::SHADER_RESOURCE,
cpu_access: super::CpuAccessFlags::NONE,
format: super::Format::Unknown,
stride: prebuild_info.ResultDataMaxSizeInBytes as usize,
Expand All @@ -3500,7 +3503,7 @@ impl super::Device for Device {
unsafe {
let cmd = self.command_list.cast::<ID3D12GraphicsCommandList4>()
.expect("hotline_rs::gfx::d3d12: expected ID3D12GraphicsCommandList4 availability to create raytracing blas");
cmd.BuildRaytracingAccelerationStructure(&tlas_desc, None);
cmd.BuildRaytracingAccelerationStructure(&tlas_desc, None);
cmd.Close()?;
}

Expand All @@ -3509,7 +3512,8 @@ impl super::Device for Device {

// return the result
Ok(RaytracingTLAS {
tlas_buffer
tlas_buffer,
shader_heap_id: self.shader_heap.as_ref().map(|x| x.id).unwrap_or(0)
})
}

Expand Down Expand Up @@ -4298,12 +4302,14 @@ impl super::CmdBuf<Device> for CmdBuf {
}
}

fn set_tlas(&self, tlas: &RaytracingTLAS) {
let cmd = self.cmd().cast::<ID3D12GraphicsCommandList4>().unwrap();
unsafe {
cmd.SetComputeRootShaderResourceView(0, tlas.tlas_buffer.resource.as_ref().unwrap().GetGPUVirtualAddress());
}
}

fn dispatch_rays(&self, sbt: &RaytracingShaderBindingTable, numthreads: Size3) {
/*
// Bind the heaps, acceleration structure and dispatch rays.
commandList->SetComputeRootShaderResourceView(GlobalRootSignatureParams::AccelerationStructureSlot, m_topLevelAccelerationStructure->GetGPUVirtualAddress());
*/
unsafe {
let dispatch_desc = D3D12_DISPATCH_RAYS_DESC {
RayGenerationShaderRecord: D3D12_GPU_VIRTUAL_ADDRESS_RANGE {
Expand Down Expand Up @@ -4742,6 +4748,16 @@ impl super::Texture<Device> for Texture {
}
}

impl super::RaytracingTLAS<Device> for RaytracingTLAS {
fn get_srv_index(&self) -> Option<usize> {
self.tlas_buffer.srv_index
}

fn get_shader_heap_id(&self) -> u16 {
self.shader_heap_id
}
}

impl super::Pipeline for RenderPipeline {
fn get_pipeline_slot(&self, register: u32, space: u32, descriptor_type: DescriptorType) -> Option<&super::PipelineSlotInfo> {
let h = get_binding_descriptor_hash(register, space, descriptor_type);
Expand Down Expand Up @@ -4976,5 +4992,4 @@ impl super::ComputePipeline<Device> for ComputePipeline {}
impl super::RaytracingPipeline<Device> for RaytracingPipeline {}
impl super::CommandSignature<Device> for CommandSignature {}
impl super::RaytracingShaderBindingTable<Device> for RaytracingShaderBindingTable {}
impl super::RaytracingBLAS<Device> for RaytracingBLAS {}
impl super::RaytracingTLAS<Device> for RaytracingTLAS {}
impl super::RaytracingBLAS<Device> for RaytracingBLAS {}

0 comments on commit 39d1537

Please sign in to comment.