diff --git a/Cargo.lock b/Cargo.lock index 3ec5fff..b75d8f8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1329,7 +1329,7 @@ checksum = "bff34eb29ff4b8a8688bc7299f14fb6b597461ca80fec03ed7d22939ab33e48f" [[package]] name = "bevy_naga_reflect" version = "0.1.0" -source = "git+https://github.com/tychedelia/bevy_naga_reflect#60010545e20027c7ae2ca084b21ce014664ccd36" +source = "git+https://github.com/tychedelia/bevy_naga_reflect#1d6bfcdddaf44e7a3ed2c4a946e6af2ace2f9f44" dependencies = [ "bevy", "naga", @@ -1418,7 +1418,7 @@ dependencies = [ "wasm-bindgen", "wasm-bindgen-futures", "web-time", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -2023,9 +2023,9 @@ dependencies = [ [[package]] name = "blake3" -version = "1.8.4" +version = "1.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d2d5991425dfd0785aed03aedcf0b321d61975c9b5b3689c774a2610ae0b51e" +checksum = "0aa83c34e62843d924f905e0f5c866eb1dd6545fc4d719e803d9ba6030371fce" dependencies = [ "arrayref", "arrayvec", @@ -2178,9 +2178,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.60" +version = "1.2.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43c5703da9466b66a946814e1adf53ea2c90f10063b86290cc9eb67ce3478a20" +checksum = "d16d90359e986641506914ba71350897565610e87ce0ad9e6f28569db3dd5c6d" dependencies = [ "find-msvc-tools", "jobserver", @@ -2721,9 +2721,9 @@ checksum = "0c87e182de0887fd5361989c677c4e8f5000cd9491d6d563161a8f3a5519fc7f" [[package]] name = "data-encoding" -version = "2.10.0" +version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7a1e2f27636f116493b8b860f5546edb47c8d8f8ea73e1d2a20be88e28d1fea" +checksum = "a4ae5f15dda3c708c0ade84bfee31ccab44a3da4f88015ed22f63732abe300c8" [[package]] name = "derive_more" @@ -4156,9 +4156,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.185" +version = "0.2.186" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52ff2c0fe9bc6cb6b14a0592c2ff4fa9ceb83eea9db979b0487cd054946a2b8f" +checksum = "68ab91017fe16c622486840e4c83c9a37afeff978bd239b5293d61ece587de66" [[package]] name = "libfuzzer-sys" @@ -5752,7 +5752,6 @@ dependencies = [ "objc2 0.6.4", "objc2-app-kit 0.3.2", "processing_core", - "processing_midi", "raw-window-handle", "wasm-bindgen", "wasm-bindgen-futures", @@ -5812,7 +5811,7 @@ checksum = "e0c5ccf5294c6ccd63a74f1565028353830a9c2f5eb0c682c355c471726a6e3f" [[package]] name = "pyo3" version = "0.28.3" -source = "git+https://github.com/PyO3/pyo3?branch=main#999560aadb5d4d4bdb10670169fc9294663a6313" +source = "git+https://github.com/PyO3/pyo3?branch=main#20781441337e84362bce32e43363a199a6182aab" dependencies = [ "inventory", "libc", @@ -5826,7 +5825,7 @@ dependencies = [ [[package]] name = "pyo3-build-config" version = "0.28.3" -source = "git+https://github.com/PyO3/pyo3?branch=main#999560aadb5d4d4bdb10670169fc9294663a6313" +source = "git+https://github.com/PyO3/pyo3?branch=main#20781441337e84362bce32e43363a199a6182aab" dependencies = [ "target-lexicon", ] @@ -5834,7 +5833,7 @@ dependencies = [ [[package]] name = "pyo3-ffi" version = "0.28.3" -source = "git+https://github.com/PyO3/pyo3?branch=main#999560aadb5d4d4bdb10670169fc9294663a6313" +source = "git+https://github.com/PyO3/pyo3?branch=main#20781441337e84362bce32e43363a199a6182aab" dependencies = [ "libc", "pyo3-build-config", @@ -5843,7 +5842,7 @@ dependencies = [ [[package]] name = "pyo3-introspection" version = "0.28.3" -source = "git+https://github.com/PyO3/pyo3?branch=main#999560aadb5d4d4bdb10670169fc9294663a6313" +source = "git+https://github.com/PyO3/pyo3?branch=main#20781441337e84362bce32e43363a199a6182aab" dependencies = [ "anyhow", "goblin", @@ -5854,7 +5853,7 @@ dependencies = [ [[package]] name = "pyo3-macros" version = "0.28.3" -source = "git+https://github.com/PyO3/pyo3?branch=main#999560aadb5d4d4bdb10670169fc9294663a6313" +source = "git+https://github.com/PyO3/pyo3?branch=main#20781441337e84362bce32e43363a199a6182aab" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -5865,7 +5864,7 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" version = "0.28.3" -source = "git+https://github.com/PyO3/pyo3?branch=main#999560aadb5d4d4bdb10670169fc9294663a6313" +source = "git+https://github.com/PyO3/pyo3?branch=main#20781441337e84362bce32e43363a199a6182aab" dependencies = [ "heck", "proc-macro2", diff --git a/Cargo.toml b/Cargo.toml index 5919779..41c9f3b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -159,6 +159,10 @@ path = "examples/blend_modes.rs" name = "camera_controllers" path = "examples/camera_controllers.rs" +[[example]] +name = "compute_readback" +path = "examples/compute_readback.rs" + [profile.wasm-release] inherits = "release" opt-level = "z" diff --git a/crates/processing_core/src/error.rs b/crates/processing_core/src/error.rs index 7cc1e9e..2fc9607 100644 --- a/crates/processing_core/src/error.rs +++ b/crates/processing_core/src/error.rs @@ -32,8 +32,8 @@ pub enum ProcessingError { TransformNotFound, #[error("Material not found")] MaterialNotFound, - #[error("Unknown material property: {0}")] - UnknownMaterialProperty(String), + #[error("Unknown shader property: {0}")] + UnknownShaderProperty(String), #[error("GLTF load error: {0}")] GltfLoadError(String), #[error("Webcam not connected")] @@ -46,4 +46,14 @@ pub enum ProcessingError { MidiPortNotFound(usize), #[error("CUDA error: {0}")] CudaError(String), + #[error("Compute shader not found")] + ComputeNotFound, + #[error("Buffer not found")] + BufferNotFound, + #[error("Buffer map error: {0}")] + BufferMapError(String), + #[error("Pipeline compile error: {0}")] + PipelineCompileError(String), + #[error("Pipeline not ready after {0} frames")] + PipelineNotReady(u32), } diff --git a/crates/processing_core/src/lib.rs b/crates/processing_core/src/lib.rs index 9de4d39..2a9118b 100644 --- a/crates/processing_core/src/lib.rs +++ b/crates/processing_core/src/lib.rs @@ -15,7 +15,9 @@ thread_local! { pub fn app_mut(cb: impl FnOnce(&mut App) -> error::Result) -> error::Result { let res = APP.with(|app_cell| { - let mut app_borrow = app_cell.borrow_mut(); + let mut app_borrow = app_cell + .try_borrow_mut() + .map_err(|_| error::ProcessingError::AppAccess)?; let app = app_borrow .as_mut() .ok_or(error::ProcessingError::AppAccess)?; diff --git a/crates/processing_ffi/src/lib.rs b/crates/processing_ffi/src/lib.rs index 7c76491..d2b9d24 100644 --- a/crates/processing_ffi/src/lib.rs +++ b/crates/processing_ffi/src/lib.rs @@ -10,6 +10,12 @@ use crate::color::Color; mod color; mod error; +unsafe fn cstr_to_str<'a>(ptr: *const std::ffi::c_char) -> Result<&'a str, ProcessingError> { + unsafe { std::ffi::CStr::from_ptr(ptr) } + .to_str() + .map_err(|_| ProcessingError::InvalidArgument("non-UTF8 C string".to_string())) +} + /// Initialize libProcessing. /// /// SAFETY: @@ -1776,12 +1782,12 @@ pub unsafe extern "C" fn processing_material_set_float( value: f32, ) { error::clear_error(); - let name = unsafe { std::ffi::CStr::from_ptr(name) }.to_str().unwrap(); error::check(|| { + let name = unsafe { cstr_to_str(name) }?; material_set( Entity::from_bits(mat_id), name, - material::MaterialValue::Float(value), + shader_value::ShaderValue::Float(value), ) }); } @@ -1800,12 +1806,12 @@ pub unsafe extern "C" fn processing_material_set_float4( a: f32, ) { error::clear_error(); - let name = unsafe { std::ffi::CStr::from_ptr(name) }.to_str().unwrap(); error::check(|| { + let name = unsafe { cstr_to_str(name) }?; material_set( Entity::from_bits(mat_id), name, - material::MaterialValue::Float4([r, g, b, a]), + shader_value::ShaderValue::Float4([r, g, b, a]), ) }); } @@ -1824,6 +1830,172 @@ pub extern "C" fn processing_material(window_id: u64, mat_id: u64) { error::check(|| graphics_record_command(window_entity, DrawCommand::Material(mat_entity))); } +// Shader + +/// Create a shader from WGSL source. +/// +/// # Safety +/// - `source` must be non-null +#[unsafe(no_mangle)] +pub unsafe extern "C" fn processing_shader_create(source: *const std::ffi::c_char) -> u64 { + error::clear_error(); + error::check(|| { + let source = unsafe { cstr_to_str(source) }?; + shader_create(source) + }) + .map(|e| e.to_bits()) + .unwrap_or(0) +} + +/// Load a shader from a file path. +/// +/// # Safety +/// - `path` must be non-null +#[unsafe(no_mangle)] +pub unsafe extern "C" fn processing_shader_load(path: *const std::ffi::c_char) -> u64 { + error::clear_error(); + error::check(|| { + let path = unsafe { cstr_to_str(path) }?; + shader_load(path) + }) + .map(|e| e.to_bits()) + .unwrap_or(0) +} + +#[unsafe(no_mangle)] +pub extern "C" fn processing_shader_destroy(shader_id: u64) { + error::clear_error(); + error::check(|| shader_destroy(Entity::from_bits(shader_id))); +} + +// Buffer + +#[unsafe(no_mangle)] +pub extern "C" fn processing_buffer_create(size: u64) -> u64 { + error::clear_error(); + error::check(|| buffer_create(size)) + .map(|e| e.to_bits()) + .unwrap_or(0) +} + +/// Create a buffer initialized with data. +/// +/// # Safety +/// - `data` must point to `len` valid bytes +#[unsafe(no_mangle)] +pub unsafe extern "C" fn processing_buffer_create_with_data(data: *const u8, len: u64) -> u64 { + error::clear_error(); + let bytes = unsafe { std::slice::from_raw_parts(data, len as usize) }.to_vec(); + error::check(|| buffer_create_with_data(bytes)) + .map(|e| e.to_bits()) + .unwrap_or(0) +} + +/// Write data to a buffer. +/// +/// # Safety +/// - `data` must point to `len` valid bytes +#[unsafe(no_mangle)] +pub unsafe extern "C" fn processing_buffer_write(buf_id: u64, data: *const u8, len: u64) { + error::clear_error(); + let bytes = unsafe { std::slice::from_raw_parts(data, len as usize) }.to_vec(); + error::check(|| buffer_write(Entity::from_bits(buf_id), bytes)); +} + +/// Returns the byte length of a buffer, or 0 if the buffer does not exist +/// (in which case the error is set). +#[unsafe(no_mangle)] +pub extern "C" fn processing_buffer_size(buf_id: u64) -> u64 { + error::clear_error(); + error::check(|| buffer_size(Entity::from_bits(buf_id))).unwrap_or(0) +} + +/// Read buffer contents into a caller-provided buffer. +/// +/// # Safety +/// - `out` must be valid for writes of `out_len` bytes (may be null if +/// `out_len == 0`, in which case this acts as a size query). +#[unsafe(no_mangle)] +pub unsafe extern "C" fn processing_buffer_read(buf_id: u64, out: *mut u8, out_len: u64) -> u64 { + error::clear_error(); + let Some(data) = error::check(|| buffer_read(Entity::from_bits(buf_id))) else { + return 0; + }; + let needed = data.len() as u64; + if needed <= out_len { + unsafe { std::ptr::copy_nonoverlapping(data.as_ptr(), out, data.len()) }; + } + needed +} + +#[unsafe(no_mangle)] +pub extern "C" fn processing_buffer_destroy(buf_id: u64) { + error::clear_error(); + error::check(|| buffer_destroy(Entity::from_bits(buf_id))); +} + +// Compute + +#[unsafe(no_mangle)] +pub extern "C" fn processing_compute_create(shader_id: u64) -> u64 { + error::clear_error(); + error::check(|| compute_create(Entity::from_bits(shader_id))) + .map(|e| e.to_bits()) + .unwrap_or(0) +} + +/// Set a float property on a compute shader. +/// +/// # Safety +/// - `name` must be non-null +#[unsafe(no_mangle)] +pub unsafe extern "C" fn processing_compute_set_float( + compute_id: u64, + name: *const std::ffi::c_char, + value: f32, +) { + error::clear_error(); + error::check(|| { + let name = unsafe { cstr_to_str(name) }?; + compute_set( + Entity::from_bits(compute_id), + name, + shader_value::ShaderValue::Float(value), + ) + }); +} + +/// # Safety +/// `name` must be a valid null-terminated C string. +#[unsafe(no_mangle)] +pub unsafe extern "C" fn processing_compute_set_buffer( + compute_id: u64, + name: *const std::ffi::c_char, + buf_id: u64, +) { + error::clear_error(); + error::check(|| { + let name = unsafe { cstr_to_str(name) }?; + compute_set( + Entity::from_bits(compute_id), + name, + shader_value::ShaderValue::Buffer(Entity::from_bits(buf_id)), + ) + }); +} + +#[unsafe(no_mangle)] +pub extern "C" fn processing_compute_dispatch(compute_id: u64, x: u32, y: u32, z: u32) { + error::clear_error(); + error::check(|| compute_dispatch(Entity::from_bits(compute_id), x, y, z)); +} + +#[unsafe(no_mangle)] +pub extern "C" fn processing_compute_destroy(compute_id: u64) { + error::clear_error(); + error::check(|| compute_destroy(Entity::from_bits(compute_id))); +} + // Mouse buttons pub const PROCESSING_MOUSE_LEFT: u8 = 0; pub const PROCESSING_MOUSE_MIDDLE: u8 = 1; diff --git a/crates/processing_pyo3/examples/compute.py b/crates/processing_pyo3/examples/compute.py new file mode 100644 index 0000000..26fb5e0 --- /dev/null +++ b/crates/processing_pyo3/examples/compute.py @@ -0,0 +1,106 @@ +import struct + +from mewnala import Graphics, Shader, Compute, Buffer + +g = Graphics.new_offscreen(1, 1, "", None) +g.begin_draw() + +shader = Shader(""" +@group(0) @binding(0) +var output: array; + +@compute @workgroup_size(1) +fn main() { + output[0] = 1u; + output[1] = 2u; + output[2] = 3u; + output[3] = 4u; +} +""") + +buf = Buffer(size=16) +compute = Compute(shader) +compute.set(output=buf) +compute.dispatch(1, 1, 1) + +data = buf.read() +assert isinstance(data, bytes), f"expected bytes, got {type(data)}" +assert list(struct.unpack("<4I", data)) == [1, 2, 3, 4] +print("PASS") + + +buf2 = Buffer(data=[10.0, 20.0, 30.0, 40.0]) +assert len(buf2) == 4 +assert buf2[0] == 10.0 +assert buf2[-1] == 40.0 +assert buf2[1:3] == [20.0, 30.0] + +buf2[2] = 99.0 +assert buf2[2] == 99.0 + +buf2[0:2] = [111.0, 222.0] +assert buf2[0] == 111.0 +assert buf2[1] == 222.0 +print("PASS") + + +double_shader = Shader(""" +@group(0) @binding(0) +var data: array; + +@compute @workgroup_size(4) +fn main(@builtin(global_invocation_id) id: vec3) { + data[id.x] = data[id.x] * 2.0; +} +""") + +buf3 = Buffer(data=[1.0, 2.0, 3.0, 4.0]) +compute3 = Compute(double_shader) +compute3.set(data=buf3) +compute3.dispatch(1, 1, 1) +assert buf3.read() == [2.0, 4.0, 6.0, 8.0] +print("PASS") + + +compute3.dispatch(1, 1, 1) +assert buf3.read() == [4.0, 8.0, 12.0, 16.0] +print("PASS") + + +wg_shader = Shader(""" +@group(0) @binding(0) +var output: array; + +@compute @workgroup_size(4) +fn main(@builtin(global_invocation_id) id: vec3) { + output[id.x] = id.x + 1u; +} +""") + +buf5 = Buffer(size=32) +compute5 = Compute(wg_shader) +compute5.set(output=buf5) +compute5.dispatch(2, 1, 1) +assert list(struct.unpack("<8I", buf5.read())) == [1, 2, 3, 4, 5, 6, 7, 8] +print("PASS") + + +copy_shader = Shader(""" +@group(0) @binding(0) var src: array; +@group(0) @binding(1) var dst: array; + +@compute @workgroup_size(4) +fn main(@builtin(global_invocation_id) id: vec3) { + dst[id.x] = src[id.x] * 10.0; +} +""") + +src_buf = Buffer(data=[1.0, 2.0, 3.0, 4.0]) +dst_buf = Buffer(size=16) +compute6 = Compute(copy_shader) +compute6.set(src=src_buf, dst=dst_buf) +compute6.dispatch(1, 1, 1) +assert list(struct.unpack("<4f", dst_buf.read())) == [10.0, 20.0, 30.0, 40.0] +print("PASS") + +g.end_draw() \ No newline at end of file diff --git a/crates/processing_pyo3/src/compute.rs b/crates/processing_pyo3/src/compute.rs new file mode 100644 index 0000000..83e6f28 --- /dev/null +++ b/crates/processing_pyo3/src/compute.rs @@ -0,0 +1,280 @@ +use bevy::prelude::Entity; +use processing::prelude::*; +use pyo3::{ + exceptions::{PyIndexError, PyRuntimeError, PyTypeError, PyValueError}, + prelude::*, + types::{PyBytes, PyList, PySlice, PySliceIndices}, +}; + +use shader_value::ShaderValue; + +use crate::material::py_to_shader_value; +use crate::shader::Shader; + +#[pyclass(unsendable)] +pub struct Buffer { + pub(crate) entity: Entity, + element_type: Option, + size: u64, +} + +#[pymethods] +impl Buffer { + #[new] + #[pyo3(signature = (size=None, data=None))] + pub fn new(size: Option, data: Option<&Bound<'_, PyAny>>) -> PyResult { + let (entity, size, element_type) = if let Some(data) = data { + let (bytes, element_type) = shader_values_to_bytes(data)?; + let size = bytes.len() as u64; + let entity = buffer_create_with_data(bytes) + .map_err(|e| PyRuntimeError::new_err(format!("{e}")))?; + (entity, size, element_type) + } else { + let size = size.unwrap_or(0); + let entity = + buffer_create(size).map_err(|e| PyRuntimeError::new_err(format!("{e}")))?; + (entity, size, None) + }; + Ok(Self { + entity, + element_type, + size, + }) + } + + pub fn __len__(&self) -> usize { + match &self.element_type { + Some(et) => et + .byte_size() + .map(|s| self.size as usize / s) + .unwrap_or(self.size as usize), + None => self.size as usize, + } + } + + pub fn __getitem__(&self, py: Python<'_>, index: &Bound<'_, PyAny>) -> PyResult> { + let Some(ref et) = self.element_type else { + return Err(PyTypeError::new_err("no element type; write values first")); + }; + let elem_size = et.byte_size().unwrap() as u64; + + let read = |i: isize| -> PyResult> { + let bytes = buffer_read_element(self.entity, i as u64 * elem_size, elem_size) + .map_err(|e| PyRuntimeError::new_err(format!("{e}")))?; + let sv = et + .read_from_bytes(&bytes) + .ok_or_else(|| PyRuntimeError::new_err("failed to decode element"))?; + shader_value_to_py(py, &sv) + }; + + if let Ok(i) = index.extract::() { + Ok(read(self.normalize_index(i)? as isize)?.into()) + } else if let Ok(slice) = index.cast::() { + let indices = slice.indices(self.__len__() as isize)?; + let values = slice_positions(&indices) + .map(read) + .collect::>>()?; + Ok(PyList::new(py, values)?.into()) + } else { + Err(PyTypeError::new_err("index must be int or slice")) + } + } + + pub fn __setitem__( + &mut self, + index: &Bound<'_, PyAny>, + value: &Bound<'_, PyAny>, + ) -> PyResult<()> { + if let Ok(i) = index.extract::() { + let sv = py_to_shader_value(value)?; + self.check_element_type(&sv)?; + let bytes = sv + .to_bytes() + .ok_or_else(|| PyTypeError::new_err("unsupported value type for buffer"))?; + let elem_size = bytes.len() as u64; + let i = self.normalize_index(i)?; + buffer_write_element(self.entity, i as u64 * elem_size, bytes) + .map_err(|e| PyRuntimeError::new_err(format!("{e}"))) + } else if let Ok(slice) = index.cast::() { + let (src_bytes, element_type) = shader_values_to_bytes(value)?; + let et = element_type + .ok_or_else(|| PyTypeError::new_err("unsupported value type for buffer"))?; + let elem_size = et.byte_size().unwrap() as u64; + self.check_element_type(&et)?; + let indices = slice.indices(self.__len__() as isize)?; + let src_elems = src_bytes.len() as u64 / elem_size; + if indices.slicelength as u64 != src_elems { + return Err(PyValueError::new_err(format!( + "slice length {} does not match value length {}", + indices.slicelength, src_elems + ))); + } + for (pos, chunk) in + slice_positions(&indices).zip(src_bytes.chunks_exact(elem_size as usize)) + { + buffer_write_element(self.entity, pos as u64 * elem_size, chunk.to_vec()) + .map_err(|e| PyRuntimeError::new_err(format!("{e}")))?; + } + Ok(()) + } else { + Err(PyTypeError::new_err("index must be int or slice")) + } + } + + pub fn write(&mut self, values: &Bound<'_, PyAny>) -> PyResult<()> { + let (bytes, element_type) = shader_values_to_bytes(values)?; + self.element_type = element_type; + buffer_write(self.entity, bytes).map_err(|e| PyRuntimeError::new_err(format!("{e}"))) + } + + pub fn read<'py>(&mut self, py: Python<'py>) -> PyResult> { + let data = buffer_read(self.entity).map_err(|e| PyRuntimeError::new_err(format!("{e}")))?; + + let Some(ref template) = self.element_type else { + return Ok(PyBytes::new(py, &data).into_any()); + }; + + let elem_size = template + .byte_size() + .ok_or_else(|| PyRuntimeError::new_err("unsupported element type"))?; + + let values = data + .chunks_exact(elem_size) + .map(|chunk| { + let sv = template + .read_from_bytes(chunk) + .ok_or_else(|| PyRuntimeError::new_err("failed to decode bytes"))?; + shader_value_to_py(py, &sv) + }) + .collect::>>()?; + + Ok(PyList::new(py, values)?.into_any()) + } +} + +impl Buffer { + fn check_element_type(&mut self, sv: &ShaderValue) -> PyResult<()> { + match &self.element_type { + Some(existing) if std::mem::discriminant(existing) != std::mem::discriminant(sv) => { + Err(PyTypeError::new_err(format!( + "buffer element type mismatch: expected {existing:?}, got {sv:?}" + ))) + } + Some(_) => Ok(()), + None => { + self.element_type = Some(sv.clone()); + Ok(()) + } + } + } + + fn normalize_index(&self, i: isize) -> PyResult { + let len = self.__len__() as isize; + let i = if i < 0 { len + i } else { i }; + if i < 0 || i >= len { + Err(PyIndexError::new_err("buffer index out of range")) + } else { + Ok(i as usize) + } + } +} + +impl Drop for Buffer { + fn drop(&mut self) { + let _ = buffer_destroy(self.entity); + } +} + +fn slice_positions(indices: &PySliceIndices) -> impl Iterator + use<> { + let PySliceIndices { + start, + step, + slicelength, + .. + } = *indices; + (0..slicelength as isize).map(move |i| start + i * step) +} + +fn shader_values_to_bytes(values: &Bound<'_, PyAny>) -> PyResult<(Vec, Option)> { + let mut bytes = Vec::new(); + let mut element_type: Option = None; + for item in values.try_iter()? { + let sv = py_to_shader_value(&item?)?; + if let Some(ref existing) = element_type + && std::mem::discriminant(existing) != std::mem::discriminant(&sv) + { + return Err(PyTypeError::new_err(format!( + "buffer elements must all share the same type: expected {existing:?}, got {sv:?}" + ))); + } + let b = sv + .to_bytes() + .ok_or_else(|| PyTypeError::new_err("unsupported value type for buffer"))?; + element_type.get_or_insert(sv); + bytes.extend_from_slice(&b); + } + Ok((bytes, element_type)) +} + +fn shader_value_to_py<'py>(py: Python<'py>, sv: &ShaderValue) -> PyResult> { + fn list<'py, T: pyo3::IntoPyObject<'py> + Copy>( + py: Python<'py>, + xs: &[T], + ) -> PyResult> { + Ok(PyList::new(py, xs.iter().copied())?.into_any()) + } + match sv { + ShaderValue::Float(v) => Ok(v.into_pyobject(py)?.into_any()), + ShaderValue::Int(v) => Ok(v.into_pyobject(py)?.into_any()), + ShaderValue::UInt(v) => Ok(v.into_pyobject(py)?.into_any()), + ShaderValue::Float2(v) => list(py, v), + ShaderValue::Float3(v) => list(py, v), + ShaderValue::Float4(v) => list(py, v), + ShaderValue::Int2(v) => list(py, v), + ShaderValue::Int3(v) => list(py, v), + ShaderValue::Int4(v) => list(py, v), + ShaderValue::Mat4(v) => list(py, v), + ShaderValue::Texture(_) | ShaderValue::Buffer(_) => Err(PyRuntimeError::new_err( + "cannot convert Texture/Buffer to Python value", + )), + } +} + +#[pyclass(unsendable)] +pub struct Compute { + pub(crate) entity: Entity, +} + +#[pymethods] +impl Compute { + #[new] + pub fn new(shader: &Shader) -> PyResult { + let entity = + compute_create(shader.entity).map_err(|e| PyRuntimeError::new_err(format!("{e}")))?; + Ok(Self { entity }) + } + + #[pyo3(signature = (**kwargs))] + pub fn set(&self, kwargs: Option<&Bound<'_, pyo3::types::PyDict>>) -> PyResult<()> { + let Some(kwargs) = kwargs else { + return Ok(()); + }; + for (key, value) in kwargs.iter() { + let name: String = key.extract()?; + let value = py_to_shader_value(&value)?; + compute_set(self.entity, &name, value) + .map_err(|e| PyRuntimeError::new_err(format!("{e}")))?; + } + Ok(()) + } + + pub fn dispatch(&self, x: u32, y: u32, z: u32) -> PyResult<()> { + compute_dispatch(self.entity, x, y, z).map_err(|e| PyRuntimeError::new_err(format!("{e}"))) + } +} + +impl Drop for Compute { + fn drop(&mut self) { + let _ = compute_destroy(self.entity); + } +} diff --git a/crates/processing_pyo3/src/lib.rs b/crates/processing_pyo3/src/lib.rs index af5324e..930545f 100644 --- a/crates/processing_pyo3/src/lib.rs +++ b/crates/processing_pyo3/src/lib.rs @@ -9,6 +9,7 @@ //! To allow Python users to create a similar experience, we provide module-level //! functions that forward to a singleton Graphics object pub(crate) behind the scenes. pub(crate) mod color; +pub(crate) mod compute; #[cfg(feature = "cuda")] pub(crate) mod cuda; mod glfw; @@ -25,6 +26,7 @@ mod time; #[cfg(feature = "webcam")] mod webcam; +use compute::{Buffer, Compute}; use graphics::{ Geometry, Graphics, Image, Light, PyBlendMode, Topology, get_graphics, get_graphics_mut, }; @@ -320,6 +322,10 @@ fn detect_environment(py: Python<'_>) -> PyResult { mod mewnala { use super::*; + #[pymodule_export] + use super::Buffer; + #[pymodule_export] + use super::Compute; #[pymodule_export] use super::Geometry; #[pymodule_export] diff --git a/crates/processing_pyo3/src/material.rs b/crates/processing_pyo3/src/material.rs index 757d2d9..b7e9bc5 100644 --- a/crates/processing_pyo3/src/material.rs +++ b/crates/processing_pyo3/src/material.rs @@ -3,6 +3,7 @@ use processing::prelude::*; use pyo3::types::PyDict; use pyo3::{exceptions::PyRuntimeError, prelude::*}; +use crate::compute::Buffer; use crate::math::{PyVec2, PyVec3, PyVec4}; use crate::shader::Shader; @@ -11,34 +12,38 @@ pub struct Material { pub(crate) entity: Entity, } -fn py_to_material_value(value: &Bound<'_, PyAny>) -> PyResult { +pub(crate) fn py_to_shader_value(value: &Bound<'_, PyAny>) -> PyResult { if let Ok(v) = value.extract::() { - return Ok(material::MaterialValue::Float(v)); + return Ok(shader_value::ShaderValue::Float(v)); } if let Ok(v) = value.extract::() { - return Ok(material::MaterialValue::Int(v)); + return Ok(shader_value::ShaderValue::Int(v)); } // Accept PyVec types if let Ok(v) = value.extract::>() { - return Ok(material::MaterialValue::Float4(v.0.to_array())); + return Ok(shader_value::ShaderValue::Float4(v.0.to_array())); } if let Ok(v) = value.extract::>() { - return Ok(material::MaterialValue::Float3(v.0.to_array())); + return Ok(shader_value::ShaderValue::Float3(v.0.to_array())); } if let Ok(v) = value.extract::>() { - return Ok(material::MaterialValue::Float2(v.0.to_array())); + return Ok(shader_value::ShaderValue::Float2(v.0.to_array())); + } + + if let Ok(buf) = value.extract::>() { + return Ok(shader_value::ShaderValue::Buffer(buf.entity)); } // Fall back to raw arrays if let Ok(v) = value.extract::<[f32; 4]>() { - return Ok(material::MaterialValue::Float4(v)); + return Ok(shader_value::ShaderValue::Float4(v)); } if let Ok(v) = value.extract::<[f32; 3]>() { - return Ok(material::MaterialValue::Float3(v)); + return Ok(shader_value::ShaderValue::Float3(v)); } if let Ok(v) = value.extract::<[f32; 2]>() { - return Ok(material::MaterialValue::Float2(v)); + return Ok(shader_value::ShaderValue::Float2(v)); } Err(PyRuntimeError::new_err(format!( @@ -63,8 +68,8 @@ impl Material { if let Some(kwargs) = kwargs { for (key, value) in kwargs.iter() { let name: String = key.extract()?; - let mat_value = py_to_material_value(&value)?; - material_set(mat.entity, &name, mat_value) + let value = py_to_shader_value(&value)?; + material_set(mat.entity, &name, value) .map_err(|e| PyRuntimeError::new_err(format!("{e}")))?; } } @@ -78,8 +83,8 @@ impl Material { }; for (key, value) in kwargs.iter() { let name: String = key.extract()?; - let mat_value = py_to_material_value(&value)?; - material_set(self.entity, &name, mat_value) + let value = py_to_shader_value(&value)?; + material_set(self.entity, &name, value) .map_err(|e| PyRuntimeError::new_err(format!("{e}")))?; } Ok(()) diff --git a/crates/processing_render/Cargo.toml b/crates/processing_render/Cargo.toml index a2097a3..11e25b4 100644 --- a/crates/processing_render/Cargo.toml +++ b/crates/processing_render/Cargo.toml @@ -21,7 +21,6 @@ raw-window-handle = "0.6" half = "2.7" crossbeam-channel = "0.5" processing_core = { workspace = true } -processing_midi = { workspace = true } [build-dependencies] wesl = { workspace = true, features = ["package"] } diff --git a/crates/processing_render/src/compute.rs b/crates/processing_render/src/compute.rs new file mode 100644 index 0000000..4d0a80e --- /dev/null +++ b/crates/processing_render/src/compute.rs @@ -0,0 +1,387 @@ +use std::collections::BTreeSet; + +use bevy::asset::RenderAssetUsages; +use bevy::reflect::PartialReflect; +use bevy::{ + prelude::*, + render::{ + RenderApp, + render_asset::RenderAssets, + render_resource::{ + BindGroupLayoutDescriptor, Buffer as WgpuBuffer, BufferDescriptor, BufferUsages, + CachedComputePipelineId, CachedPipelineState, CommandEncoderDescriptor, + ComputePassDescriptor, ComputePipelineDescriptor, MapMode, PipelineCache, PollType, + }, + renderer::{RenderDevice, RenderQueue}, + storage::{GpuShaderBuffer, ShaderBuffer}, + texture::GpuImage, + }, +}; + +use bevy_naga_reflect::dynamic_shader::DynamicShader; + +use crate::image::Image as PImage; +use crate::material::custom::{Shader, apply_reflect_field, shader_value_to_reflect}; +use crate::shader_value::ShaderValue; +use processing_core::error::{ProcessingError, Result}; + +pub struct ComputePlugin; + +impl Plugin for ComputePlugin { + fn build(&self, app: &mut App) { + app.add_systems(Last, invalidate_rw_buffers); + } +} + +#[derive(Component)] +pub struct Buffer { + pub handle: Handle, + pub readback_buffer: WgpuBuffer, + pub size: u64, + pub synced: bool, + pub bound_rw: bool, +} + +fn readback_buffer(device: &RenderDevice, size: u64) -> WgpuBuffer { + device.create_buffer(&BufferDescriptor { + label: Some("Buffer Readback"), + size, + usage: BufferUsages::COPY_DST | BufferUsages::MAP_READ, + mapped_at_creation: false, + }) +} + +pub fn create_buffer( + In(size): In, + mut commands: Commands, + mut buffers: ResMut>, + render_device: Res, +) -> Entity { + let handle = buffers.add(ShaderBuffer::new( + &vec![0u8; size as usize], + RenderAssetUsages::all(), + )); + commands + .spawn(Buffer { + handle, + readback_buffer: readback_buffer(&render_device, size), + size, + synced: true, + bound_rw: false, + }) + .id() +} + +pub fn create_buffer_with_data( + In(data): In>, + mut commands: Commands, + mut buffers: ResMut>, + render_device: Res, +) -> Entity { + let size = data.len() as u64; + let handle = buffers.add(ShaderBuffer::new(&data, RenderAssetUsages::all())); + commands + .spawn(Buffer { + handle, + readback_buffer: readback_buffer(&render_device, size), + size, + synced: true, + bound_rw: false, + }) + .id() +} + +pub fn write_buffer_cpu( + In((handle, offset, data)): In<(Handle, u64, Vec)>, + mut buffers: ResMut>, +) -> Result<()> { + let mut asset = buffers + .get_mut(&handle) + .ok_or(ProcessingError::BufferNotFound)?; + let dst = asset.data.as_mut().ok_or(ProcessingError::BufferNotFound)?; + let start = offset as usize; + let end = start + data.len(); + dst[start..end].copy_from_slice(&data); + Ok(()) +} + +/// Caller must write bytes back via `get_mut_untracked` to avoid triggering +/// a re-upload. +pub fn read_buffer_gpu( + In((handle, readback_buffer, size)): In<(Handle, WgpuBuffer, u64)>, + gpu_buffers: Res>, + render_device: Res, + render_queue: Res, +) -> Result> { + let gpu_buffer = &gpu_buffers + .get(&handle) + .ok_or(ProcessingError::BufferNotFound)? + .buffer; + + let mut encoder = render_device.create_command_encoder(&CommandEncoderDescriptor::default()); + encoder.copy_buffer_to_buffer(gpu_buffer, 0, &readback_buffer, 0, size); + render_queue.submit(std::iter::once(encoder.finish())); + + let buffer_slice = readback_buffer.slice(0..size); + let (s, r) = crossbeam_channel::bounded(1); + buffer_slice.map_async(MapMode::Read, move |result| { + let _ = s.send(result); + }); + render_device + .poll(PollType::wait_indefinitely()) + .map_err(|e| ProcessingError::BufferMapError(format!("poll failed: {e}")))?; + r.recv() + .map_err(|e| ProcessingError::BufferMapError(format!("map channel closed: {e}")))? + .map_err(|e| ProcessingError::BufferMapError(format!("map failed: {e}")))?; + + let bytes = buffer_slice.get_mapped_range().to_vec(); + readback_buffer.unmap(); + Ok(bytes) +} + +pub fn invalidate_rw_buffers(mut buffers: Query<&mut Buffer>) { + for mut buf in &mut buffers { + if buf.bound_rw && buf.synced { + buf.synced = false; + } + } +} + +pub fn destroy_buffer(In(entity): In, mut commands: Commands) -> Result<()> { + commands.entity(entity).despawn(); + Ok(()) +} + +#[derive(Component)] +pub struct Compute { + pub shader: DynamicShader, + pub entry_point: String, + pub pipeline_id: CachedComputePipelineId, + pub bind_group_layout_descriptors: Vec<(u32, BindGroupLayoutDescriptor)>, +} + +fn queue_pipeline( + In(descriptor): In, + pipeline_cache: Res, +) -> CachedComputePipelineId { + pipeline_cache.queue_compute_pipeline(descriptor) +} + +fn pump_pipeline( + In(id): In, + mut pipeline_cache: ResMut, +) -> Result { + pipeline_cache.process_queue(); + match pipeline_cache.get_compute_pipeline_state(id) { + CachedPipelineState::Ok(_) => Ok(true), + CachedPipelineState::Err(e) => Err(ProcessingError::PipelineCompileError(format!("{e}"))), + _ => Ok(false), + } +} + +pub fn create_compute(app: &mut App, shader_entity: Entity) -> Result { + let (module, shader_handle) = { + let program = app + .world() + .get::(shader_entity) + .ok_or(ProcessingError::ShaderNotFound)?; + (program.module.clone(), program.shader_handle.clone()) + }; + + let compute_ep = module + .entry_points + .iter() + .find(|ep| ep.stage == naga::ShaderStage::Compute) + .ok_or_else(|| { + ProcessingError::ShaderCompilationError( + "Shader has no @compute entry point".to_string(), + ) + })?; + let entry_point = compute_ep.name.clone(); + + let mut shader = DynamicShader::new(module) + .map_err(|e| ProcessingError::ShaderCompilationError(e.to_string()))?; + shader.init(); + + let reflection = shader.reflection(); + let groups: BTreeSet = reflection.parameters().map(|p| p.group()).collect(); + + let bind_group_layout_descriptors: Vec<(u32, BindGroupLayoutDescriptor)> = groups + .iter() + .map(|&group| { + let entries = reflection.bind_group_layout(group); + ( + group, + BindGroupLayoutDescriptor { + label: "compute_bind_group_layout".into(), + entries, + }, + ) + }) + .collect(); + + let max_group = groups.iter().last().copied().map_or(0, |g| g + 1); + let mut layout_descriptors = vec![BindGroupLayoutDescriptor::default(); max_group as usize]; + for (group, desc) in &bind_group_layout_descriptors { + layout_descriptors[*group as usize] = desc.clone(); + } + + let descriptor = ComputePipelineDescriptor { + label: Some("processing_compute".into()), + layout: layout_descriptors, + immediate_size: 0, + shader: shader_handle.clone(), + shader_defs: Vec::new(), + entry_point: Some(entry_point.clone().into()), + zero_initialize_workgroup_memory: true, + }; + + let pipeline_id = app + .sub_app_mut(RenderApp) + .world_mut() + .run_system_cached_with(queue_pipeline, descriptor) + .unwrap(); + + const MAX_WAIT: u32 = 64; + for _ in 0..MAX_WAIT { + app.update(); + let done = app + .sub_app_mut(RenderApp) + .world_mut() + .run_system_cached_with(pump_pipeline, pipeline_id) + .unwrap()?; + if done { + return Ok(app + .world_mut() + .spawn(Compute { + shader, + entry_point, + pipeline_id, + bind_group_layout_descriptors, + }) + .id()); + } + } + Err(ProcessingError::PipelineNotReady(MAX_WAIT)) +} + +pub fn set_compute_property( + In((entity, name, value)): In<(Entity, String, ShaderValue)>, + mut computes: Query<&mut Compute>, + mut p_buffers: Query<&mut Buffer>, + p_images: Query<&PImage>, +) -> Result<()> { + use bevy_naga_reflect::reflect::ParameterCategory; + + let mut compute = computes + .get_mut(entity) + .map_err(|_| ProcessingError::ComputeNotFound)?; + + let category = compute + .shader + .reflection() + .parameter(&name) + .map(|p| p.category()) + .ok_or_else(|| ProcessingError::UnknownShaderProperty(name.clone()))?; + + match (&value, category) { + (ShaderValue::Buffer(buf_entity), ParameterCategory::Storage { read_only }) => { + let mut buffer = p_buffers + .get_mut(*buf_entity) + .map_err(|_| ProcessingError::BufferNotFound)?; + compute.shader.insert(&name, buffer.handle.clone()); + if !read_only { + buffer.bound_rw = true; + } + Ok(()) + } + (ShaderValue::Texture(img_entity), ParameterCategory::Texture) + | (ShaderValue::Texture(img_entity), ParameterCategory::StorageTexture) => { + let image = p_images + .get(*img_entity) + .map_err(|_| ProcessingError::ImageNotFound)?; + compute.shader.insert(&name, image.handle.clone()); + Ok(()) + } + (ShaderValue::Buffer(_), cat) | (ShaderValue::Texture(_), cat) => { + Err(ProcessingError::InvalidArgument(format!( + "property `{name}` expects {cat:?}, got {value:?}", + ))) + } + (_, ParameterCategory::Uniform) => { + let reflect_value: Box = shader_value_to_reflect(&value)?; + apply_reflect_field(&mut compute.shader, &name, &*reflect_value) + } + (_, cat) => Err(ProcessingError::InvalidArgument(format!( + "property `{name}` expects {cat:?}, got non-resource value" + ))), + } +} + +pub fn dispatch( + In((pipeline_id, layout_descriptors, shader, x, y, z)): In<( + CachedComputePipelineId, + Vec<(u32, BindGroupLayoutDescriptor)>, + DynamicShader, + u32, + u32, + u32, + )>, + pipeline_cache: Res, + render_device: Res, + render_queue: Res, + gpu_images: Res>, + gpu_buffers: Res>, +) -> Result<()> { + let pipeline = pipeline_cache + .get_compute_pipeline(pipeline_id) + .ok_or(ProcessingError::PipelineNotReady(0))? + .clone(); + + let reflection = shader.reflection(); + + let mut bind_groups = Vec::new(); + for (group, desc) in &layout_descriptors { + let layout = pipeline_cache.get_bind_group_layout(desc); + let bindings = + reflection.create_bindings(*group, &shader, &render_device, &gpu_images, &gpu_buffers); + + let bind_group_entries: Vec<_> = bindings + .iter() + .map( + |(binding, resource)| bevy::render::render_resource::BindGroupEntry { + binding: *binding, + resource: resource.get_binding(), + }, + ) + .collect(); + + let bind_group = render_device.create_bind_group( + Some("compute_bind_group"), + &layout, + &bind_group_entries, + ); + bind_groups.push(bind_group); + } + + let mut encoder = render_device.create_command_encoder(&CommandEncoderDescriptor::default()); + { + let mut pass = encoder.begin_compute_pass(&ComputePassDescriptor { + label: Some("compute_pass"), + ..Default::default() + }); + pass.set_pipeline(&pipeline); + for ((group, _), bg) in layout_descriptors.iter().zip(bind_groups.iter()) { + pass.set_bind_group(*group, bg, &[]); + } + pass.dispatch_workgroups(x, y, z); + } + render_queue.submit(std::iter::once(encoder.finish())); + + Ok(()) +} + +pub fn destroy_compute(In(entity): In, mut commands: Commands) -> Result<()> { + commands.entity(entity).despawn(); + Ok(()) +} diff --git a/crates/processing_render/src/lib.rs b/crates/processing_render/src/lib.rs index 27636dd..3df687d 100644 --- a/crates/processing_render/src/lib.rs +++ b/crates/processing_render/src/lib.rs @@ -2,6 +2,7 @@ pub mod camera; pub mod color; +pub mod compute; pub mod geometry; pub mod gltf; pub mod graphics; @@ -10,6 +11,7 @@ pub mod light; pub mod material; pub mod monitor; pub mod render; +pub mod shader_value; pub mod sketch; pub(crate) mod surface; pub mod time; @@ -61,6 +63,7 @@ impl Plugin for ProcessingRenderPlugin { material::ProcessingMaterialPlugin, bevy::pbr::wireframe::WireframePlugin::default(), material::custom::CustomMaterialPlugin, + compute::ComputePlugin, camera::OrbitCameraPlugin, bevy::camera_controller::free_camera::FreeCameraPlugin, bevy::camera_controller::pan_camera::PanCameraPlugin, @@ -1280,7 +1283,7 @@ pub fn material_create_pbr() -> error::Result { pub fn material_set( entity: Entity, name: impl Into, - value: material::MaterialValue, + value: shader_value::ShaderValue, ) -> error::Result<()> { app_mut(|app| { app.world_mut() @@ -1479,3 +1482,220 @@ pub fn gltf_light(gltf_entity: Entity, index: usize) -> error::Result { .unwrap() }) } + +pub fn buffer_create(size: u64) -> error::Result { + app_mut(|app| { + let entity = app + .world_mut() + .run_system_cached_with(compute::create_buffer, size) + .unwrap(); + app.update(); + Ok(entity) + }) +} + +pub fn buffer_create_with_data(data: Vec) -> error::Result { + app_mut(|app| { + let entity = app + .world_mut() + .run_system_cached_with(compute::create_buffer_with_data, data) + .unwrap(); + app.update(); + Ok(entity) + }) +} + +pub fn buffer_size(entity: Entity) -> error::Result { + app_mut(|app| { + Ok(app + .world() + .get::(entity) + .ok_or(error::ProcessingError::BufferNotFound)? + .size) + }) +} + +pub fn buffer_write(entity: Entity, data: Vec) -> error::Result<()> { + buffer_write_range(entity, 0, data, true) +} + +pub fn buffer_write_element(entity: Entity, offset: u64, data: Vec) -> error::Result<()> { + buffer_write_range(entity, offset, data, false) +} + +fn ensure_buffer_synced(app: &mut App, entity: Entity) -> error::Result<()> { + let (handle, readback_buffer, size, synced) = { + let buf = app + .world() + .get::(entity) + .ok_or(error::ProcessingError::BufferNotFound)?; + ( + buf.handle.clone(), + buf.readback_buffer.clone(), + buf.size, + buf.synced, + ) + }; + if synced { + return Ok(()); + } + let bytes = app + .sub_app_mut(bevy::render::RenderApp) + .world_mut() + .run_system_cached_with( + compute::read_buffer_gpu, + (handle.clone(), readback_buffer, size), + ) + .unwrap()?; + + let world = app.world_mut(); + { + let mut buffers = world.resource_mut::>(); + let asset = buffers + .get_mut_untracked(handle.id()) + .ok_or(error::ProcessingError::BufferNotFound)?; + asset.data = Some(bytes); + } + + let mut buf = world + .get_mut::(entity) + .ok_or(error::ProcessingError::BufferNotFound)?; + buf.synced = true; + Ok(()) +} + +fn buffer_write_range( + entity: Entity, + offset: u64, + data: Vec, + exact_size: bool, +) -> error::Result<()> { + app_mut(|app| { + let (handle, size) = { + let buf = app + .world() + .get::(entity) + .ok_or(error::ProcessingError::BufferNotFound)?; + (buf.handle.clone(), buf.size) + }; + let end = offset.checked_add(data.len() as u64).ok_or_else(|| { + error::ProcessingError::InvalidArgument("offset + len overflow".to_string()) + })?; + if exact_size && (offset != 0 || end != size) { + return Err(error::ProcessingError::InvalidArgument(format!( + "buffer_write data length {} does not match buffer size {size}; \ + destroy and re-create to resize, or use buffer_write_element for partial writes", + data.len() + ))); + } + if end > size { + return Err(error::ProcessingError::InvalidArgument(format!( + "buffer write out of bounds: offset {offset} + len {} > size {size}", + data.len() + ))); + } + ensure_buffer_synced(app, entity)?; + app.world_mut() + .run_system_cached_with(compute::write_buffer_cpu, (handle, offset, data)) + .unwrap() + }) +} + +pub fn buffer_read_element(entity: Entity, offset: u64, len: u64) -> error::Result> { + buffer_read_range(entity, offset, len) +} + +pub fn buffer_read(entity: Entity) -> error::Result> { + let size = buffer_size(entity)?; + buffer_read_range(entity, 0, size) +} + +fn buffer_read_range(entity: Entity, offset: u64, len: u64) -> error::Result> { + app_mut(|app| { + let size = app + .world() + .get::(entity) + .ok_or(error::ProcessingError::BufferNotFound)? + .size; + let end = offset.checked_add(len).ok_or_else(|| { + error::ProcessingError::InvalidArgument("offset + len overflow".to_string()) + })?; + if end > size { + return Err(error::ProcessingError::InvalidArgument(format!( + "buffer read out of bounds: offset {offset} + len {len} > size {size}" + ))); + } + ensure_buffer_synced(app, entity)?; + let handle = app + .world() + .get::(entity) + .ok_or(error::ProcessingError::BufferNotFound)? + .handle + .clone(); + let buffers = app + .world() + .resource::>(); + let data = buffers + .get(&handle) + .and_then(|a| a.data.as_ref()) + .ok_or(error::ProcessingError::BufferNotFound)?; + Ok(data[offset as usize..(offset + len) as usize].to_vec()) + }) +} + +pub fn buffer_destroy(entity: Entity) -> error::Result<()> { + app_mut(|app| { + app.world_mut() + .run_system_cached_with(compute::destroy_buffer, entity) + .unwrap() + }) +} + +pub fn compute_create(shader_entity: Entity) -> error::Result { + app_mut(|app| compute::create_compute(app, shader_entity)) +} + +pub fn compute_set( + entity: Entity, + name: impl Into, + value: shader_value::ShaderValue, +) -> error::Result<()> { + app_mut(|app| { + app.world_mut() + .run_system_cached_with(compute::set_compute_property, (entity, name.into(), value)) + .unwrap() + }) +} + +pub fn compute_dispatch(entity: Entity, x: u32, y: u32, z: u32) -> error::Result<()> { + app_mut(|app| { + app.update(); + + let args = { + let c = app + .world() + .get::(entity) + .ok_or(error::ProcessingError::ComputeNotFound)?; + ( + c.pipeline_id, + c.bind_group_layout_descriptors.clone(), + c.shader.clone(), + x, + y, + z, + ) + }; + app.sub_app_mut(bevy::render::RenderApp) + .world_mut() + .run_system_cached_with(compute::dispatch, args) + .unwrap() + }) +} + +pub fn compute_destroy(entity: Entity) -> error::Result<()> { + app_mut(|app| { + app.world_mut() + .run_system_cached_with(compute::destroy_compute, entity) + .unwrap() + }) +} diff --git a/crates/processing_render/src/material/custom.rs b/crates/processing_render/src/material/custom.rs index 04080b8..0a3611d 100644 --- a/crates/processing_render/src/material/custom.rs +++ b/crates/processing_render/src/material/custom.rs @@ -51,8 +51,8 @@ use bevy_naga_reflect::dynamic_shader::DynamicShader; use bevy::shader::Shader as ShaderAsset; -use crate::material::MaterialValue; use crate::render::material::UntypedMaterial; +use crate::shader_value::ShaderValue; use processing_core::config::{Config, ConfigKey}; use processing_core::error::{ProcessingError, Result}; @@ -265,52 +265,58 @@ pub fn create_custom( Ok(commands.spawn(UntypedMaterial(handle.untyped())).id()) } -pub fn set_property( - material: &mut CustomMaterial, +pub fn set_property(material: &mut CustomMaterial, name: &str, value: &ShaderValue) -> Result<()> { + let reflect_value: Box = shader_value_to_reflect(value)?; + apply_reflect_field(&mut material.shader, name, &*reflect_value) +} + +pub(crate) fn apply_reflect_field( + shader: &mut DynamicShader, name: &str, - value: &MaterialValue, + value: &dyn PartialReflect, ) -> Result<()> { - let reflect_value: Box = material_value_to_reflect(value)?; - - if let Some(field) = material.shader.field_mut(name) { - field.apply(&*reflect_value); + if let Some(field) = shader.field_mut(name) { + field.apply(value); return Ok(()); } - let param_name = find_param_containing_field(&material.shader, name); + let param_name = find_param_containing_field(shader, name); if let Some(param_name) = param_name - && let Some(param) = material.shader.field_mut(¶m_name) + && let Some(param) = shader.field_mut(¶m_name) && let ReflectMut::Struct(s) = param.reflect_mut() && let Some(field) = s.field_mut(name) { - field.apply(&*reflect_value); + field.apply(value); return Ok(()); } - Err(ProcessingError::UnknownMaterialProperty(name.to_string())) + Err(ProcessingError::UnknownShaderProperty(name.to_string())) } -fn material_value_to_reflect(value: &MaterialValue) -> Result> { +pub(crate) fn shader_value_to_reflect(value: &ShaderValue) -> Result> { Ok(match value { - MaterialValue::Float(v) => Box::new(*v), - MaterialValue::Float2(v) => Box::new(Vec2::from_array(*v)), - MaterialValue::Float3(v) => Box::new(Vec3::from_array(*v)), - MaterialValue::Float4(v) => Box::new(Vec4::from_array(*v)), - MaterialValue::Int(v) => Box::new(*v), - MaterialValue::Int2(v) => Box::new(IVec2::from_array(*v)), - MaterialValue::Int3(v) => Box::new(IVec3::from_array(*v)), - MaterialValue::Int4(v) => Box::new(IVec4::from_array(*v)), - MaterialValue::UInt(v) => Box::new(*v), - MaterialValue::Mat4(v) => Box::new(Mat4::from_cols_array(v)), - MaterialValue::Texture(_) => { - return Err(ProcessingError::UnknownMaterialProperty( - "Texture properties not yet supported for custom materials".to_string(), + ShaderValue::Float(v) => Box::new(*v), + ShaderValue::Float2(v) => Box::new(Vec2::from_array(*v)), + ShaderValue::Float3(v) => Box::new(Vec3::from_array(*v)), + ShaderValue::Float4(v) => Box::new(Vec4::from_array(*v)), + ShaderValue::Int(v) => Box::new(*v), + ShaderValue::Int2(v) => Box::new(IVec2::from_array(*v)), + ShaderValue::Int3(v) => Box::new(IVec3::from_array(*v)), + ShaderValue::Int4(v) => Box::new(IVec4::from_array(*v)), + ShaderValue::UInt(v) => Box::new(*v), + ShaderValue::Mat4(v) => Box::new(Mat4::from_cols_array(v)), + ShaderValue::Texture(_) | ShaderValue::Buffer(_) => { + return Err(ProcessingError::InvalidArgument( + "Texture/Buffer must be bound via set_property, not as a uniform value".to_string(), )); } }) } -fn find_param_containing_field(shader: &DynamicShader, field_name: &str) -> Option { +pub(crate) fn find_param_containing_field( + shader: &DynamicShader, + field_name: &str, +) -> Option { for i in 0..shader.field_len() { if let Some(field) = shader.field_at(i) && let ReflectRef::Struct(s) = field.reflect_ref() diff --git a/crates/processing_render/src/material/mod.rs b/crates/processing_render/src/material/mod.rs index be48836..d7fd029 100644 --- a/crates/processing_render/src/material/mod.rs +++ b/crates/processing_render/src/material/mod.rs @@ -1,7 +1,9 @@ pub mod custom; pub mod pbr; +use crate::compute; use crate::render::material::UntypedMaterial; +use crate::shader_value::ShaderValue; use bevy::material::descriptor::RenderPipelineDescriptor; use bevy::material::specialize::SpecializedMeshPipelineError; use bevy::mesh::MeshVertexBufferLayoutRef; @@ -11,6 +13,7 @@ use bevy::pbr::{ use bevy::prelude::*; use bevy::render::render_resource::{AsBindGroup, BlendState}; use bevy::shader::ShaderRef; +use bevy_naga_reflect::reflect::ParameterCategory; use processing_core::error::{self, ProcessingError}; pub struct ProcessingMaterialPlugin; @@ -38,21 +41,6 @@ impl Plugin for ProcessingMaterialPlugin { #[derive(Resource)] pub struct DefaultMaterial(pub Entity); -#[derive(Debug, Clone)] -pub enum MaterialValue { - Float(f32), - Float2([f32; 2]), - Float3([f32; 3]), - Float4([f32; 4]), - Int(i32), - Int2([i32; 2]), - Int3([i32; 3]), - Int4([i32; 4]), - UInt(u32), - Mat4([f32; 16]), - Texture(Entity), -} - pub fn create_pbr( mut commands: Commands, mut materials: ResMut>>, @@ -69,10 +57,11 @@ pub fn create_pbr( } pub fn set_property( - In((entity, name, value)): In<(Entity, String, MaterialValue)>, + In((entity, name, value)): In<(Entity, String, ShaderValue)>, material_handles: Query<&UntypedMaterial>, mut extended_materials: ResMut>>, mut custom_materials: ResMut>, + mut p_buffers: Query<&mut compute::Buffer>, ) -> error::Result<()> { let untyped = material_handles .get(entity) @@ -93,6 +82,31 @@ pub fn set_property( let mut mat = custom_materials .get_mut(&handle) .ok_or(ProcessingError::MaterialNotFound)?; + + if let ShaderValue::Buffer(buf_entity) = &value { + let mut buffer = p_buffers + .get_mut(*buf_entity) + .map_err(|_| ProcessingError::BufferNotFound)?; + + let category = mat + .shader + .reflection() + .parameter(&name) + .map(|p| p.category()) + .ok_or_else(|| ProcessingError::UnknownShaderProperty(name.clone()))?; + + let ParameterCategory::Storage { read_only } = category else { + return Err(ProcessingError::InvalidArgument(format!( + "property `{name}` expects {category:?}, got Buffer" + ))); + }; + mat.shader.insert(&name, buffer.handle.clone()); + if !read_only { + buffer.bound_rw = true; + } + return Ok(()); + } + return custom::set_property(&mut mat, &name, &value); } diff --git a/crates/processing_render/src/material/pbr.rs b/crates/processing_render/src/material/pbr.rs index df3df23..c4f7a37 100644 --- a/crates/processing_render/src/material/pbr.rs +++ b/crates/processing_render/src/material/pbr.rs @@ -1,17 +1,17 @@ use bevy::prelude::*; -use super::MaterialValue; +use crate::shader_value::ShaderValue; use processing_core::error::{ProcessingError, Result}; /// Set a property on a StandardMaterial by name. pub fn set_property( material: &mut StandardMaterial, name: &str, - value: &MaterialValue, + value: &ShaderValue, ) -> Result<()> { match name { "base_color" | "color" => { - let MaterialValue::Float4(c) = value else { + let ShaderValue::Float4(c) = value else { return Err(ProcessingError::InvalidArgument(format!( "'{name}' expects Float4, got {value:?}" ))); @@ -19,7 +19,7 @@ pub fn set_property( material.base_color = Color::srgba(c[0], c[1], c[2], c[3]); } "metallic" => { - let MaterialValue::Float(v) = value else { + let ShaderValue::Float(v) = value else { return Err(ProcessingError::InvalidArgument(format!( "'{name}' expects Float, got {value:?}" ))); @@ -27,7 +27,7 @@ pub fn set_property( material.metallic = *v; } "roughness" | "perceptual_roughness" => { - let MaterialValue::Float(v) = value else { + let ShaderValue::Float(v) = value else { return Err(ProcessingError::InvalidArgument(format!( "'{name}' expects Float, got {value:?}" ))); @@ -35,7 +35,7 @@ pub fn set_property( material.perceptual_roughness = *v; } "reflectance" => { - let MaterialValue::Float(v) = value else { + let ShaderValue::Float(v) = value else { return Err(ProcessingError::InvalidArgument(format!( "'{name}' expects Float, got {value:?}" ))); @@ -43,7 +43,7 @@ pub fn set_property( material.reflectance = *v; } "emissive" => { - let MaterialValue::Float4(c) = value else { + let ShaderValue::Float4(c) = value else { return Err(ProcessingError::InvalidArgument(format!( "'{name}' expects Float4, got {value:?}" ))); @@ -51,7 +51,7 @@ pub fn set_property( material.emissive = LinearRgba::new(c[0], c[1], c[2], c[3]); } "unlit" => { - let MaterialValue::Float(v) = value else { + let ShaderValue::Float(v) = value else { return Err(ProcessingError::InvalidArgument(format!( "'{name}' expects Float, got {value:?}" ))); @@ -59,7 +59,7 @@ pub fn set_property( material.unlit = *v > 0.5; } "double_sided" => { - let MaterialValue::Float(v) = value else { + let ShaderValue::Float(v) = value else { return Err(ProcessingError::InvalidArgument(format!( "'{name}' expects Float, got {value:?}" ))); @@ -67,7 +67,7 @@ pub fn set_property( material.double_sided = *v > 0.5; } "alpha_mode" => { - let MaterialValue::Int(v) = value else { + let ShaderValue::Int(v) = value else { return Err(ProcessingError::InvalidArgument(format!( "'{name}' expects Int, got {value:?}" ))); @@ -88,7 +88,7 @@ pub fn set_property( }; } _ => { - return Err(ProcessingError::UnknownMaterialProperty(name.to_string())); + return Err(ProcessingError::UnknownShaderProperty(name.to_string())); } } Ok(()) diff --git a/crates/processing_render/src/shader_value.rs b/crates/processing_render/src/shader_value.rs new file mode 100644 index 0000000..9d2c77d --- /dev/null +++ b/crates/processing_render/src/shader_value.rs @@ -0,0 +1,82 @@ +use bevy::prelude::*; + +#[derive(Debug, Clone)] +pub enum ShaderValue { + Float(f32), + Float2([f32; 2]), + Float3([f32; 3]), + Float4([f32; 4]), + Int(i32), + Int2([i32; 2]), + Int3([i32; 3]), + Int4([i32; 4]), + UInt(u32), + Mat4([f32; 16]), + Texture(Entity), + Buffer(Entity), +} + +impl ShaderValue { + pub fn to_bytes(&self) -> Option> { + match self { + ShaderValue::Float(v) => Some(v.to_le_bytes().to_vec()), + ShaderValue::Float2(v) => Some(v.iter().flat_map(|f| f.to_le_bytes()).collect()), + ShaderValue::Float3(v) => Some(v.iter().flat_map(|f| f.to_le_bytes()).collect()), + ShaderValue::Float4(v) => Some(v.iter().flat_map(|f| f.to_le_bytes()).collect()), + ShaderValue::Int(v) => Some(v.to_le_bytes().to_vec()), + ShaderValue::Int2(v) => Some(v.iter().flat_map(|i| i.to_le_bytes()).collect()), + ShaderValue::Int3(v) => Some(v.iter().flat_map(|i| i.to_le_bytes()).collect()), + ShaderValue::Int4(v) => Some(v.iter().flat_map(|i| i.to_le_bytes()).collect()), + ShaderValue::UInt(v) => Some(v.to_le_bytes().to_vec()), + ShaderValue::Mat4(v) => Some(v.iter().flat_map(|f| f.to_le_bytes()).collect()), + ShaderValue::Texture(_) | ShaderValue::Buffer(_) => None, + } + } + + pub fn byte_size(&self) -> Option { + match self { + ShaderValue::Float(_) | ShaderValue::Int(_) | ShaderValue::UInt(_) => Some(4), + ShaderValue::Float2(_) | ShaderValue::Int2(_) => Some(8), + ShaderValue::Float3(_) | ShaderValue::Int3(_) => Some(12), + ShaderValue::Float4(_) | ShaderValue::Int4(_) => Some(16), + ShaderValue::Mat4(_) => Some(64), + ShaderValue::Texture(_) | ShaderValue::Buffer(_) => None, + } + } + + pub fn read_from_bytes(&self, bytes: &[u8]) -> Option { + fn f32s(bytes: &[u8]) -> Option<[f32; N]> { + let mut arr = [0f32; N]; + for i in 0..N { + arr[i] = f32::from_le_bytes(bytes[i * 4..(i + 1) * 4].try_into().ok()?); + } + Some(arr) + } + fn i32s(bytes: &[u8]) -> Option<[i32; N]> { + let mut arr = [0i32; N]; + for i in 0..N { + arr[i] = i32::from_le_bytes(bytes[i * 4..(i + 1) * 4].try_into().ok()?); + } + Some(arr) + } + match self { + ShaderValue::Float(_) => Some(ShaderValue::Float(f32::from_le_bytes( + bytes[..4].try_into().ok()?, + ))), + ShaderValue::Float2(_) => Some(ShaderValue::Float2(f32s::<2>(bytes)?)), + ShaderValue::Float3(_) => Some(ShaderValue::Float3(f32s::<3>(bytes)?)), + ShaderValue::Float4(_) => Some(ShaderValue::Float4(f32s::<4>(bytes)?)), + ShaderValue::Int(_) => Some(ShaderValue::Int(i32::from_le_bytes( + bytes[..4].try_into().ok()?, + ))), + ShaderValue::Int2(_) => Some(ShaderValue::Int2(i32s::<2>(bytes)?)), + ShaderValue::Int3(_) => Some(ShaderValue::Int3(i32s::<3>(bytes)?)), + ShaderValue::Int4(_) => Some(ShaderValue::Int4(i32s::<4>(bytes)?)), + ShaderValue::UInt(_) => Some(ShaderValue::UInt(u32::from_le_bytes( + bytes[..4].try_into().ok()?, + ))), + ShaderValue::Mat4(_) => Some(ShaderValue::Mat4(f32s::<16>(bytes)?)), + ShaderValue::Texture(_) | ShaderValue::Buffer(_) => None, + } + } +} diff --git a/crates/processing_wasm/src/lib.rs b/crates/processing_wasm/src/lib.rs index 120a9fe..0372ce4 100644 --- a/crates/processing_wasm/src/lib.rs +++ b/crates/processing_wasm/src/lib.rs @@ -744,7 +744,7 @@ pub fn js_material_set_float(mat_id: u64, name: &str, value: f32) -> Result<(), check(material_set( Entity::from_bits(mat_id), name, - material::MaterialValue::Float(value), + shader_value::ShaderValue::Float(value), )) } @@ -760,7 +760,7 @@ pub fn js_material_set_float4( check(material_set( Entity::from_bits(mat_id), name, - material::MaterialValue::Float4([r, g, b, a]), + shader_value::ShaderValue::Float4([r, g, b, a]), )) } diff --git a/examples/compute_readback.rs b/examples/compute_readback.rs new file mode 100644 index 0000000..f470274 --- /dev/null +++ b/examples/compute_readback.rs @@ -0,0 +1,90 @@ +use processing::prelude::*; + +fn main() { + match run() { + Ok(_) => { + eprintln!("Compute readback test passed!"); + exit(0).unwrap(); + } + Err(e) => { + eprintln!("Compute readback error: {:?}", e); + exit(1).unwrap(); + } + } +} + +fn run() -> error::Result<()> { + init(Config::default())?; + + let surface = surface_create_offscreen(1, 1, 1.0, TextureFormat::Rgba8Unorm)?; + let _graphics = graphics_create(surface, 1, 1, TextureFormat::Rgba8Unorm)?; + + let buf = buffer_create(16)?; + + let shader_src = r#" +@group(0) @binding(0) +var output: array; + +@compute @workgroup_size(1) +fn main() { + output[0] = 1u; + output[1] = 2u; + output[2] = 3u; + output[3] = 4u; +} +"#; + let shader = shader_create(shader_src)?; + let compute = compute_create(shader)?; + compute_set(compute, "output", shader_value::ShaderValue::Buffer(buf))?; + + compute_dispatch(compute, 1, 1, 1)?; + + let data = buffer_read(buf)?; + let values: Vec = data + .chunks_exact(4) + .map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]])) + .collect(); + + assert_eq!(values, vec![1, 2, 3, 4], "Compute readback mismatch!"); + eprintln!("PASS"); + + let double_src = r#" +@group(0) @binding(0) +var data: array; + +@compute @workgroup_size(4) +fn main(@builtin(global_invocation_id) id: vec3) { + data[id.x] = data[id.x] * 2.0; +} +"#; + let buf2_data: Vec = [1.0f32, 2.0, 3.0, 4.0] + .iter() + .flat_map(|f| f.to_le_bytes()) + .collect(); + let buf2 = buffer_create_with_data(buf2_data)?; + let shader2 = shader_create(double_src)?; + let compute2 = compute_create(shader2)?; + compute_set(compute2, "data", shader_value::ShaderValue::Buffer(buf2))?; + compute_dispatch(compute2, 1, 1, 1)?; + + let data2 = buffer_read(buf2)?; + let floats: Vec = data2 + .chunks_exact(4) + .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])) + .collect(); + assert_eq!( + floats, + vec![2.0, 4.0, 6.0, 8.0], + "In-place double mismatch!" + ); + eprintln!("PASS"); + + compute_destroy(compute)?; + compute_destroy(compute2)?; + shader_destroy(shader)?; + shader_destroy(shader2)?; + buffer_destroy(buf)?; + buffer_destroy(buf2)?; + + Ok(()) +} diff --git a/examples/custom_material.rs b/examples/custom_material.rs index b29feed..aea1b34 100644 --- a/examples/custom_material.rs +++ b/examples/custom_material.rs @@ -36,7 +36,7 @@ fn sketch() -> error::Result<()> { material_set( mat, "color", - material::MaterialValue::Float4([1.0, 0.2, 0.4, 1.0]), + shader_value::ShaderValue::Float4([1.0, 0.2, 0.4, 1.0]), )?; let mut angle = 0.0; diff --git a/examples/gltf_load.rs b/examples/gltf_load.rs index b261539..8d86958 100644 --- a/examples/gltf_load.rs +++ b/examples/gltf_load.rs @@ -2,8 +2,8 @@ use processing_glfw::GlfwContext; use bevy::math::Vec3; use processing::prelude::*; -use processing_render::material::MaterialValue; use processing_render::render::command::DrawCommand; +use processing_render::shader_value::ShaderValue; fn main() { match sketch() { @@ -50,11 +50,7 @@ fn sketch() -> error::Result<()> { let r = (t * 8.0).sin() * 0.5 + 0.5; let g = (t * 8.0 + 2.0).sin() * 0.5 + 0.5; let b = (t * 8.0 + 4.0).sin() * 0.5 + 0.5; - material_set( - duck_mat, - "base_color", - MaterialValue::Float4([r, g, b, 1.0]), - )?; + material_set(duck_mat, "base_color", ShaderValue::Float4([r, g, b, 1.0]))?; graphics_begin_draw(graphics)?; diff --git a/examples/lights.rs b/examples/lights.rs index b2d0f21..b111591 100644 --- a/examples/lights.rs +++ b/examples/lights.rs @@ -27,7 +27,7 @@ fn sketch() -> error::Result<()> { let graphics = graphics_create(surface, width, height, TextureFormat::Rgba16Float)?; let box_geo = geometry_box(100.0, 100.0, 100.0)?; let pbr_mat = material_create_pbr()?; - material_set(pbr_mat, "roughness", material::MaterialValue::Float(0.0))?; + material_set(pbr_mat, "roughness", shader_value::ShaderValue::Float(0.0))?; // We will only declare lights in `setup` // rather than calling some sort of `light()` method inside of `draw` diff --git a/examples/materials.rs b/examples/materials.rs index 5ec92d8..306ac6f 100644 --- a/examples/materials.rs +++ b/examples/materials.rs @@ -51,8 +51,12 @@ fn sketch() -> error::Result<()> { let roughness = col as f32 / (cols - 1) as f32; let metallic = row as f32 / (rows - 1) as f32; - material_set(mat, "roughness", material::MaterialValue::Float(roughness))?; - material_set(mat, "metallic", material::MaterialValue::Float(metallic))?; + material_set( + mat, + "roughness", + shader_value::ShaderValue::Float(roughness), + )?; + material_set(mat, "metallic", shader_value::ShaderValue::Float(metallic))?; materials.push(mat); } } diff --git a/examples/primitives_3d.rs b/examples/primitives_3d.rs index adadc6e..1de2bbc 100644 --- a/examples/primitives_3d.rs +++ b/examples/primitives_3d.rs @@ -24,7 +24,7 @@ fn sketch() -> error::Result<()> { light_create_directional(graphics, bevy::color::Color::srgb(0.9, 0.85, 0.8), 300.0)?; let pbr = material_create_pbr()?; - material_set(pbr, "roughness", material::MaterialValue::Float(0.35))?; + material_set(pbr, "roughness", shader_value::ShaderValue::Float(0.35))?; let mut t: f32 = 0.0;