wgpu_core/command/
compute.rs

1use crate::command::compute_command::{ArcComputeCommand, ComputeCommand};
2use crate::device::DeviceError;
3use crate::resource::Resource;
4use crate::snatch::SnatchGuard;
5use crate::track::TrackerIndex;
6use crate::{
7    binding_model::{
8        BindError, BindGroup, LateMinBufferBindingSizeMismatch, PushConstantUploadError,
9    },
10    command::{
11        bind::Binder,
12        end_pipeline_statistics_query,
13        memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState},
14        BasePass, BasePassRef, BindGroupStateChange, CommandBuffer, CommandEncoderError,
15        CommandEncoderStatus, MapPassErr, PassErrorScope, QueryUseError, StateChange,
16    },
17    device::{MissingDownlevelFlags, MissingFeatures},
18    error::{ErrorFormatter, PrettyError},
19    global::Global,
20    hal_api::HalApi,
21    hal_label, id,
22    id::DeviceId,
23    init_tracker::MemoryInitKind,
24    resource::{self},
25    storage::Storage,
26    track::{Tracker, UsageConflict, UsageScope},
27    validation::{check_buffer_usage, MissingBufferUsageError},
28    Label,
29};
30
31use hal::CommandEncoder as _;
32#[cfg(feature = "serde")]
33use serde::Deserialize;
34#[cfg(feature = "serde")]
35use serde::Serialize;
36
37use thiserror::Error;
38
39use std::sync::Arc;
40use std::{fmt, mem, str};
41
42#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
43pub struct ComputePass {
44    base: BasePass<ComputeCommand>,
45    parent_id: id::CommandEncoderId,
46    timestamp_writes: Option<ComputePassTimestampWrites>,
47
48    // Resource binding dedupe state.
49    #[cfg_attr(feature = "serde", serde(skip))]
50    current_bind_groups: BindGroupStateChange,
51    #[cfg_attr(feature = "serde", serde(skip))]
52    current_pipeline: StateChange<id::ComputePipelineId>,
53}
54
55impl ComputePass {
56    pub fn new(parent_id: id::CommandEncoderId, desc: &ComputePassDescriptor) -> Self {
57        Self {
58            base: BasePass::new(&desc.label),
59            parent_id,
60            timestamp_writes: desc.timestamp_writes.cloned(),
61
62            current_bind_groups: BindGroupStateChange::new(),
63            current_pipeline: StateChange::new(),
64        }
65    }
66
67    pub fn parent_id(&self) -> id::CommandEncoderId {
68        self.parent_id
69    }
70
71    #[cfg(feature = "trace")]
72    pub fn into_command(self) -> crate::device::trace::Command {
73        crate::device::trace::Command::RunComputePass {
74            base: self.base,
75            timestamp_writes: self.timestamp_writes,
76        }
77    }
78}
79
80impl fmt::Debug for ComputePass {
81    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82        write!(
83            f,
84            "ComputePass {{ encoder_id: {:?}, data: {:?} commands and {:?} dynamic offsets }}",
85            self.parent_id,
86            self.base.commands.len(),
87            self.base.dynamic_offsets.len()
88        )
89    }
90}
91
92/// Describes the writing of timestamp values in a compute pass.
93#[repr(C)]
94#[derive(Clone, Debug, PartialEq, Eq)]
95#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
96pub struct ComputePassTimestampWrites {
97    /// The query set to write the timestamps to.
98    pub query_set: id::QuerySetId,
99    /// The index of the query set at which a start timestamp of this pass is written, if any.
100    pub beginning_of_pass_write_index: Option<u32>,
101    /// The index of the query set at which an end timestamp of this pass is written, if any.
102    pub end_of_pass_write_index: Option<u32>,
103}
104
105#[derive(Clone, Debug, Default)]
106pub struct ComputePassDescriptor<'a> {
107    pub label: Label<'a>,
108    /// Defines where and when timestamp values will be written for this pass.
109    pub timestamp_writes: Option<&'a ComputePassTimestampWrites>,
110}
111
112#[derive(Clone, Debug, Error, Eq, PartialEq)]
113#[non_exhaustive]
114pub enum DispatchError {
115    #[error("Compute pipeline must be set")]
116    MissingPipeline,
117    #[error("Incompatible bind group at index {index} in the current compute pipeline")]
118    IncompatibleBindGroup { index: u32, diff: Vec<String> },
119    #[error(
120        "Each current dispatch group size dimension ({current:?}) must be less or equal to {limit}"
121    )]
122    InvalidGroupSize { current: [u32; 3], limit: u32 },
123    #[error(transparent)]
124    BindingSizeTooSmall(#[from] LateMinBufferBindingSizeMismatch),
125}
126
127/// Error encountered when performing a compute pass.
128#[derive(Clone, Debug, Error)]
129pub enum ComputePassErrorInner {
130    #[error(transparent)]
131    Device(#[from] DeviceError),
132    #[error(transparent)]
133    Encoder(#[from] CommandEncoderError),
134    #[error("Bind group at index {0:?} is invalid")]
135    InvalidBindGroup(u32),
136    #[error("Device {0:?} is invalid")]
137    InvalidDevice(DeviceId),
138    #[error("Bind group index {index} is greater than the device's requested `max_bind_group` limit {max}")]
139    BindGroupIndexOutOfRange { index: u32, max: u32 },
140    #[error("Compute pipeline {0:?} is invalid")]
141    InvalidPipeline(id::ComputePipelineId),
142    #[error("QuerySet {0:?} is invalid")]
143    InvalidQuerySet(id::QuerySetId),
144    #[error("Indirect buffer {0:?} is invalid or destroyed")]
145    InvalidIndirectBuffer(id::BufferId),
146    #[error("Indirect buffer uses bytes {offset}..{end_offset} which overruns indirect buffer of size {buffer_size}")]
147    IndirectBufferOverrun {
148        offset: u64,
149        end_offset: u64,
150        buffer_size: u64,
151    },
152    #[error("Buffer {0:?} is invalid or destroyed")]
153    InvalidBuffer(id::BufferId),
154    #[error(transparent)]
155    ResourceUsageConflict(#[from] UsageConflict),
156    #[error(transparent)]
157    MissingBufferUsage(#[from] MissingBufferUsageError),
158    #[error("Cannot pop debug group, because number of pushed debug groups is zero")]
159    InvalidPopDebugGroup,
160    #[error(transparent)]
161    Dispatch(#[from] DispatchError),
162    #[error(transparent)]
163    Bind(#[from] BindError),
164    #[error(transparent)]
165    PushConstants(#[from] PushConstantUploadError),
166    #[error(transparent)]
167    QueryUse(#[from] QueryUseError),
168    #[error(transparent)]
169    MissingFeatures(#[from] MissingFeatures),
170    #[error(transparent)]
171    MissingDownlevelFlags(#[from] MissingDownlevelFlags),
172}
173
174impl PrettyError for ComputePassErrorInner {
175    fn fmt_pretty(&self, fmt: &mut ErrorFormatter) {
176        fmt.error(self);
177        match *self {
178            Self::InvalidPipeline(id) => {
179                fmt.compute_pipeline_label(&id);
180            }
181            Self::InvalidIndirectBuffer(id) => {
182                fmt.buffer_label(&id);
183            }
184            Self::Dispatch(DispatchError::IncompatibleBindGroup { ref diff, .. }) => {
185                for d in diff {
186                    fmt.note(&d);
187                }
188            }
189            _ => {}
190        };
191    }
192}
193
194/// Error encountered when performing a compute pass.
195#[derive(Clone, Debug, Error)]
196#[error("{scope}")]
197pub struct ComputePassError {
198    pub scope: PassErrorScope,
199    #[source]
200    pub(super) inner: ComputePassErrorInner,
201}
202impl PrettyError for ComputePassError {
203    fn fmt_pretty(&self, fmt: &mut ErrorFormatter) {
204        // This error is wrapper for the inner error,
205        // but the scope has useful labels
206        fmt.error(self);
207        self.scope.fmt_pretty(fmt);
208    }
209}
210
211impl<T, E> MapPassErr<T, ComputePassError> for Result<T, E>
212where
213    E: Into<ComputePassErrorInner>,
214{
215    fn map_pass_err(self, scope: PassErrorScope) -> Result<T, ComputePassError> {
216        self.map_err(|inner| ComputePassError {
217            scope,
218            inner: inner.into(),
219        })
220    }
221}
222
223struct State<'a, A: HalApi> {
224    binder: Binder<A>,
225    pipeline: Option<id::ComputePipelineId>,
226    scope: UsageScope<'a, A>,
227    debug_scope_depth: u32,
228}
229
230impl<'a, A: HalApi> State<'a, A> {
231    fn is_ready(&self) -> Result<(), DispatchError> {
232        let bind_mask = self.binder.invalid_mask();
233        if bind_mask != 0 {
234            //let (expected, provided) = self.binder.entries[index as usize].info();
235            let index = bind_mask.trailing_zeros();
236
237            return Err(DispatchError::IncompatibleBindGroup {
238                index,
239                diff: self.binder.bgl_diff(),
240            });
241        }
242        if self.pipeline.is_none() {
243            return Err(DispatchError::MissingPipeline);
244        }
245        self.binder.check_late_buffer_bindings()?;
246
247        Ok(())
248    }
249
250    // `extra_buffer` is there to represent the indirect buffer that is also
251    // part of the usage scope.
252    fn flush_states(
253        &mut self,
254        raw_encoder: &mut A::CommandEncoder,
255        base_trackers: &mut Tracker<A>,
256        bind_group_guard: &Storage<BindGroup<A>>,
257        indirect_buffer: Option<TrackerIndex>,
258        snatch_guard: &SnatchGuard,
259    ) -> Result<(), UsageConflict> {
260        for id in self.binder.list_active() {
261            unsafe { self.scope.merge_bind_group(&bind_group_guard[id].used)? };
262            // Note: stateless trackers are not merged: the lifetime reference
263            // is held to the bind group itself.
264        }
265
266        for id in self.binder.list_active() {
267            unsafe {
268                base_trackers.set_and_remove_from_usage_scope_sparse(
269                    &mut self.scope,
270                    &bind_group_guard[id].used,
271                )
272            }
273        }
274
275        // Add the state of the indirect buffer if it hasn't been hit before.
276        unsafe {
277            base_trackers
278                .buffers
279                .set_and_remove_from_usage_scope_sparse(&mut self.scope.buffers, indirect_buffer);
280        }
281
282        log::trace!("Encoding dispatch barriers");
283
284        CommandBuffer::drain_barriers(raw_encoder, base_trackers, snatch_guard);
285        Ok(())
286    }
287}
288
289// Common routines between render/compute
290
291impl Global {
292    pub fn command_encoder_run_compute_pass<A: HalApi>(
293        &self,
294        encoder_id: id::CommandEncoderId,
295        pass: &ComputePass,
296    ) -> Result<(), ComputePassError> {
297        // TODO: This should go directly to `command_encoder_run_compute_pass_impl` by means of storing `ArcComputeCommand` internally.
298        self.command_encoder_run_compute_pass_with_unresolved_commands::<A>(
299            encoder_id,
300            pass.base.as_ref(),
301            pass.timestamp_writes.as_ref(),
302        )
303    }
304
305    #[doc(hidden)]
306    pub fn command_encoder_run_compute_pass_with_unresolved_commands<A: HalApi>(
307        &self,
308        encoder_id: id::CommandEncoderId,
309        base: BasePassRef<ComputeCommand>,
310        timestamp_writes: Option<&ComputePassTimestampWrites>,
311    ) -> Result<(), ComputePassError> {
312        let resolved_commands =
313            ComputeCommand::resolve_compute_command_ids(A::hub(self), base.commands)?;
314
315        self.command_encoder_run_compute_pass_impl::<A>(
316            encoder_id,
317            BasePassRef {
318                label: base.label,
319                commands: &resolved_commands,
320                dynamic_offsets: base.dynamic_offsets,
321                string_data: base.string_data,
322                push_constant_data: base.push_constant_data,
323            },
324            timestamp_writes,
325        )
326    }
327
328    fn command_encoder_run_compute_pass_impl<A: HalApi>(
329        &self,
330        encoder_id: id::CommandEncoderId,
331        base: BasePassRef<ArcComputeCommand<A>>,
332        timestamp_writes: Option<&ComputePassTimestampWrites>,
333    ) -> Result<(), ComputePassError> {
334        profiling::scope!("CommandEncoder::run_compute_pass");
335        let pass_scope = PassErrorScope::Pass(encoder_id);
336
337        let hub = A::hub(self);
338
339        let cmd_buf: Arc<CommandBuffer<A>> =
340            CommandBuffer::get_encoder(hub, encoder_id).map_pass_err(pass_scope)?;
341        let device = &cmd_buf.device;
342        if !device.is_valid() {
343            return Err(ComputePassErrorInner::InvalidDevice(
344                cmd_buf.device.as_info().id(),
345            ))
346            .map_pass_err(pass_scope);
347        }
348
349        let mut cmd_buf_data = cmd_buf.data.lock();
350        let cmd_buf_data = cmd_buf_data.as_mut().unwrap();
351
352        #[cfg(feature = "trace")]
353        if let Some(ref mut list) = cmd_buf_data.commands {
354            list.push(crate::device::trace::Command::RunComputePass {
355                base: BasePass {
356                    label: base.label.map(str::to_string),
357                    commands: base.commands.iter().map(Into::into).collect(),
358                    dynamic_offsets: base.dynamic_offsets.to_vec(),
359                    string_data: base.string_data.to_vec(),
360                    push_constant_data: base.push_constant_data.to_vec(),
361                },
362                timestamp_writes: timestamp_writes.cloned(),
363            });
364        }
365
366        let encoder = &mut cmd_buf_data.encoder;
367        let status = &mut cmd_buf_data.status;
368        let tracker = &mut cmd_buf_data.trackers;
369        let buffer_memory_init_actions = &mut cmd_buf_data.buffer_memory_init_actions;
370        let texture_memory_actions = &mut cmd_buf_data.texture_memory_actions;
371
372        // We automatically keep extending command buffers over time, and because
373        // we want to insert a command buffer _before_ what we're about to record,
374        // we need to make sure to close the previous one.
375        encoder.close().map_pass_err(pass_scope)?;
376        // will be reset to true if recording is done without errors
377        *status = CommandEncoderStatus::Error;
378        let raw = encoder.open().map_pass_err(pass_scope)?;
379
380        let bind_group_guard = hub.bind_groups.read();
381        let query_set_guard = hub.query_sets.read();
382
383        let mut state = State {
384            binder: Binder::new(),
385            pipeline: None,
386            scope: device.new_usage_scope(),
387            debug_scope_depth: 0,
388        };
389        let mut temp_offsets = Vec::new();
390        let mut dynamic_offset_count = 0;
391        let mut string_offset = 0;
392        let mut active_query = None;
393
394        let timestamp_writes = if let Some(tw) = timestamp_writes {
395            let query_set: &resource::QuerySet<A> = tracker
396                .query_sets
397                .add_single(&*query_set_guard, tw.query_set)
398                .ok_or(ComputePassErrorInner::InvalidQuerySet(tw.query_set))
399                .map_pass_err(pass_scope)?;
400
401            // Unlike in render passes we can't delay resetting the query sets since
402            // there is no auxiliary pass.
403            let range = if let (Some(index_a), Some(index_b)) =
404                (tw.beginning_of_pass_write_index, tw.end_of_pass_write_index)
405            {
406                Some(index_a.min(index_b)..index_a.max(index_b) + 1)
407            } else {
408                tw.beginning_of_pass_write_index
409                    .or(tw.end_of_pass_write_index)
410                    .map(|i| i..i + 1)
411            };
412            // Range should always be Some, both values being None should lead to a validation error.
413            // But no point in erroring over that nuance here!
414            if let Some(range) = range {
415                unsafe {
416                    raw.reset_queries(query_set.raw.as_ref().unwrap(), range);
417                }
418            }
419
420            Some(hal::ComputePassTimestampWrites {
421                query_set: query_set.raw.as_ref().unwrap(),
422                beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
423                end_of_pass_write_index: tw.end_of_pass_write_index,
424            })
425        } else {
426            None
427        };
428
429        let snatch_guard = device.snatchable_lock.read();
430
431        let indices = &device.tracker_indices;
432        tracker.buffers.set_size(indices.buffers.size());
433        tracker.textures.set_size(indices.textures.size());
434        tracker.bind_groups.set_size(indices.bind_groups.size());
435        tracker
436            .compute_pipelines
437            .set_size(indices.compute_pipelines.size());
438        tracker.query_sets.set_size(indices.query_sets.size());
439
440        let discard_hal_labels = self
441            .instance
442            .flags
443            .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS);
444        let hal_desc = hal::ComputePassDescriptor {
445            label: hal_label(base.label, self.instance.flags),
446            timestamp_writes,
447        };
448
449        unsafe {
450            raw.begin_compute_pass(&hal_desc);
451        }
452
453        let mut intermediate_trackers = Tracker::<A>::new();
454
455        // Immediate texture inits required because of prior discards. Need to
456        // be inserted before texture reads.
457        let mut pending_discard_init_fixups = SurfacesInDiscardState::new();
458
459        // TODO: We should be draining the commands here, avoiding extra copies in the process.
460        //       (A command encoder can't be executed twice!)
461        for command in base.commands {
462            match command {
463                ArcComputeCommand::SetBindGroup {
464                    index,
465                    num_dynamic_offsets,
466                    bind_group,
467                } => {
468                    let scope = PassErrorScope::SetBindGroup(bind_group.as_info().id());
469
470                    let max_bind_groups = cmd_buf.limits.max_bind_groups;
471                    if index >= &max_bind_groups {
472                        return Err(ComputePassErrorInner::BindGroupIndexOutOfRange {
473                            index: *index,
474                            max: max_bind_groups,
475                        })
476                        .map_pass_err(scope);
477                    }
478
479                    temp_offsets.clear();
480                    temp_offsets.extend_from_slice(
481                        &base.dynamic_offsets
482                            [dynamic_offset_count..dynamic_offset_count + num_dynamic_offsets],
483                    );
484                    dynamic_offset_count += num_dynamic_offsets;
485
486                    let bind_group = tracker.bind_groups.insert_single(bind_group.clone());
487                    bind_group
488                        .validate_dynamic_bindings(*index, &temp_offsets, &cmd_buf.limits)
489                        .map_pass_err(scope)?;
490
491                    buffer_memory_init_actions.extend(
492                        bind_group.used_buffer_ranges.iter().filter_map(|action| {
493                            action
494                                .buffer
495                                .initialization_status
496                                .read()
497                                .check_action(action)
498                        }),
499                    );
500
501                    for action in bind_group.used_texture_ranges.iter() {
502                        pending_discard_init_fixups
503                            .extend(texture_memory_actions.register_init_action(action));
504                    }
505
506                    let pipeline_layout = state.binder.pipeline_layout.clone();
507                    let entries =
508                        state
509                            .binder
510                            .assign_group(*index as usize, bind_group, &temp_offsets);
511                    if !entries.is_empty() && pipeline_layout.is_some() {
512                        let pipeline_layout = pipeline_layout.as_ref().unwrap().raw();
513                        for (i, e) in entries.iter().enumerate() {
514                            if let Some(group) = e.group.as_ref() {
515                                let raw_bg = group
516                                    .raw(&snatch_guard)
517                                    .ok_or(ComputePassErrorInner::InvalidBindGroup(i as u32))
518                                    .map_pass_err(scope)?;
519                                unsafe {
520                                    raw.set_bind_group(
521                                        pipeline_layout,
522                                        index + i as u32,
523                                        raw_bg,
524                                        &e.dynamic_offsets,
525                                    );
526                                }
527                            }
528                        }
529                    }
530                }
531                ArcComputeCommand::SetPipeline(pipeline) => {
532                    let pipeline_id = pipeline.as_info().id();
533                    let scope = PassErrorScope::SetPipelineCompute(pipeline_id);
534
535                    state.pipeline = Some(pipeline_id);
536
537                    tracker.compute_pipelines.insert_single(pipeline.clone());
538
539                    unsafe {
540                        raw.set_compute_pipeline(pipeline.raw());
541                    }
542
543                    // Rebind resources
544                    if state.binder.pipeline_layout.is_none()
545                        || !state
546                            .binder
547                            .pipeline_layout
548                            .as_ref()
549                            .unwrap()
550                            .is_equal(&pipeline.layout)
551                    {
552                        let (start_index, entries) = state.binder.change_pipeline_layout(
553                            &pipeline.layout,
554                            &pipeline.late_sized_buffer_groups,
555                        );
556                        if !entries.is_empty() {
557                            for (i, e) in entries.iter().enumerate() {
558                                if let Some(group) = e.group.as_ref() {
559                                    let raw_bg = group
560                                        .raw(&snatch_guard)
561                                        .ok_or(ComputePassErrorInner::InvalidBindGroup(i as u32))
562                                        .map_pass_err(scope)?;
563                                    unsafe {
564                                        raw.set_bind_group(
565                                            pipeline.layout.raw(),
566                                            start_index as u32 + i as u32,
567                                            raw_bg,
568                                            &e.dynamic_offsets,
569                                        );
570                                    }
571                                }
572                            }
573                        }
574
575                        // Clear push constant ranges
576                        let non_overlapping = super::bind::compute_nonoverlapping_ranges(
577                            &pipeline.layout.push_constant_ranges,
578                        );
579                        for range in non_overlapping {
580                            let offset = range.range.start;
581                            let size_bytes = range.range.end - offset;
582                            super::push_constant_clear(
583                                offset,
584                                size_bytes,
585                                |clear_offset, clear_data| unsafe {
586                                    raw.set_push_constants(
587                                        pipeline.layout.raw(),
588                                        wgt::ShaderStages::COMPUTE,
589                                        clear_offset,
590                                        clear_data,
591                                    );
592                                },
593                            );
594                        }
595                    }
596                }
597                ArcComputeCommand::SetPushConstant {
598                    offset,
599                    size_bytes,
600                    values_offset,
601                } => {
602                    let scope = PassErrorScope::SetPushConstant;
603
604                    let end_offset_bytes = offset + size_bytes;
605                    let values_end_offset =
606                        (values_offset + size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
607                    let data_slice =
608                        &base.push_constant_data[(*values_offset as usize)..values_end_offset];
609
610                    let pipeline_layout = state
611                        .binder
612                        .pipeline_layout
613                        .as_ref()
614                        //TODO: don't error here, lazily update the push constants
615                        .ok_or(ComputePassErrorInner::Dispatch(
616                            DispatchError::MissingPipeline,
617                        ))
618                        .map_pass_err(scope)?;
619
620                    pipeline_layout
621                        .validate_push_constant_ranges(
622                            wgt::ShaderStages::COMPUTE,
623                            *offset,
624                            end_offset_bytes,
625                        )
626                        .map_pass_err(scope)?;
627
628                    unsafe {
629                        raw.set_push_constants(
630                            pipeline_layout.raw(),
631                            wgt::ShaderStages::COMPUTE,
632                            *offset,
633                            data_slice,
634                        );
635                    }
636                }
637                ArcComputeCommand::Dispatch(groups) => {
638                    let scope = PassErrorScope::Dispatch {
639                        indirect: false,
640                        pipeline: state.pipeline,
641                    };
642                    state.is_ready().map_pass_err(scope)?;
643
644                    state
645                        .flush_states(
646                            raw,
647                            &mut intermediate_trackers,
648                            &*bind_group_guard,
649                            None,
650                            &snatch_guard,
651                        )
652                        .map_pass_err(scope)?;
653
654                    let groups_size_limit = cmd_buf.limits.max_compute_workgroups_per_dimension;
655
656                    if groups[0] > groups_size_limit
657                        || groups[1] > groups_size_limit
658                        || groups[2] > groups_size_limit
659                    {
660                        return Err(ComputePassErrorInner::Dispatch(
661                            DispatchError::InvalidGroupSize {
662                                current: *groups,
663                                limit: groups_size_limit,
664                            },
665                        ))
666                        .map_pass_err(scope);
667                    }
668
669                    unsafe {
670                        raw.dispatch(*groups);
671                    }
672                }
673                ArcComputeCommand::DispatchIndirect { buffer, offset } => {
674                    let buffer_id = buffer.as_info().id();
675                    let scope = PassErrorScope::Dispatch {
676                        indirect: true,
677                        pipeline: state.pipeline,
678                    };
679
680                    state.is_ready().map_pass_err(scope)?;
681
682                    device
683                        .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)
684                        .map_pass_err(scope)?;
685
686                    state
687                        .scope
688                        .buffers
689                        .insert_merge_single(buffer.clone(), hal::BufferUses::INDIRECT)
690                        .map_pass_err(scope)?;
691                    check_buffer_usage(buffer_id, buffer.usage, wgt::BufferUsages::INDIRECT)
692                        .map_pass_err(scope)?;
693
694                    let end_offset = offset + mem::size_of::<wgt::DispatchIndirectArgs>() as u64;
695                    if end_offset > buffer.size {
696                        return Err(ComputePassErrorInner::IndirectBufferOverrun {
697                            offset: *offset,
698                            end_offset,
699                            buffer_size: buffer.size,
700                        })
701                        .map_pass_err(scope);
702                    }
703
704                    let buf_raw = buffer
705                        .raw
706                        .get(&snatch_guard)
707                        .ok_or(ComputePassErrorInner::InvalidIndirectBuffer(buffer_id))
708                        .map_pass_err(scope)?;
709
710                    let stride = 3 * 4; // 3 integers, x/y/z group size
711
712                    buffer_memory_init_actions.extend(
713                        buffer.initialization_status.read().create_action(
714                            buffer,
715                            *offset..(*offset + stride),
716                            MemoryInitKind::NeedsInitializedMemory,
717                        ),
718                    );
719
720                    state
721                        .flush_states(
722                            raw,
723                            &mut intermediate_trackers,
724                            &*bind_group_guard,
725                            Some(buffer.as_info().tracker_index()),
726                            &snatch_guard,
727                        )
728                        .map_pass_err(scope)?;
729                    unsafe {
730                        raw.dispatch_indirect(buf_raw, *offset);
731                    }
732                }
733                ArcComputeCommand::PushDebugGroup { color: _, len } => {
734                    state.debug_scope_depth += 1;
735                    if !discard_hal_labels {
736                        let label =
737                            str::from_utf8(&base.string_data[string_offset..string_offset + len])
738                                .unwrap();
739                        unsafe {
740                            raw.begin_debug_marker(label);
741                        }
742                    }
743                    string_offset += len;
744                }
745                ArcComputeCommand::PopDebugGroup => {
746                    let scope = PassErrorScope::PopDebugGroup;
747
748                    if state.debug_scope_depth == 0 {
749                        return Err(ComputePassErrorInner::InvalidPopDebugGroup)
750                            .map_pass_err(scope);
751                    }
752                    state.debug_scope_depth -= 1;
753                    if !discard_hal_labels {
754                        unsafe {
755                            raw.end_debug_marker();
756                        }
757                    }
758                }
759                ArcComputeCommand::InsertDebugMarker { color: _, len } => {
760                    if !discard_hal_labels {
761                        let label =
762                            str::from_utf8(&base.string_data[string_offset..string_offset + len])
763                                .unwrap();
764                        unsafe { raw.insert_debug_marker(label) }
765                    }
766                    string_offset += len;
767                }
768                ArcComputeCommand::WriteTimestamp {
769                    query_set,
770                    query_index,
771                } => {
772                    let query_set_id = query_set.as_info().id();
773                    let scope = PassErrorScope::WriteTimestamp;
774
775                    device
776                        .require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_PASSES)
777                        .map_pass_err(scope)?;
778
779                    let query_set = tracker.query_sets.insert_single(query_set.clone());
780
781                    query_set
782                        .validate_and_write_timestamp(raw, query_set_id, *query_index, None)
783                        .map_pass_err(scope)?;
784                }
785                ArcComputeCommand::BeginPipelineStatisticsQuery {
786                    query_set,
787                    query_index,
788                } => {
789                    let query_set_id = query_set.as_info().id();
790                    let scope = PassErrorScope::BeginPipelineStatisticsQuery;
791
792                    let query_set = tracker.query_sets.insert_single(query_set.clone());
793
794                    query_set
795                        .validate_and_begin_pipeline_statistics_query(
796                            raw,
797                            query_set_id,
798                            *query_index,
799                            None,
800                            &mut active_query,
801                        )
802                        .map_pass_err(scope)?;
803                }
804                ArcComputeCommand::EndPipelineStatisticsQuery => {
805                    let scope = PassErrorScope::EndPipelineStatisticsQuery;
806
807                    end_pipeline_statistics_query(raw, &*query_set_guard, &mut active_query)
808                        .map_pass_err(scope)?;
809                }
810            }
811        }
812
813        unsafe {
814            raw.end_compute_pass();
815        }
816
817        // We've successfully recorded the compute pass, bring the
818        // command buffer out of the error state.
819        *status = CommandEncoderStatus::Recording;
820
821        // Stop the current command buffer.
822        encoder.close().map_pass_err(pass_scope)?;
823
824        // Create a new command buffer, which we will insert _before_ the body of the compute pass.
825        //
826        // Use that buffer to insert barriers and clear discarded images.
827        let transit = encoder.open().map_pass_err(pass_scope)?;
828        fixup_discarded_surfaces(
829            pending_discard_init_fixups.into_iter(),
830            transit,
831            &mut tracker.textures,
832            device,
833            &snatch_guard,
834        );
835        CommandBuffer::insert_barriers_from_tracker(
836            transit,
837            tracker,
838            &intermediate_trackers,
839            &snatch_guard,
840        );
841        // Close the command buffer, and swap it with the previous.
842        encoder.close_and_swap().map_pass_err(pass_scope)?;
843
844        Ok(())
845    }
846}
847
848pub mod compute_commands {
849    use super::{ComputeCommand, ComputePass};
850    use crate::id;
851    use std::convert::TryInto;
852    use wgt::{BufferAddress, DynamicOffset};
853
854    pub fn wgpu_compute_pass_set_bind_group(
855        pass: &mut ComputePass,
856        index: u32,
857        bind_group_id: id::BindGroupId,
858        offsets: &[DynamicOffset],
859    ) {
860        let redundant = pass.current_bind_groups.set_and_check_redundant(
861            bind_group_id,
862            index,
863            &mut pass.base.dynamic_offsets,
864            offsets,
865        );
866
867        if redundant {
868            return;
869        }
870
871        pass.base.commands.push(ComputeCommand::SetBindGroup {
872            index,
873            num_dynamic_offsets: offsets.len(),
874            bind_group_id,
875        });
876    }
877
878    pub fn wgpu_compute_pass_set_pipeline(
879        pass: &mut ComputePass,
880        pipeline_id: id::ComputePipelineId,
881    ) {
882        if pass.current_pipeline.set_and_check_redundant(pipeline_id) {
883            return;
884        }
885
886        pass.base
887            .commands
888            .push(ComputeCommand::SetPipeline(pipeline_id));
889    }
890
891    pub fn wgpu_compute_pass_set_push_constant(pass: &mut ComputePass, offset: u32, data: &[u8]) {
892        assert_eq!(
893            offset & (wgt::PUSH_CONSTANT_ALIGNMENT - 1),
894            0,
895            "Push constant offset must be aligned to 4 bytes."
896        );
897        assert_eq!(
898            data.len() as u32 & (wgt::PUSH_CONSTANT_ALIGNMENT - 1),
899            0,
900            "Push constant size must be aligned to 4 bytes."
901        );
902        let value_offset = pass.base.push_constant_data.len().try_into().expect(
903            "Ran out of push constant space. Don't set 4gb of push constants per ComputePass.",
904        );
905
906        pass.base.push_constant_data.extend(
907            data.chunks_exact(wgt::PUSH_CONSTANT_ALIGNMENT as usize)
908                .map(|arr| u32::from_ne_bytes([arr[0], arr[1], arr[2], arr[3]])),
909        );
910
911        pass.base.commands.push(ComputeCommand::SetPushConstant {
912            offset,
913            size_bytes: data.len() as u32,
914            values_offset: value_offset,
915        });
916    }
917
918    pub fn wgpu_compute_pass_dispatch_workgroups(
919        pass: &mut ComputePass,
920        groups_x: u32,
921        groups_y: u32,
922        groups_z: u32,
923    ) {
924        pass.base
925            .commands
926            .push(ComputeCommand::Dispatch([groups_x, groups_y, groups_z]));
927    }
928
929    pub fn wgpu_compute_pass_dispatch_workgroups_indirect(
930        pass: &mut ComputePass,
931        buffer_id: id::BufferId,
932        offset: BufferAddress,
933    ) {
934        pass.base
935            .commands
936            .push(ComputeCommand::DispatchIndirect { buffer_id, offset });
937    }
938
939    pub fn wgpu_compute_pass_push_debug_group(pass: &mut ComputePass, label: &str, color: u32) {
940        let bytes = label.as_bytes();
941        pass.base.string_data.extend_from_slice(bytes);
942
943        pass.base.commands.push(ComputeCommand::PushDebugGroup {
944            color,
945            len: bytes.len(),
946        });
947    }
948
949    pub fn wgpu_compute_pass_pop_debug_group(pass: &mut ComputePass) {
950        pass.base.commands.push(ComputeCommand::PopDebugGroup);
951    }
952
953    pub fn wgpu_compute_pass_insert_debug_marker(pass: &mut ComputePass, label: &str, color: u32) {
954        let bytes = label.as_bytes();
955        pass.base.string_data.extend_from_slice(bytes);
956
957        pass.base.commands.push(ComputeCommand::InsertDebugMarker {
958            color,
959            len: bytes.len(),
960        });
961    }
962
963    pub fn wgpu_compute_pass_write_timestamp(
964        pass: &mut ComputePass,
965        query_set_id: id::QuerySetId,
966        query_index: u32,
967    ) {
968        pass.base.commands.push(ComputeCommand::WriteTimestamp {
969            query_set_id,
970            query_index,
971        });
972    }
973
974    pub fn wgpu_compute_pass_begin_pipeline_statistics_query(
975        pass: &mut ComputePass,
976        query_set_id: id::QuerySetId,
977        query_index: u32,
978    ) {
979        pass.base
980            .commands
981            .push(ComputeCommand::BeginPipelineStatisticsQuery {
982                query_set_id,
983                query_index,
984            });
985    }
986
987    pub fn wgpu_compute_pass_end_pipeline_statistics_query(pass: &mut ComputePass) {
988        pass.base
989            .commands
990            .push(ComputeCommand::EndPipelineStatisticsQuery);
991    }
992}