wgpu_core/command/
compute_command.rs

1use std::sync::Arc;
2
3use crate::{
4    binding_model::BindGroup,
5    hal_api::HalApi,
6    id,
7    pipeline::ComputePipeline,
8    resource::{Buffer, QuerySet},
9};
10
11use super::{ComputePassError, ComputePassErrorInner, PassErrorScope};
12
13#[derive(Clone, Copy, Debug)]
14#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
15pub enum ComputeCommand {
16    SetBindGroup {
17        index: u32,
18        num_dynamic_offsets: usize,
19        bind_group_id: id::BindGroupId,
20    },
21
22    SetPipeline(id::ComputePipelineId),
23
24    /// Set a range of push constants to values stored in `push_constant_data`.
25    SetPushConstant {
26        /// The byte offset within the push constant storage to write to. This
27        /// must be a multiple of four.
28        offset: u32,
29
30        /// The number of bytes to write. This must be a multiple of four.
31        size_bytes: u32,
32
33        /// Index in `push_constant_data` of the start of the data
34        /// to be written.
35        ///
36        /// Note: this is not a byte offset like `offset`. Rather, it is the
37        /// index of the first `u32` element in `push_constant_data` to read.
38        values_offset: u32,
39    },
40
41    Dispatch([u32; 3]),
42
43    DispatchIndirect {
44        buffer_id: id::BufferId,
45        offset: wgt::BufferAddress,
46    },
47
48    PushDebugGroup {
49        color: u32,
50        len: usize,
51    },
52
53    PopDebugGroup,
54
55    InsertDebugMarker {
56        color: u32,
57        len: usize,
58    },
59
60    WriteTimestamp {
61        query_set_id: id::QuerySetId,
62        query_index: u32,
63    },
64
65    BeginPipelineStatisticsQuery {
66        query_set_id: id::QuerySetId,
67        query_index: u32,
68    },
69
70    EndPipelineStatisticsQuery,
71}
72
73impl ComputeCommand {
74    /// Resolves all ids in a list of commands into the corresponding resource Arc.
75    ///
76    // TODO: Once resolving is done on-the-fly during recording, this function should be only needed with the replay feature:
77    // #[cfg(feature = "replay")]
78    pub fn resolve_compute_command_ids<A: HalApi>(
79        hub: &crate::hub::Hub<A>,
80        commands: &[ComputeCommand],
81    ) -> Result<Vec<ArcComputeCommand<A>>, ComputePassError> {
82        let buffers_guard = hub.buffers.read();
83        let bind_group_guard = hub.bind_groups.read();
84        let query_set_guard = hub.query_sets.read();
85        let pipelines_guard = hub.compute_pipelines.read();
86
87        let resolved_commands: Vec<ArcComputeCommand<A>> = commands
88            .iter()
89            .map(|c| -> Result<ArcComputeCommand<A>, ComputePassError> {
90                Ok(match *c {
91                    ComputeCommand::SetBindGroup {
92                        index,
93                        num_dynamic_offsets,
94                        bind_group_id,
95                    } => ArcComputeCommand::SetBindGroup {
96                        index,
97                        num_dynamic_offsets,
98                        bind_group: bind_group_guard.get_owned(bind_group_id).map_err(|_| {
99                            ComputePassError {
100                                scope: PassErrorScope::SetBindGroup(bind_group_id),
101                                inner: ComputePassErrorInner::InvalidBindGroup(index),
102                            }
103                        })?,
104                    },
105
106                    ComputeCommand::SetPipeline(pipeline_id) => ArcComputeCommand::SetPipeline(
107                        pipelines_guard
108                            .get_owned(pipeline_id)
109                            .map_err(|_| ComputePassError {
110                                scope: PassErrorScope::SetPipelineCompute(pipeline_id),
111                                inner: ComputePassErrorInner::InvalidPipeline(pipeline_id),
112                            })?,
113                    ),
114
115                    ComputeCommand::SetPushConstant {
116                        offset,
117                        size_bytes,
118                        values_offset,
119                    } => ArcComputeCommand::SetPushConstant {
120                        offset,
121                        size_bytes,
122                        values_offset,
123                    },
124
125                    ComputeCommand::Dispatch(dim) => ArcComputeCommand::Dispatch(dim),
126
127                    ComputeCommand::DispatchIndirect { buffer_id, offset } => {
128                        ArcComputeCommand::DispatchIndirect {
129                            buffer: buffers_guard.get_owned(buffer_id).map_err(|_| {
130                                ComputePassError {
131                                    scope: PassErrorScope::Dispatch {
132                                        indirect: true,
133                                        pipeline: None, // TODO: not used right now, but once we do the resolve during recording we can use this again.
134                                    },
135                                    inner: ComputePassErrorInner::InvalidBuffer(buffer_id),
136                                }
137                            })?,
138                            offset,
139                        }
140                    }
141
142                    ComputeCommand::PushDebugGroup { color, len } => {
143                        ArcComputeCommand::PushDebugGroup { color, len }
144                    }
145
146                    ComputeCommand::PopDebugGroup => ArcComputeCommand::PopDebugGroup,
147
148                    ComputeCommand::InsertDebugMarker { color, len } => {
149                        ArcComputeCommand::InsertDebugMarker { color, len }
150                    }
151
152                    ComputeCommand::WriteTimestamp {
153                        query_set_id,
154                        query_index,
155                    } => ArcComputeCommand::WriteTimestamp {
156                        query_set: query_set_guard.get_owned(query_set_id).map_err(|_| {
157                            ComputePassError {
158                                scope: PassErrorScope::WriteTimestamp,
159                                inner: ComputePassErrorInner::InvalidQuerySet(query_set_id),
160                            }
161                        })?,
162                        query_index,
163                    },
164
165                    ComputeCommand::BeginPipelineStatisticsQuery {
166                        query_set_id,
167                        query_index,
168                    } => ArcComputeCommand::BeginPipelineStatisticsQuery {
169                        query_set: query_set_guard.get_owned(query_set_id).map_err(|_| {
170                            ComputePassError {
171                                scope: PassErrorScope::BeginPipelineStatisticsQuery,
172                                inner: ComputePassErrorInner::InvalidQuerySet(query_set_id),
173                            }
174                        })?,
175                        query_index,
176                    },
177
178                    ComputeCommand::EndPipelineStatisticsQuery => {
179                        ArcComputeCommand::EndPipelineStatisticsQuery
180                    }
181                })
182            })
183            .collect::<Result<Vec<_>, ComputePassError>>()?;
184        Ok(resolved_commands)
185    }
186}
187
188/// Equivalent to `ComputeCommand` but the Ids resolved into resource Arcs.
189#[derive(Clone, Debug)]
190pub enum ArcComputeCommand<A: HalApi> {
191    SetBindGroup {
192        index: u32,
193        num_dynamic_offsets: usize,
194        bind_group: Arc<BindGroup<A>>,
195    },
196
197    SetPipeline(Arc<ComputePipeline<A>>),
198
199    /// Set a range of push constants to values stored in `push_constant_data`.
200    SetPushConstant {
201        /// The byte offset within the push constant storage to write to. This
202        /// must be a multiple of four.
203        offset: u32,
204
205        /// The number of bytes to write. This must be a multiple of four.
206        size_bytes: u32,
207
208        /// Index in `push_constant_data` of the start of the data
209        /// to be written.
210        ///
211        /// Note: this is not a byte offset like `offset`. Rather, it is the
212        /// index of the first `u32` element in `push_constant_data` to read.
213        values_offset: u32,
214    },
215
216    Dispatch([u32; 3]),
217
218    DispatchIndirect {
219        buffer: Arc<Buffer<A>>,
220        offset: wgt::BufferAddress,
221    },
222
223    PushDebugGroup {
224        color: u32,
225        len: usize,
226    },
227
228    PopDebugGroup,
229
230    InsertDebugMarker {
231        color: u32,
232        len: usize,
233    },
234
235    WriteTimestamp {
236        query_set: Arc<QuerySet<A>>,
237        query_index: u32,
238    },
239
240    BeginPipelineStatisticsQuery {
241        query_set: Arc<QuerySet<A>>,
242        query_index: u32,
243    },
244
245    EndPipelineStatisticsQuery,
246}
247
248#[cfg(feature = "trace")]
249impl<A: HalApi> From<&ArcComputeCommand<A>> for ComputeCommand {
250    fn from(value: &ArcComputeCommand<A>) -> Self {
251        use crate::resource::Resource as _;
252
253        match value {
254            ArcComputeCommand::SetBindGroup {
255                index,
256                num_dynamic_offsets,
257                bind_group,
258            } => ComputeCommand::SetBindGroup {
259                index: *index,
260                num_dynamic_offsets: *num_dynamic_offsets,
261                bind_group_id: bind_group.as_info().id(),
262            },
263
264            ArcComputeCommand::SetPipeline(pipeline) => {
265                ComputeCommand::SetPipeline(pipeline.as_info().id())
266            }
267
268            ArcComputeCommand::SetPushConstant {
269                offset,
270                size_bytes,
271                values_offset,
272            } => ComputeCommand::SetPushConstant {
273                offset: *offset,
274                size_bytes: *size_bytes,
275                values_offset: *values_offset,
276            },
277
278            ArcComputeCommand::Dispatch(dim) => ComputeCommand::Dispatch(*dim),
279
280            ArcComputeCommand::DispatchIndirect { buffer, offset } => {
281                ComputeCommand::DispatchIndirect {
282                    buffer_id: buffer.as_info().id(),
283                    offset: *offset,
284                }
285            }
286
287            ArcComputeCommand::PushDebugGroup { color, len } => ComputeCommand::PushDebugGroup {
288                color: *color,
289                len: *len,
290            },
291
292            ArcComputeCommand::PopDebugGroup => ComputeCommand::PopDebugGroup,
293
294            ArcComputeCommand::InsertDebugMarker { color, len } => {
295                ComputeCommand::InsertDebugMarker {
296                    color: *color,
297                    len: *len,
298                }
299            }
300
301            ArcComputeCommand::WriteTimestamp {
302                query_set,
303                query_index,
304            } => ComputeCommand::WriteTimestamp {
305                query_set_id: query_set.as_info().id(),
306                query_index: *query_index,
307            },
308
309            ArcComputeCommand::BeginPipelineStatisticsQuery {
310                query_set,
311                query_index,
312            } => ComputeCommand::BeginPipelineStatisticsQuery {
313                query_set_id: query_set.as_info().id(),
314                query_index: *query_index,
315            },
316
317            ArcComputeCommand::EndPipelineStatisticsQuery => {
318                ComputeCommand::EndPipelineStatisticsQuery
319            }
320        }
321    }
322}