1use crate::command::compute_command::{ArcComputeCommand, ComputeCommand};
2use crate::device::DeviceError;
3use crate::resource::Resource;
4use crate::snatch::SnatchGuard;
5use crate::track::TrackerIndex;
6use crate::{
7 binding_model::{
8 BindError, BindGroup, LateMinBufferBindingSizeMismatch, PushConstantUploadError,
9 },
10 command::{
11 bind::Binder,
12 end_pipeline_statistics_query,
13 memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState},
14 BasePass, BasePassRef, BindGroupStateChange, CommandBuffer, CommandEncoderError,
15 CommandEncoderStatus, MapPassErr, PassErrorScope, QueryUseError, StateChange,
16 },
17 device::{MissingDownlevelFlags, MissingFeatures},
18 error::{ErrorFormatter, PrettyError},
19 global::Global,
20 hal_api::HalApi,
21 hal_label, id,
22 id::DeviceId,
23 init_tracker::MemoryInitKind,
24 resource::{self},
25 storage::Storage,
26 track::{Tracker, UsageConflict, UsageScope},
27 validation::{check_buffer_usage, MissingBufferUsageError},
28 Label,
29};
30
31use hal::CommandEncoder as _;
32#[cfg(feature = "serde")]
33use serde::Deserialize;
34#[cfg(feature = "serde")]
35use serde::Serialize;
36
37use thiserror::Error;
38
39use std::sync::Arc;
40use std::{fmt, mem, str};
41
42#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
43pub struct ComputePass {
44 base: BasePass<ComputeCommand>,
45 parent_id: id::CommandEncoderId,
46 timestamp_writes: Option<ComputePassTimestampWrites>,
47
48 #[cfg_attr(feature = "serde", serde(skip))]
50 current_bind_groups: BindGroupStateChange,
51 #[cfg_attr(feature = "serde", serde(skip))]
52 current_pipeline: StateChange<id::ComputePipelineId>,
53}
54
55impl ComputePass {
56 pub fn new(parent_id: id::CommandEncoderId, desc: &ComputePassDescriptor) -> Self {
57 Self {
58 base: BasePass::new(&desc.label),
59 parent_id,
60 timestamp_writes: desc.timestamp_writes.cloned(),
61
62 current_bind_groups: BindGroupStateChange::new(),
63 current_pipeline: StateChange::new(),
64 }
65 }
66
67 pub fn parent_id(&self) -> id::CommandEncoderId {
68 self.parent_id
69 }
70
71 #[cfg(feature = "trace")]
72 pub fn into_command(self) -> crate::device::trace::Command {
73 crate::device::trace::Command::RunComputePass {
74 base: self.base,
75 timestamp_writes: self.timestamp_writes,
76 }
77 }
78}
79
80impl fmt::Debug for ComputePass {
81 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82 write!(
83 f,
84 "ComputePass {{ encoder_id: {:?}, data: {:?} commands and {:?} dynamic offsets }}",
85 self.parent_id,
86 self.base.commands.len(),
87 self.base.dynamic_offsets.len()
88 )
89 }
90}
91
92#[repr(C)]
94#[derive(Clone, Debug, PartialEq, Eq)]
95#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
96pub struct ComputePassTimestampWrites {
97 pub query_set: id::QuerySetId,
99 pub beginning_of_pass_write_index: Option<u32>,
101 pub end_of_pass_write_index: Option<u32>,
103}
104
105#[derive(Clone, Debug, Default)]
106pub struct ComputePassDescriptor<'a> {
107 pub label: Label<'a>,
108 pub timestamp_writes: Option<&'a ComputePassTimestampWrites>,
110}
111
112#[derive(Clone, Debug, Error, Eq, PartialEq)]
113#[non_exhaustive]
114pub enum DispatchError {
115 #[error("Compute pipeline must be set")]
116 MissingPipeline,
117 #[error("Incompatible bind group at index {index} in the current compute pipeline")]
118 IncompatibleBindGroup { index: u32, diff: Vec<String> },
119 #[error(
120 "Each current dispatch group size dimension ({current:?}) must be less or equal to {limit}"
121 )]
122 InvalidGroupSize { current: [u32; 3], limit: u32 },
123 #[error(transparent)]
124 BindingSizeTooSmall(#[from] LateMinBufferBindingSizeMismatch),
125}
126
127#[derive(Clone, Debug, Error)]
129pub enum ComputePassErrorInner {
130 #[error(transparent)]
131 Device(#[from] DeviceError),
132 #[error(transparent)]
133 Encoder(#[from] CommandEncoderError),
134 #[error("Bind group at index {0:?} is invalid")]
135 InvalidBindGroup(u32),
136 #[error("Device {0:?} is invalid")]
137 InvalidDevice(DeviceId),
138 #[error("Bind group index {index} is greater than the device's requested `max_bind_group` limit {max}")]
139 BindGroupIndexOutOfRange { index: u32, max: u32 },
140 #[error("Compute pipeline {0:?} is invalid")]
141 InvalidPipeline(id::ComputePipelineId),
142 #[error("QuerySet {0:?} is invalid")]
143 InvalidQuerySet(id::QuerySetId),
144 #[error("Indirect buffer {0:?} is invalid or destroyed")]
145 InvalidIndirectBuffer(id::BufferId),
146 #[error("Indirect buffer uses bytes {offset}..{end_offset} which overruns indirect buffer of size {buffer_size}")]
147 IndirectBufferOverrun {
148 offset: u64,
149 end_offset: u64,
150 buffer_size: u64,
151 },
152 #[error("Buffer {0:?} is invalid or destroyed")]
153 InvalidBuffer(id::BufferId),
154 #[error(transparent)]
155 ResourceUsageConflict(#[from] UsageConflict),
156 #[error(transparent)]
157 MissingBufferUsage(#[from] MissingBufferUsageError),
158 #[error("Cannot pop debug group, because number of pushed debug groups is zero")]
159 InvalidPopDebugGroup,
160 #[error(transparent)]
161 Dispatch(#[from] DispatchError),
162 #[error(transparent)]
163 Bind(#[from] BindError),
164 #[error(transparent)]
165 PushConstants(#[from] PushConstantUploadError),
166 #[error(transparent)]
167 QueryUse(#[from] QueryUseError),
168 #[error(transparent)]
169 MissingFeatures(#[from] MissingFeatures),
170 #[error(transparent)]
171 MissingDownlevelFlags(#[from] MissingDownlevelFlags),
172}
173
174impl PrettyError for ComputePassErrorInner {
175 fn fmt_pretty(&self, fmt: &mut ErrorFormatter) {
176 fmt.error(self);
177 match *self {
178 Self::InvalidPipeline(id) => {
179 fmt.compute_pipeline_label(&id);
180 }
181 Self::InvalidIndirectBuffer(id) => {
182 fmt.buffer_label(&id);
183 }
184 Self::Dispatch(DispatchError::IncompatibleBindGroup { ref diff, .. }) => {
185 for d in diff {
186 fmt.note(&d);
187 }
188 }
189 _ => {}
190 };
191 }
192}
193
194#[derive(Clone, Debug, Error)]
196#[error("{scope}")]
197pub struct ComputePassError {
198 pub scope: PassErrorScope,
199 #[source]
200 pub(super) inner: ComputePassErrorInner,
201}
202impl PrettyError for ComputePassError {
203 fn fmt_pretty(&self, fmt: &mut ErrorFormatter) {
204 fmt.error(self);
207 self.scope.fmt_pretty(fmt);
208 }
209}
210
211impl<T, E> MapPassErr<T, ComputePassError> for Result<T, E>
212where
213 E: Into<ComputePassErrorInner>,
214{
215 fn map_pass_err(self, scope: PassErrorScope) -> Result<T, ComputePassError> {
216 self.map_err(|inner| ComputePassError {
217 scope,
218 inner: inner.into(),
219 })
220 }
221}
222
223struct State<'a, A: HalApi> {
224 binder: Binder<A>,
225 pipeline: Option<id::ComputePipelineId>,
226 scope: UsageScope<'a, A>,
227 debug_scope_depth: u32,
228}
229
230impl<'a, A: HalApi> State<'a, A> {
231 fn is_ready(&self) -> Result<(), DispatchError> {
232 let bind_mask = self.binder.invalid_mask();
233 if bind_mask != 0 {
234 let index = bind_mask.trailing_zeros();
236
237 return Err(DispatchError::IncompatibleBindGroup {
238 index,
239 diff: self.binder.bgl_diff(),
240 });
241 }
242 if self.pipeline.is_none() {
243 return Err(DispatchError::MissingPipeline);
244 }
245 self.binder.check_late_buffer_bindings()?;
246
247 Ok(())
248 }
249
250 fn flush_states(
253 &mut self,
254 raw_encoder: &mut A::CommandEncoder,
255 base_trackers: &mut Tracker<A>,
256 bind_group_guard: &Storage<BindGroup<A>>,
257 indirect_buffer: Option<TrackerIndex>,
258 snatch_guard: &SnatchGuard,
259 ) -> Result<(), UsageConflict> {
260 for id in self.binder.list_active() {
261 unsafe { self.scope.merge_bind_group(&bind_group_guard[id].used)? };
262 }
265
266 for id in self.binder.list_active() {
267 unsafe {
268 base_trackers.set_and_remove_from_usage_scope_sparse(
269 &mut self.scope,
270 &bind_group_guard[id].used,
271 )
272 }
273 }
274
275 unsafe {
277 base_trackers
278 .buffers
279 .set_and_remove_from_usage_scope_sparse(&mut self.scope.buffers, indirect_buffer);
280 }
281
282 log::trace!("Encoding dispatch barriers");
283
284 CommandBuffer::drain_barriers(raw_encoder, base_trackers, snatch_guard);
285 Ok(())
286 }
287}
288
289impl Global {
292 pub fn command_encoder_run_compute_pass<A: HalApi>(
293 &self,
294 encoder_id: id::CommandEncoderId,
295 pass: &ComputePass,
296 ) -> Result<(), ComputePassError> {
297 self.command_encoder_run_compute_pass_with_unresolved_commands::<A>(
299 encoder_id,
300 pass.base.as_ref(),
301 pass.timestamp_writes.as_ref(),
302 )
303 }
304
305 #[doc(hidden)]
306 pub fn command_encoder_run_compute_pass_with_unresolved_commands<A: HalApi>(
307 &self,
308 encoder_id: id::CommandEncoderId,
309 base: BasePassRef<ComputeCommand>,
310 timestamp_writes: Option<&ComputePassTimestampWrites>,
311 ) -> Result<(), ComputePassError> {
312 let resolved_commands =
313 ComputeCommand::resolve_compute_command_ids(A::hub(self), base.commands)?;
314
315 self.command_encoder_run_compute_pass_impl::<A>(
316 encoder_id,
317 BasePassRef {
318 label: base.label,
319 commands: &resolved_commands,
320 dynamic_offsets: base.dynamic_offsets,
321 string_data: base.string_data,
322 push_constant_data: base.push_constant_data,
323 },
324 timestamp_writes,
325 )
326 }
327
328 fn command_encoder_run_compute_pass_impl<A: HalApi>(
329 &self,
330 encoder_id: id::CommandEncoderId,
331 base: BasePassRef<ArcComputeCommand<A>>,
332 timestamp_writes: Option<&ComputePassTimestampWrites>,
333 ) -> Result<(), ComputePassError> {
334 profiling::scope!("CommandEncoder::run_compute_pass");
335 let pass_scope = PassErrorScope::Pass(encoder_id);
336
337 let hub = A::hub(self);
338
339 let cmd_buf: Arc<CommandBuffer<A>> =
340 CommandBuffer::get_encoder(hub, encoder_id).map_pass_err(pass_scope)?;
341 let device = &cmd_buf.device;
342 if !device.is_valid() {
343 return Err(ComputePassErrorInner::InvalidDevice(
344 cmd_buf.device.as_info().id(),
345 ))
346 .map_pass_err(pass_scope);
347 }
348
349 let mut cmd_buf_data = cmd_buf.data.lock();
350 let cmd_buf_data = cmd_buf_data.as_mut().unwrap();
351
352 #[cfg(feature = "trace")]
353 if let Some(ref mut list) = cmd_buf_data.commands {
354 list.push(crate::device::trace::Command::RunComputePass {
355 base: BasePass {
356 label: base.label.map(str::to_string),
357 commands: base.commands.iter().map(Into::into).collect(),
358 dynamic_offsets: base.dynamic_offsets.to_vec(),
359 string_data: base.string_data.to_vec(),
360 push_constant_data: base.push_constant_data.to_vec(),
361 },
362 timestamp_writes: timestamp_writes.cloned(),
363 });
364 }
365
366 let encoder = &mut cmd_buf_data.encoder;
367 let status = &mut cmd_buf_data.status;
368 let tracker = &mut cmd_buf_data.trackers;
369 let buffer_memory_init_actions = &mut cmd_buf_data.buffer_memory_init_actions;
370 let texture_memory_actions = &mut cmd_buf_data.texture_memory_actions;
371
372 encoder.close().map_pass_err(pass_scope)?;
376 *status = CommandEncoderStatus::Error;
378 let raw = encoder.open().map_pass_err(pass_scope)?;
379
380 let bind_group_guard = hub.bind_groups.read();
381 let query_set_guard = hub.query_sets.read();
382
383 let mut state = State {
384 binder: Binder::new(),
385 pipeline: None,
386 scope: device.new_usage_scope(),
387 debug_scope_depth: 0,
388 };
389 let mut temp_offsets = Vec::new();
390 let mut dynamic_offset_count = 0;
391 let mut string_offset = 0;
392 let mut active_query = None;
393
394 let timestamp_writes = if let Some(tw) = timestamp_writes {
395 let query_set: &resource::QuerySet<A> = tracker
396 .query_sets
397 .add_single(&*query_set_guard, tw.query_set)
398 .ok_or(ComputePassErrorInner::InvalidQuerySet(tw.query_set))
399 .map_pass_err(pass_scope)?;
400
401 let range = if let (Some(index_a), Some(index_b)) =
404 (tw.beginning_of_pass_write_index, tw.end_of_pass_write_index)
405 {
406 Some(index_a.min(index_b)..index_a.max(index_b) + 1)
407 } else {
408 tw.beginning_of_pass_write_index
409 .or(tw.end_of_pass_write_index)
410 .map(|i| i..i + 1)
411 };
412 if let Some(range) = range {
415 unsafe {
416 raw.reset_queries(query_set.raw.as_ref().unwrap(), range);
417 }
418 }
419
420 Some(hal::ComputePassTimestampWrites {
421 query_set: query_set.raw.as_ref().unwrap(),
422 beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
423 end_of_pass_write_index: tw.end_of_pass_write_index,
424 })
425 } else {
426 None
427 };
428
429 let snatch_guard = device.snatchable_lock.read();
430
431 let indices = &device.tracker_indices;
432 tracker.buffers.set_size(indices.buffers.size());
433 tracker.textures.set_size(indices.textures.size());
434 tracker.bind_groups.set_size(indices.bind_groups.size());
435 tracker
436 .compute_pipelines
437 .set_size(indices.compute_pipelines.size());
438 tracker.query_sets.set_size(indices.query_sets.size());
439
440 let discard_hal_labels = self
441 .instance
442 .flags
443 .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS);
444 let hal_desc = hal::ComputePassDescriptor {
445 label: hal_label(base.label, self.instance.flags),
446 timestamp_writes,
447 };
448
449 unsafe {
450 raw.begin_compute_pass(&hal_desc);
451 }
452
453 let mut intermediate_trackers = Tracker::<A>::new();
454
455 let mut pending_discard_init_fixups = SurfacesInDiscardState::new();
458
459 for command in base.commands {
462 match command {
463 ArcComputeCommand::SetBindGroup {
464 index,
465 num_dynamic_offsets,
466 bind_group,
467 } => {
468 let scope = PassErrorScope::SetBindGroup(bind_group.as_info().id());
469
470 let max_bind_groups = cmd_buf.limits.max_bind_groups;
471 if index >= &max_bind_groups {
472 return Err(ComputePassErrorInner::BindGroupIndexOutOfRange {
473 index: *index,
474 max: max_bind_groups,
475 })
476 .map_pass_err(scope);
477 }
478
479 temp_offsets.clear();
480 temp_offsets.extend_from_slice(
481 &base.dynamic_offsets
482 [dynamic_offset_count..dynamic_offset_count + num_dynamic_offsets],
483 );
484 dynamic_offset_count += num_dynamic_offsets;
485
486 let bind_group = tracker.bind_groups.insert_single(bind_group.clone());
487 bind_group
488 .validate_dynamic_bindings(*index, &temp_offsets, &cmd_buf.limits)
489 .map_pass_err(scope)?;
490
491 buffer_memory_init_actions.extend(
492 bind_group.used_buffer_ranges.iter().filter_map(|action| {
493 action
494 .buffer
495 .initialization_status
496 .read()
497 .check_action(action)
498 }),
499 );
500
501 for action in bind_group.used_texture_ranges.iter() {
502 pending_discard_init_fixups
503 .extend(texture_memory_actions.register_init_action(action));
504 }
505
506 let pipeline_layout = state.binder.pipeline_layout.clone();
507 let entries =
508 state
509 .binder
510 .assign_group(*index as usize, bind_group, &temp_offsets);
511 if !entries.is_empty() && pipeline_layout.is_some() {
512 let pipeline_layout = pipeline_layout.as_ref().unwrap().raw();
513 for (i, e) in entries.iter().enumerate() {
514 if let Some(group) = e.group.as_ref() {
515 let raw_bg = group
516 .raw(&snatch_guard)
517 .ok_or(ComputePassErrorInner::InvalidBindGroup(i as u32))
518 .map_pass_err(scope)?;
519 unsafe {
520 raw.set_bind_group(
521 pipeline_layout,
522 index + i as u32,
523 raw_bg,
524 &e.dynamic_offsets,
525 );
526 }
527 }
528 }
529 }
530 }
531 ArcComputeCommand::SetPipeline(pipeline) => {
532 let pipeline_id = pipeline.as_info().id();
533 let scope = PassErrorScope::SetPipelineCompute(pipeline_id);
534
535 state.pipeline = Some(pipeline_id);
536
537 tracker.compute_pipelines.insert_single(pipeline.clone());
538
539 unsafe {
540 raw.set_compute_pipeline(pipeline.raw());
541 }
542
543 if state.binder.pipeline_layout.is_none()
545 || !state
546 .binder
547 .pipeline_layout
548 .as_ref()
549 .unwrap()
550 .is_equal(&pipeline.layout)
551 {
552 let (start_index, entries) = state.binder.change_pipeline_layout(
553 &pipeline.layout,
554 &pipeline.late_sized_buffer_groups,
555 );
556 if !entries.is_empty() {
557 for (i, e) in entries.iter().enumerate() {
558 if let Some(group) = e.group.as_ref() {
559 let raw_bg = group
560 .raw(&snatch_guard)
561 .ok_or(ComputePassErrorInner::InvalidBindGroup(i as u32))
562 .map_pass_err(scope)?;
563 unsafe {
564 raw.set_bind_group(
565 pipeline.layout.raw(),
566 start_index as u32 + i as u32,
567 raw_bg,
568 &e.dynamic_offsets,
569 );
570 }
571 }
572 }
573 }
574
575 let non_overlapping = super::bind::compute_nonoverlapping_ranges(
577 &pipeline.layout.push_constant_ranges,
578 );
579 for range in non_overlapping {
580 let offset = range.range.start;
581 let size_bytes = range.range.end - offset;
582 super::push_constant_clear(
583 offset,
584 size_bytes,
585 |clear_offset, clear_data| unsafe {
586 raw.set_push_constants(
587 pipeline.layout.raw(),
588 wgt::ShaderStages::COMPUTE,
589 clear_offset,
590 clear_data,
591 );
592 },
593 );
594 }
595 }
596 }
597 ArcComputeCommand::SetPushConstant {
598 offset,
599 size_bytes,
600 values_offset,
601 } => {
602 let scope = PassErrorScope::SetPushConstant;
603
604 let end_offset_bytes = offset + size_bytes;
605 let values_end_offset =
606 (values_offset + size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
607 let data_slice =
608 &base.push_constant_data[(*values_offset as usize)..values_end_offset];
609
610 let pipeline_layout = state
611 .binder
612 .pipeline_layout
613 .as_ref()
614 .ok_or(ComputePassErrorInner::Dispatch(
616 DispatchError::MissingPipeline,
617 ))
618 .map_pass_err(scope)?;
619
620 pipeline_layout
621 .validate_push_constant_ranges(
622 wgt::ShaderStages::COMPUTE,
623 *offset,
624 end_offset_bytes,
625 )
626 .map_pass_err(scope)?;
627
628 unsafe {
629 raw.set_push_constants(
630 pipeline_layout.raw(),
631 wgt::ShaderStages::COMPUTE,
632 *offset,
633 data_slice,
634 );
635 }
636 }
637 ArcComputeCommand::Dispatch(groups) => {
638 let scope = PassErrorScope::Dispatch {
639 indirect: false,
640 pipeline: state.pipeline,
641 };
642 state.is_ready().map_pass_err(scope)?;
643
644 state
645 .flush_states(
646 raw,
647 &mut intermediate_trackers,
648 &*bind_group_guard,
649 None,
650 &snatch_guard,
651 )
652 .map_pass_err(scope)?;
653
654 let groups_size_limit = cmd_buf.limits.max_compute_workgroups_per_dimension;
655
656 if groups[0] > groups_size_limit
657 || groups[1] > groups_size_limit
658 || groups[2] > groups_size_limit
659 {
660 return Err(ComputePassErrorInner::Dispatch(
661 DispatchError::InvalidGroupSize {
662 current: *groups,
663 limit: groups_size_limit,
664 },
665 ))
666 .map_pass_err(scope);
667 }
668
669 unsafe {
670 raw.dispatch(*groups);
671 }
672 }
673 ArcComputeCommand::DispatchIndirect { buffer, offset } => {
674 let buffer_id = buffer.as_info().id();
675 let scope = PassErrorScope::Dispatch {
676 indirect: true,
677 pipeline: state.pipeline,
678 };
679
680 state.is_ready().map_pass_err(scope)?;
681
682 device
683 .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)
684 .map_pass_err(scope)?;
685
686 state
687 .scope
688 .buffers
689 .insert_merge_single(buffer.clone(), hal::BufferUses::INDIRECT)
690 .map_pass_err(scope)?;
691 check_buffer_usage(buffer_id, buffer.usage, wgt::BufferUsages::INDIRECT)
692 .map_pass_err(scope)?;
693
694 let end_offset = offset + mem::size_of::<wgt::DispatchIndirectArgs>() as u64;
695 if end_offset > buffer.size {
696 return Err(ComputePassErrorInner::IndirectBufferOverrun {
697 offset: *offset,
698 end_offset,
699 buffer_size: buffer.size,
700 })
701 .map_pass_err(scope);
702 }
703
704 let buf_raw = buffer
705 .raw
706 .get(&snatch_guard)
707 .ok_or(ComputePassErrorInner::InvalidIndirectBuffer(buffer_id))
708 .map_pass_err(scope)?;
709
710 let stride = 3 * 4; buffer_memory_init_actions.extend(
713 buffer.initialization_status.read().create_action(
714 buffer,
715 *offset..(*offset + stride),
716 MemoryInitKind::NeedsInitializedMemory,
717 ),
718 );
719
720 state
721 .flush_states(
722 raw,
723 &mut intermediate_trackers,
724 &*bind_group_guard,
725 Some(buffer.as_info().tracker_index()),
726 &snatch_guard,
727 )
728 .map_pass_err(scope)?;
729 unsafe {
730 raw.dispatch_indirect(buf_raw, *offset);
731 }
732 }
733 ArcComputeCommand::PushDebugGroup { color: _, len } => {
734 state.debug_scope_depth += 1;
735 if !discard_hal_labels {
736 let label =
737 str::from_utf8(&base.string_data[string_offset..string_offset + len])
738 .unwrap();
739 unsafe {
740 raw.begin_debug_marker(label);
741 }
742 }
743 string_offset += len;
744 }
745 ArcComputeCommand::PopDebugGroup => {
746 let scope = PassErrorScope::PopDebugGroup;
747
748 if state.debug_scope_depth == 0 {
749 return Err(ComputePassErrorInner::InvalidPopDebugGroup)
750 .map_pass_err(scope);
751 }
752 state.debug_scope_depth -= 1;
753 if !discard_hal_labels {
754 unsafe {
755 raw.end_debug_marker();
756 }
757 }
758 }
759 ArcComputeCommand::InsertDebugMarker { color: _, len } => {
760 if !discard_hal_labels {
761 let label =
762 str::from_utf8(&base.string_data[string_offset..string_offset + len])
763 .unwrap();
764 unsafe { raw.insert_debug_marker(label) }
765 }
766 string_offset += len;
767 }
768 ArcComputeCommand::WriteTimestamp {
769 query_set,
770 query_index,
771 } => {
772 let query_set_id = query_set.as_info().id();
773 let scope = PassErrorScope::WriteTimestamp;
774
775 device
776 .require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_PASSES)
777 .map_pass_err(scope)?;
778
779 let query_set = tracker.query_sets.insert_single(query_set.clone());
780
781 query_set
782 .validate_and_write_timestamp(raw, query_set_id, *query_index, None)
783 .map_pass_err(scope)?;
784 }
785 ArcComputeCommand::BeginPipelineStatisticsQuery {
786 query_set,
787 query_index,
788 } => {
789 let query_set_id = query_set.as_info().id();
790 let scope = PassErrorScope::BeginPipelineStatisticsQuery;
791
792 let query_set = tracker.query_sets.insert_single(query_set.clone());
793
794 query_set
795 .validate_and_begin_pipeline_statistics_query(
796 raw,
797 query_set_id,
798 *query_index,
799 None,
800 &mut active_query,
801 )
802 .map_pass_err(scope)?;
803 }
804 ArcComputeCommand::EndPipelineStatisticsQuery => {
805 let scope = PassErrorScope::EndPipelineStatisticsQuery;
806
807 end_pipeline_statistics_query(raw, &*query_set_guard, &mut active_query)
808 .map_pass_err(scope)?;
809 }
810 }
811 }
812
813 unsafe {
814 raw.end_compute_pass();
815 }
816
817 *status = CommandEncoderStatus::Recording;
820
821 encoder.close().map_pass_err(pass_scope)?;
823
824 let transit = encoder.open().map_pass_err(pass_scope)?;
828 fixup_discarded_surfaces(
829 pending_discard_init_fixups.into_iter(),
830 transit,
831 &mut tracker.textures,
832 device,
833 &snatch_guard,
834 );
835 CommandBuffer::insert_barriers_from_tracker(
836 transit,
837 tracker,
838 &intermediate_trackers,
839 &snatch_guard,
840 );
841 encoder.close_and_swap().map_pass_err(pass_scope)?;
843
844 Ok(())
845 }
846}
847
848pub mod compute_commands {
849 use super::{ComputeCommand, ComputePass};
850 use crate::id;
851 use std::convert::TryInto;
852 use wgt::{BufferAddress, DynamicOffset};
853
854 pub fn wgpu_compute_pass_set_bind_group(
855 pass: &mut ComputePass,
856 index: u32,
857 bind_group_id: id::BindGroupId,
858 offsets: &[DynamicOffset],
859 ) {
860 let redundant = pass.current_bind_groups.set_and_check_redundant(
861 bind_group_id,
862 index,
863 &mut pass.base.dynamic_offsets,
864 offsets,
865 );
866
867 if redundant {
868 return;
869 }
870
871 pass.base.commands.push(ComputeCommand::SetBindGroup {
872 index,
873 num_dynamic_offsets: offsets.len(),
874 bind_group_id,
875 });
876 }
877
878 pub fn wgpu_compute_pass_set_pipeline(
879 pass: &mut ComputePass,
880 pipeline_id: id::ComputePipelineId,
881 ) {
882 if pass.current_pipeline.set_and_check_redundant(pipeline_id) {
883 return;
884 }
885
886 pass.base
887 .commands
888 .push(ComputeCommand::SetPipeline(pipeline_id));
889 }
890
891 pub fn wgpu_compute_pass_set_push_constant(pass: &mut ComputePass, offset: u32, data: &[u8]) {
892 assert_eq!(
893 offset & (wgt::PUSH_CONSTANT_ALIGNMENT - 1),
894 0,
895 "Push constant offset must be aligned to 4 bytes."
896 );
897 assert_eq!(
898 data.len() as u32 & (wgt::PUSH_CONSTANT_ALIGNMENT - 1),
899 0,
900 "Push constant size must be aligned to 4 bytes."
901 );
902 let value_offset = pass.base.push_constant_data.len().try_into().expect(
903 "Ran out of push constant space. Don't set 4gb of push constants per ComputePass.",
904 );
905
906 pass.base.push_constant_data.extend(
907 data.chunks_exact(wgt::PUSH_CONSTANT_ALIGNMENT as usize)
908 .map(|arr| u32::from_ne_bytes([arr[0], arr[1], arr[2], arr[3]])),
909 );
910
911 pass.base.commands.push(ComputeCommand::SetPushConstant {
912 offset,
913 size_bytes: data.len() as u32,
914 values_offset: value_offset,
915 });
916 }
917
918 pub fn wgpu_compute_pass_dispatch_workgroups(
919 pass: &mut ComputePass,
920 groups_x: u32,
921 groups_y: u32,
922 groups_z: u32,
923 ) {
924 pass.base
925 .commands
926 .push(ComputeCommand::Dispatch([groups_x, groups_y, groups_z]));
927 }
928
929 pub fn wgpu_compute_pass_dispatch_workgroups_indirect(
930 pass: &mut ComputePass,
931 buffer_id: id::BufferId,
932 offset: BufferAddress,
933 ) {
934 pass.base
935 .commands
936 .push(ComputeCommand::DispatchIndirect { buffer_id, offset });
937 }
938
939 pub fn wgpu_compute_pass_push_debug_group(pass: &mut ComputePass, label: &str, color: u32) {
940 let bytes = label.as_bytes();
941 pass.base.string_data.extend_from_slice(bytes);
942
943 pass.base.commands.push(ComputeCommand::PushDebugGroup {
944 color,
945 len: bytes.len(),
946 });
947 }
948
949 pub fn wgpu_compute_pass_pop_debug_group(pass: &mut ComputePass) {
950 pass.base.commands.push(ComputeCommand::PopDebugGroup);
951 }
952
953 pub fn wgpu_compute_pass_insert_debug_marker(pass: &mut ComputePass, label: &str, color: u32) {
954 let bytes = label.as_bytes();
955 pass.base.string_data.extend_from_slice(bytes);
956
957 pass.base.commands.push(ComputeCommand::InsertDebugMarker {
958 color,
959 len: bytes.len(),
960 });
961 }
962
963 pub fn wgpu_compute_pass_write_timestamp(
964 pass: &mut ComputePass,
965 query_set_id: id::QuerySetId,
966 query_index: u32,
967 ) {
968 pass.base.commands.push(ComputeCommand::WriteTimestamp {
969 query_set_id,
970 query_index,
971 });
972 }
973
974 pub fn wgpu_compute_pass_begin_pipeline_statistics_query(
975 pass: &mut ComputePass,
976 query_set_id: id::QuerySetId,
977 query_index: u32,
978 ) {
979 pass.base
980 .commands
981 .push(ComputeCommand::BeginPipelineStatisticsQuery {
982 query_set_id,
983 query_index,
984 });
985 }
986
987 pub fn wgpu_compute_pass_end_pipeline_statistics_query(pass: &mut ComputePass) {
988 pass.base
989 .commands
990 .push(ComputeCommand::EndPipelineStatisticsQuery);
991 }
992}