wgpu_core/device/
mod.rs

1use crate::{
2    binding_model,
3    hal_api::HalApi,
4    hub::Hub,
5    id::{BindGroupLayoutId, PipelineLayoutId},
6    resource::{Buffer, BufferAccessError, BufferAccessResult, BufferMapOperation},
7    snatch::SnatchGuard,
8    Label, DOWNLEVEL_ERROR_MESSAGE,
9};
10
11use arrayvec::ArrayVec;
12use hal::Device as _;
13use smallvec::SmallVec;
14use std::os::raw::c_char;
15use thiserror::Error;
16use wgt::{BufferAddress, DeviceLostReason, TextureFormat};
17
18use std::{iter, num::NonZeroU32, ptr};
19
20pub mod any_device;
21pub(crate) mod bgl;
22pub mod global;
23mod life;
24pub mod queue;
25pub mod resource;
26#[cfg(any(feature = "trace", feature = "replay"))]
27pub mod trace;
28pub use {life::WaitIdleError, resource::Device};
29
30pub const SHADER_STAGE_COUNT: usize = hal::MAX_CONCURRENT_SHADER_STAGES;
31// Should be large enough for the largest possible texture row. This
32// value is enough for a 16k texture with float4 format.
33pub(crate) const ZERO_BUFFER_SIZE: BufferAddress = 512 << 10;
34
35const CLEANUP_WAIT_MS: u32 = 5000;
36
37const IMPLICIT_BIND_GROUP_LAYOUT_ERROR_LABEL: &str = "Implicit BindGroupLayout in the Error State";
38const ENTRYPOINT_FAILURE_ERROR: &str = "The given EntryPoint is Invalid";
39
40pub type DeviceDescriptor<'a> = wgt::DeviceDescriptor<Label<'a>>;
41
42#[repr(C)]
43#[derive(Clone, Copy, Debug, Eq, PartialEq)]
44#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
45pub enum HostMap {
46    Read,
47    Write,
48}
49
50#[derive(Clone, Debug, Hash, PartialEq)]
51#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
52pub(crate) struct AttachmentData<T> {
53    pub colors: ArrayVec<Option<T>, { hal::MAX_COLOR_ATTACHMENTS }>,
54    pub resolves: ArrayVec<T, { hal::MAX_COLOR_ATTACHMENTS }>,
55    pub depth_stencil: Option<T>,
56}
57impl<T: PartialEq> Eq for AttachmentData<T> {}
58impl<T> AttachmentData<T> {
59    pub(crate) fn map<U, F: Fn(&T) -> U>(&self, fun: F) -> AttachmentData<U> {
60        AttachmentData {
61            colors: self.colors.iter().map(|c| c.as_ref().map(&fun)).collect(),
62            resolves: self.resolves.iter().map(&fun).collect(),
63            depth_stencil: self.depth_stencil.as_ref().map(&fun),
64        }
65    }
66}
67
68#[derive(Debug, Copy, Clone)]
69pub enum RenderPassCompatibilityCheckType {
70    RenderPipeline,
71    RenderBundle,
72}
73
74#[derive(Clone, Debug, Hash, PartialEq)]
75#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
76pub(crate) struct RenderPassContext {
77    pub attachments: AttachmentData<TextureFormat>,
78    pub sample_count: u32,
79    pub multiview: Option<NonZeroU32>,
80}
81#[derive(Clone, Debug, Error)]
82#[non_exhaustive]
83pub enum RenderPassCompatibilityError {
84    #[error(
85        "Incompatible color attachments at indices {indices:?}: the RenderPass uses textures with formats {expected:?} but the {ty:?} uses attachments with formats {actual:?}",
86    )]
87    IncompatibleColorAttachment {
88        indices: Vec<usize>,
89        expected: Vec<Option<TextureFormat>>,
90        actual: Vec<Option<TextureFormat>>,
91        ty: RenderPassCompatibilityCheckType,
92    },
93    #[error(
94        "Incompatible depth-stencil attachment format: the RenderPass uses a texture with format {expected:?} but the {ty:?} uses an attachment with format {actual:?}",
95    )]
96    IncompatibleDepthStencilAttachment {
97        expected: Option<TextureFormat>,
98        actual: Option<TextureFormat>,
99        ty: RenderPassCompatibilityCheckType,
100    },
101    #[error(
102        "Incompatible sample count: the RenderPass uses textures with sample count {expected:?} but the {ty:?} uses attachments with format {actual:?}",
103    )]
104    IncompatibleSampleCount {
105        expected: u32,
106        actual: u32,
107        ty: RenderPassCompatibilityCheckType,
108    },
109    #[error("Incompatible multiview setting: the RenderPass uses setting {expected:?} but the {ty:?} uses setting {actual:?}")]
110    IncompatibleMultiview {
111        expected: Option<NonZeroU32>,
112        actual: Option<NonZeroU32>,
113        ty: RenderPassCompatibilityCheckType,
114    },
115}
116
117impl RenderPassContext {
118    // Assumes the renderpass only contains one subpass
119    pub(crate) fn check_compatible(
120        &self,
121        other: &Self,
122        ty: RenderPassCompatibilityCheckType,
123    ) -> Result<(), RenderPassCompatibilityError> {
124        if self.attachments.colors != other.attachments.colors {
125            let indices = self
126                .attachments
127                .colors
128                .iter()
129                .zip(&other.attachments.colors)
130                .enumerate()
131                .filter_map(|(idx, (left, right))| (left != right).then_some(idx))
132                .collect();
133            return Err(RenderPassCompatibilityError::IncompatibleColorAttachment {
134                indices,
135                expected: self.attachments.colors.iter().cloned().collect(),
136                actual: other.attachments.colors.iter().cloned().collect(),
137                ty,
138            });
139        }
140        if self.attachments.depth_stencil != other.attachments.depth_stencil {
141            return Err(
142                RenderPassCompatibilityError::IncompatibleDepthStencilAttachment {
143                    expected: self.attachments.depth_stencil,
144                    actual: other.attachments.depth_stencil,
145                    ty,
146                },
147            );
148        }
149        if self.sample_count != other.sample_count {
150            return Err(RenderPassCompatibilityError::IncompatibleSampleCount {
151                expected: self.sample_count,
152                actual: other.sample_count,
153                ty,
154            });
155        }
156        if self.multiview != other.multiview {
157            return Err(RenderPassCompatibilityError::IncompatibleMultiview {
158                expected: self.multiview,
159                actual: other.multiview,
160                ty,
161            });
162        }
163        Ok(())
164    }
165}
166
167pub type BufferMapPendingClosure = (BufferMapOperation, BufferAccessResult);
168
169#[derive(Default)]
170pub struct UserClosures {
171    pub mappings: Vec<BufferMapPendingClosure>,
172    pub submissions: SmallVec<[queue::SubmittedWorkDoneClosure; 1]>,
173    pub device_lost_invocations: SmallVec<[DeviceLostInvocation; 1]>,
174}
175
176impl UserClosures {
177    fn extend(&mut self, other: Self) {
178        self.mappings.extend(other.mappings);
179        self.submissions.extend(other.submissions);
180        self.device_lost_invocations
181            .extend(other.device_lost_invocations);
182    }
183
184    fn fire(self) {
185        // Note: this logic is specifically moved out of `handle_mapping()` in order to
186        // have nothing locked by the time we execute users callback code.
187
188        // Mappings _must_ be fired before submissions, as the spec requires all mapping callbacks that are registered before
189        // a on_submitted_work_done callback to be fired before the on_submitted_work_done callback.
190        for (mut operation, status) in self.mappings {
191            if let Some(callback) = operation.callback.take() {
192                callback.call(status);
193            }
194        }
195        for closure in self.submissions {
196            closure.call();
197        }
198        for invocation in self.device_lost_invocations {
199            invocation
200                .closure
201                .call(invocation.reason, invocation.message);
202        }
203    }
204}
205
206#[cfg(send_sync)]
207pub type DeviceLostCallback = Box<dyn Fn(DeviceLostReason, String) + Send + 'static>;
208#[cfg(not(send_sync))]
209pub type DeviceLostCallback = Box<dyn Fn(DeviceLostReason, String) + 'static>;
210
211pub struct DeviceLostClosureRust {
212    pub callback: DeviceLostCallback,
213    consumed: bool,
214}
215
216impl Drop for DeviceLostClosureRust {
217    fn drop(&mut self) {
218        if !self.consumed {
219            panic!("DeviceLostClosureRust must be consumed before it is dropped.");
220        }
221    }
222}
223
224#[repr(C)]
225pub struct DeviceLostClosureC {
226    pub callback: unsafe extern "C" fn(user_data: *mut u8, reason: u8, message: *const c_char),
227    pub user_data: *mut u8,
228    consumed: bool,
229}
230
231#[cfg(send_sync)]
232unsafe impl Send for DeviceLostClosureC {}
233
234impl Drop for DeviceLostClosureC {
235    fn drop(&mut self) {
236        if !self.consumed {
237            panic!("DeviceLostClosureC must be consumed before it is dropped.");
238        }
239    }
240}
241
242pub struct DeviceLostClosure {
243    // We wrap this so creating the enum in the C variant can be unsafe,
244    // allowing our call function to be safe.
245    inner: DeviceLostClosureInner,
246}
247
248pub struct DeviceLostInvocation {
249    closure: DeviceLostClosure,
250    reason: DeviceLostReason,
251    message: String,
252}
253
254enum DeviceLostClosureInner {
255    Rust { inner: DeviceLostClosureRust },
256    C { inner: DeviceLostClosureC },
257}
258
259impl DeviceLostClosure {
260    pub fn from_rust(callback: DeviceLostCallback) -> Self {
261        let inner = DeviceLostClosureRust {
262            callback,
263            consumed: false,
264        };
265        Self {
266            inner: DeviceLostClosureInner::Rust { inner },
267        }
268    }
269
270    /// # Safety
271    ///
272    /// - The callback pointer must be valid to call with the provided `user_data`
273    ///   pointer.
274    ///
275    /// - Both pointers must point to `'static` data, as the callback may happen at
276    ///   an unspecified time.
277    pub unsafe fn from_c(mut closure: DeviceLostClosureC) -> Self {
278        // Build an inner with the values from closure, ensuring that
279        // inner.consumed is false.
280        let inner = DeviceLostClosureC {
281            callback: closure.callback,
282            user_data: closure.user_data,
283            consumed: false,
284        };
285
286        // Mark the original closure as consumed, so we can safely drop it.
287        closure.consumed = true;
288
289        Self {
290            inner: DeviceLostClosureInner::C { inner },
291        }
292    }
293
294    pub(crate) fn call(self, reason: DeviceLostReason, message: String) {
295        match self.inner {
296            DeviceLostClosureInner::Rust { mut inner } => {
297                inner.consumed = true;
298
299                (inner.callback)(reason, message)
300            }
301            // SAFETY: the contract of the call to from_c says that this unsafe is sound.
302            DeviceLostClosureInner::C { mut inner } => unsafe {
303                inner.consumed = true;
304
305                // Ensure message is structured as a null-terminated C string. It only
306                // needs to live as long as the callback invocation.
307                let message = std::ffi::CString::new(message).unwrap();
308                (inner.callback)(inner.user_data, reason as u8, message.as_ptr())
309            },
310        }
311    }
312}
313
314fn map_buffer<A: HalApi>(
315    raw: &A::Device,
316    buffer: &Buffer<A>,
317    offset: BufferAddress,
318    size: BufferAddress,
319    kind: HostMap,
320    snatch_guard: &SnatchGuard,
321) -> Result<ptr::NonNull<u8>, BufferAccessError> {
322    let raw_buffer = buffer
323        .raw(snatch_guard)
324        .ok_or(BufferAccessError::Destroyed)?;
325    let mapping = unsafe {
326        raw.map_buffer(raw_buffer, offset..offset + size)
327            .map_err(DeviceError::from)?
328    };
329
330    *buffer.sync_mapped_writes.lock() = match kind {
331        HostMap::Read if !mapping.is_coherent => unsafe {
332            raw.invalidate_mapped_ranges(raw_buffer, iter::once(offset..offset + size));
333            None
334        },
335        HostMap::Write if !mapping.is_coherent => Some(offset..offset + size),
336        _ => None,
337    };
338
339    assert_eq!(offset % wgt::COPY_BUFFER_ALIGNMENT, 0);
340    assert_eq!(size % wgt::COPY_BUFFER_ALIGNMENT, 0);
341    // Zero out uninitialized parts of the mapping. (Spec dictates all resources
342    // behave as if they were initialized with zero)
343    //
344    // If this is a read mapping, ideally we would use a `clear_buffer` command
345    // before reading the data from GPU (i.e. `invalidate_range`). However, this
346    // would require us to kick off and wait for a command buffer or piggy back
347    // on an existing one (the later is likely the only worthwhile option). As
348    // reading uninitialized memory isn't a particular important path to
349    // support, we instead just initialize the memory here and make sure it is
350    // GPU visible, so this happens at max only once for every buffer region.
351    //
352    // If this is a write mapping zeroing out the memory here is the only
353    // reasonable way as all data is pushed to GPU anyways.
354
355    // No need to flush if it is flushed later anyways.
356    let zero_init_needs_flush_now =
357        mapping.is_coherent && buffer.sync_mapped_writes.lock().is_none();
358    let mapped = unsafe { std::slice::from_raw_parts_mut(mapping.ptr.as_ptr(), size as usize) };
359
360    for uninitialized in buffer
361        .initialization_status
362        .write()
363        .drain(offset..(size + offset))
364    {
365        // The mapping's pointer is already offset, however we track the
366        // uninitialized range relative to the buffer's start.
367        let fill_range =
368            (uninitialized.start - offset) as usize..(uninitialized.end - offset) as usize;
369        mapped[fill_range].fill(0);
370
371        if zero_init_needs_flush_now {
372            unsafe { raw.flush_mapped_ranges(raw_buffer, iter::once(uninitialized)) };
373        }
374    }
375
376    Ok(mapping.ptr)
377}
378
379#[derive(Clone, Debug, Error)]
380#[error("Device is invalid")]
381pub struct InvalidDevice;
382
383#[derive(Clone, Debug, Error)]
384#[non_exhaustive]
385pub enum DeviceError {
386    #[error("Parent device is invalid.")]
387    Invalid,
388    #[error("Parent device is lost")]
389    Lost,
390    #[error("Not enough memory left.")]
391    OutOfMemory,
392    #[error("Creation of a resource failed for a reason other than running out of memory.")]
393    ResourceCreationFailed,
394    #[error("QueueId is invalid")]
395    InvalidQueueId,
396    #[error("Attempt to use a resource with a different device from the one that created it")]
397    WrongDevice,
398}
399
400impl From<hal::DeviceError> for DeviceError {
401    fn from(error: hal::DeviceError) -> Self {
402        match error {
403            hal::DeviceError::Lost => DeviceError::Lost,
404            hal::DeviceError::OutOfMemory => DeviceError::OutOfMemory,
405            hal::DeviceError::ResourceCreationFailed => DeviceError::ResourceCreationFailed,
406        }
407    }
408}
409
410#[derive(Clone, Debug, Error)]
411#[error("Features {0:?} are required but not enabled on the device")]
412pub struct MissingFeatures(pub wgt::Features);
413
414#[derive(Clone, Debug, Error)]
415#[error(
416    "Downlevel flags {0:?} are required but not supported on the device.\n{}",
417    DOWNLEVEL_ERROR_MESSAGE
418)]
419pub struct MissingDownlevelFlags(pub wgt::DownlevelFlags);
420
421#[derive(Clone, Debug)]
422#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
423pub struct ImplicitPipelineContext {
424    pub root_id: PipelineLayoutId,
425    pub group_ids: ArrayVec<BindGroupLayoutId, { hal::MAX_BIND_GROUPS }>,
426}
427
428pub struct ImplicitPipelineIds<'a> {
429    pub root_id: Option<PipelineLayoutId>,
430    pub group_ids: &'a [Option<BindGroupLayoutId>],
431}
432
433impl ImplicitPipelineIds<'_> {
434    fn prepare<A: HalApi>(self, hub: &Hub<A>) -> ImplicitPipelineContext {
435        ImplicitPipelineContext {
436            root_id: hub.pipeline_layouts.prepare(self.root_id).into_id(),
437            group_ids: self
438                .group_ids
439                .iter()
440                .map(|id_in| hub.bind_group_layouts.prepare(*id_in).into_id())
441                .collect(),
442        }
443    }
444}