1use crate::{
2 binding_model,
3 hal_api::HalApi,
4 hub::Hub,
5 id::{BindGroupLayoutId, PipelineLayoutId},
6 resource::{Buffer, BufferAccessError, BufferAccessResult, BufferMapOperation},
7 snatch::SnatchGuard,
8 Label, DOWNLEVEL_ERROR_MESSAGE,
9};
10
11use arrayvec::ArrayVec;
12use hal::Device as _;
13use smallvec::SmallVec;
14use std::os::raw::c_char;
15use thiserror::Error;
16use wgt::{BufferAddress, DeviceLostReason, TextureFormat};
17
18use std::{iter, num::NonZeroU32, ptr};
19
20pub mod any_device;
21pub(crate) mod bgl;
22pub mod global;
23mod life;
24pub mod queue;
25pub mod resource;
26#[cfg(any(feature = "trace", feature = "replay"))]
27pub mod trace;
28pub use {life::WaitIdleError, resource::Device};
29
30pub const SHADER_STAGE_COUNT: usize = hal::MAX_CONCURRENT_SHADER_STAGES;
31pub(crate) const ZERO_BUFFER_SIZE: BufferAddress = 512 << 10;
34
35const CLEANUP_WAIT_MS: u32 = 5000;
36
37const IMPLICIT_BIND_GROUP_LAYOUT_ERROR_LABEL: &str = "Implicit BindGroupLayout in the Error State";
38const ENTRYPOINT_FAILURE_ERROR: &str = "The given EntryPoint is Invalid";
39
40pub type DeviceDescriptor<'a> = wgt::DeviceDescriptor<Label<'a>>;
41
42#[repr(C)]
43#[derive(Clone, Copy, Debug, Eq, PartialEq)]
44#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
45pub enum HostMap {
46 Read,
47 Write,
48}
49
50#[derive(Clone, Debug, Hash, PartialEq)]
51#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
52pub(crate) struct AttachmentData<T> {
53 pub colors: ArrayVec<Option<T>, { hal::MAX_COLOR_ATTACHMENTS }>,
54 pub resolves: ArrayVec<T, { hal::MAX_COLOR_ATTACHMENTS }>,
55 pub depth_stencil: Option<T>,
56}
57impl<T: PartialEq> Eq for AttachmentData<T> {}
58impl<T> AttachmentData<T> {
59 pub(crate) fn map<U, F: Fn(&T) -> U>(&self, fun: F) -> AttachmentData<U> {
60 AttachmentData {
61 colors: self.colors.iter().map(|c| c.as_ref().map(&fun)).collect(),
62 resolves: self.resolves.iter().map(&fun).collect(),
63 depth_stencil: self.depth_stencil.as_ref().map(&fun),
64 }
65 }
66}
67
68#[derive(Debug, Copy, Clone)]
69pub enum RenderPassCompatibilityCheckType {
70 RenderPipeline,
71 RenderBundle,
72}
73
74#[derive(Clone, Debug, Hash, PartialEq)]
75#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
76pub(crate) struct RenderPassContext {
77 pub attachments: AttachmentData<TextureFormat>,
78 pub sample_count: u32,
79 pub multiview: Option<NonZeroU32>,
80}
81#[derive(Clone, Debug, Error)]
82#[non_exhaustive]
83pub enum RenderPassCompatibilityError {
84 #[error(
85 "Incompatible color attachments at indices {indices:?}: the RenderPass uses textures with formats {expected:?} but the {ty:?} uses attachments with formats {actual:?}",
86 )]
87 IncompatibleColorAttachment {
88 indices: Vec<usize>,
89 expected: Vec<Option<TextureFormat>>,
90 actual: Vec<Option<TextureFormat>>,
91 ty: RenderPassCompatibilityCheckType,
92 },
93 #[error(
94 "Incompatible depth-stencil attachment format: the RenderPass uses a texture with format {expected:?} but the {ty:?} uses an attachment with format {actual:?}",
95 )]
96 IncompatibleDepthStencilAttachment {
97 expected: Option<TextureFormat>,
98 actual: Option<TextureFormat>,
99 ty: RenderPassCompatibilityCheckType,
100 },
101 #[error(
102 "Incompatible sample count: the RenderPass uses textures with sample count {expected:?} but the {ty:?} uses attachments with format {actual:?}",
103 )]
104 IncompatibleSampleCount {
105 expected: u32,
106 actual: u32,
107 ty: RenderPassCompatibilityCheckType,
108 },
109 #[error("Incompatible multiview setting: the RenderPass uses setting {expected:?} but the {ty:?} uses setting {actual:?}")]
110 IncompatibleMultiview {
111 expected: Option<NonZeroU32>,
112 actual: Option<NonZeroU32>,
113 ty: RenderPassCompatibilityCheckType,
114 },
115}
116
117impl RenderPassContext {
118 pub(crate) fn check_compatible(
120 &self,
121 other: &Self,
122 ty: RenderPassCompatibilityCheckType,
123 ) -> Result<(), RenderPassCompatibilityError> {
124 if self.attachments.colors != other.attachments.colors {
125 let indices = self
126 .attachments
127 .colors
128 .iter()
129 .zip(&other.attachments.colors)
130 .enumerate()
131 .filter_map(|(idx, (left, right))| (left != right).then_some(idx))
132 .collect();
133 return Err(RenderPassCompatibilityError::IncompatibleColorAttachment {
134 indices,
135 expected: self.attachments.colors.iter().cloned().collect(),
136 actual: other.attachments.colors.iter().cloned().collect(),
137 ty,
138 });
139 }
140 if self.attachments.depth_stencil != other.attachments.depth_stencil {
141 return Err(
142 RenderPassCompatibilityError::IncompatibleDepthStencilAttachment {
143 expected: self.attachments.depth_stencil,
144 actual: other.attachments.depth_stencil,
145 ty,
146 },
147 );
148 }
149 if self.sample_count != other.sample_count {
150 return Err(RenderPassCompatibilityError::IncompatibleSampleCount {
151 expected: self.sample_count,
152 actual: other.sample_count,
153 ty,
154 });
155 }
156 if self.multiview != other.multiview {
157 return Err(RenderPassCompatibilityError::IncompatibleMultiview {
158 expected: self.multiview,
159 actual: other.multiview,
160 ty,
161 });
162 }
163 Ok(())
164 }
165}
166
167pub type BufferMapPendingClosure = (BufferMapOperation, BufferAccessResult);
168
169#[derive(Default)]
170pub struct UserClosures {
171 pub mappings: Vec<BufferMapPendingClosure>,
172 pub submissions: SmallVec<[queue::SubmittedWorkDoneClosure; 1]>,
173 pub device_lost_invocations: SmallVec<[DeviceLostInvocation; 1]>,
174}
175
176impl UserClosures {
177 fn extend(&mut self, other: Self) {
178 self.mappings.extend(other.mappings);
179 self.submissions.extend(other.submissions);
180 self.device_lost_invocations
181 .extend(other.device_lost_invocations);
182 }
183
184 fn fire(self) {
185 for (mut operation, status) in self.mappings {
191 if let Some(callback) = operation.callback.take() {
192 callback.call(status);
193 }
194 }
195 for closure in self.submissions {
196 closure.call();
197 }
198 for invocation in self.device_lost_invocations {
199 invocation
200 .closure
201 .call(invocation.reason, invocation.message);
202 }
203 }
204}
205
206#[cfg(send_sync)]
207pub type DeviceLostCallback = Box<dyn Fn(DeviceLostReason, String) + Send + 'static>;
208#[cfg(not(send_sync))]
209pub type DeviceLostCallback = Box<dyn Fn(DeviceLostReason, String) + 'static>;
210
211pub struct DeviceLostClosureRust {
212 pub callback: DeviceLostCallback,
213 consumed: bool,
214}
215
216impl Drop for DeviceLostClosureRust {
217 fn drop(&mut self) {
218 if !self.consumed {
219 panic!("DeviceLostClosureRust must be consumed before it is dropped.");
220 }
221 }
222}
223
224#[repr(C)]
225pub struct DeviceLostClosureC {
226 pub callback: unsafe extern "C" fn(user_data: *mut u8, reason: u8, message: *const c_char),
227 pub user_data: *mut u8,
228 consumed: bool,
229}
230
231#[cfg(send_sync)]
232unsafe impl Send for DeviceLostClosureC {}
233
234impl Drop for DeviceLostClosureC {
235 fn drop(&mut self) {
236 if !self.consumed {
237 panic!("DeviceLostClosureC must be consumed before it is dropped.");
238 }
239 }
240}
241
242pub struct DeviceLostClosure {
243 inner: DeviceLostClosureInner,
246}
247
248pub struct DeviceLostInvocation {
249 closure: DeviceLostClosure,
250 reason: DeviceLostReason,
251 message: String,
252}
253
254enum DeviceLostClosureInner {
255 Rust { inner: DeviceLostClosureRust },
256 C { inner: DeviceLostClosureC },
257}
258
259impl DeviceLostClosure {
260 pub fn from_rust(callback: DeviceLostCallback) -> Self {
261 let inner = DeviceLostClosureRust {
262 callback,
263 consumed: false,
264 };
265 Self {
266 inner: DeviceLostClosureInner::Rust { inner },
267 }
268 }
269
270 pub unsafe fn from_c(mut closure: DeviceLostClosureC) -> Self {
278 let inner = DeviceLostClosureC {
281 callback: closure.callback,
282 user_data: closure.user_data,
283 consumed: false,
284 };
285
286 closure.consumed = true;
288
289 Self {
290 inner: DeviceLostClosureInner::C { inner },
291 }
292 }
293
294 pub(crate) fn call(self, reason: DeviceLostReason, message: String) {
295 match self.inner {
296 DeviceLostClosureInner::Rust { mut inner } => {
297 inner.consumed = true;
298
299 (inner.callback)(reason, message)
300 }
301 DeviceLostClosureInner::C { mut inner } => unsafe {
303 inner.consumed = true;
304
305 let message = std::ffi::CString::new(message).unwrap();
308 (inner.callback)(inner.user_data, reason as u8, message.as_ptr())
309 },
310 }
311 }
312}
313
314fn map_buffer<A: HalApi>(
315 raw: &A::Device,
316 buffer: &Buffer<A>,
317 offset: BufferAddress,
318 size: BufferAddress,
319 kind: HostMap,
320 snatch_guard: &SnatchGuard,
321) -> Result<ptr::NonNull<u8>, BufferAccessError> {
322 let raw_buffer = buffer
323 .raw(snatch_guard)
324 .ok_or(BufferAccessError::Destroyed)?;
325 let mapping = unsafe {
326 raw.map_buffer(raw_buffer, offset..offset + size)
327 .map_err(DeviceError::from)?
328 };
329
330 *buffer.sync_mapped_writes.lock() = match kind {
331 HostMap::Read if !mapping.is_coherent => unsafe {
332 raw.invalidate_mapped_ranges(raw_buffer, iter::once(offset..offset + size));
333 None
334 },
335 HostMap::Write if !mapping.is_coherent => Some(offset..offset + size),
336 _ => None,
337 };
338
339 assert_eq!(offset % wgt::COPY_BUFFER_ALIGNMENT, 0);
340 assert_eq!(size % wgt::COPY_BUFFER_ALIGNMENT, 0);
341 let zero_init_needs_flush_now =
357 mapping.is_coherent && buffer.sync_mapped_writes.lock().is_none();
358 let mapped = unsafe { std::slice::from_raw_parts_mut(mapping.ptr.as_ptr(), size as usize) };
359
360 for uninitialized in buffer
361 .initialization_status
362 .write()
363 .drain(offset..(size + offset))
364 {
365 let fill_range =
368 (uninitialized.start - offset) as usize..(uninitialized.end - offset) as usize;
369 mapped[fill_range].fill(0);
370
371 if zero_init_needs_flush_now {
372 unsafe { raw.flush_mapped_ranges(raw_buffer, iter::once(uninitialized)) };
373 }
374 }
375
376 Ok(mapping.ptr)
377}
378
379#[derive(Clone, Debug, Error)]
380#[error("Device is invalid")]
381pub struct InvalidDevice;
382
383#[derive(Clone, Debug, Error)]
384#[non_exhaustive]
385pub enum DeviceError {
386 #[error("Parent device is invalid.")]
387 Invalid,
388 #[error("Parent device is lost")]
389 Lost,
390 #[error("Not enough memory left.")]
391 OutOfMemory,
392 #[error("Creation of a resource failed for a reason other than running out of memory.")]
393 ResourceCreationFailed,
394 #[error("QueueId is invalid")]
395 InvalidQueueId,
396 #[error("Attempt to use a resource with a different device from the one that created it")]
397 WrongDevice,
398}
399
400impl From<hal::DeviceError> for DeviceError {
401 fn from(error: hal::DeviceError) -> Self {
402 match error {
403 hal::DeviceError::Lost => DeviceError::Lost,
404 hal::DeviceError::OutOfMemory => DeviceError::OutOfMemory,
405 hal::DeviceError::ResourceCreationFailed => DeviceError::ResourceCreationFailed,
406 }
407 }
408}
409
410#[derive(Clone, Debug, Error)]
411#[error("Features {0:?} are required but not enabled on the device")]
412pub struct MissingFeatures(pub wgt::Features);
413
414#[derive(Clone, Debug, Error)]
415#[error(
416 "Downlevel flags {0:?} are required but not supported on the device.\n{}",
417 DOWNLEVEL_ERROR_MESSAGE
418)]
419pub struct MissingDownlevelFlags(pub wgt::DownlevelFlags);
420
421#[derive(Clone, Debug)]
422#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
423pub struct ImplicitPipelineContext {
424 pub root_id: PipelineLayoutId,
425 pub group_ids: ArrayVec<BindGroupLayoutId, { hal::MAX_BIND_GROUPS }>,
426}
427
428pub struct ImplicitPipelineIds<'a> {
429 pub root_id: Option<PipelineLayoutId>,
430 pub group_ids: &'a [Option<BindGroupLayoutId>],
431}
432
433impl ImplicitPipelineIds<'_> {
434 fn prepare<A: HalApi>(self, hub: &Hub<A>) -> ImplicitPipelineContext {
435 ImplicitPipelineContext {
436 root_id: hub.pipeline_layouts.prepare(self.root_id).into_id(),
437 group_ids: self
438 .group_ids
439 .iter()
440 .map(|id_in| hub.bind_group_layouts.prepare(*id_in).into_id())
441 .collect(),
442 }
443 }
444}