1use std::sync::Arc;
2
3use crate::{
4 binding_model::BindGroup,
5 hal_api::HalApi,
6 id,
7 pipeline::ComputePipeline,
8 resource::{Buffer, QuerySet},
9};
10
11use super::{ComputePassError, ComputePassErrorInner, PassErrorScope};
12
13#[derive(Clone, Copy, Debug)]
14#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
15pub enum ComputeCommand {
16 SetBindGroup {
17 index: u32,
18 num_dynamic_offsets: usize,
19 bind_group_id: id::BindGroupId,
20 },
21
22 SetPipeline(id::ComputePipelineId),
23
24 SetPushConstant {
26 offset: u32,
29
30 size_bytes: u32,
32
33 values_offset: u32,
39 },
40
41 Dispatch([u32; 3]),
42
43 DispatchIndirect {
44 buffer_id: id::BufferId,
45 offset: wgt::BufferAddress,
46 },
47
48 PushDebugGroup {
49 color: u32,
50 len: usize,
51 },
52
53 PopDebugGroup,
54
55 InsertDebugMarker {
56 color: u32,
57 len: usize,
58 },
59
60 WriteTimestamp {
61 query_set_id: id::QuerySetId,
62 query_index: u32,
63 },
64
65 BeginPipelineStatisticsQuery {
66 query_set_id: id::QuerySetId,
67 query_index: u32,
68 },
69
70 EndPipelineStatisticsQuery,
71}
72
73impl ComputeCommand {
74 pub fn resolve_compute_command_ids<A: HalApi>(
79 hub: &crate::hub::Hub<A>,
80 commands: &[ComputeCommand],
81 ) -> Result<Vec<ArcComputeCommand<A>>, ComputePassError> {
82 let buffers_guard = hub.buffers.read();
83 let bind_group_guard = hub.bind_groups.read();
84 let query_set_guard = hub.query_sets.read();
85 let pipelines_guard = hub.compute_pipelines.read();
86
87 let resolved_commands: Vec<ArcComputeCommand<A>> = commands
88 .iter()
89 .map(|c| -> Result<ArcComputeCommand<A>, ComputePassError> {
90 Ok(match *c {
91 ComputeCommand::SetBindGroup {
92 index,
93 num_dynamic_offsets,
94 bind_group_id,
95 } => ArcComputeCommand::SetBindGroup {
96 index,
97 num_dynamic_offsets,
98 bind_group: bind_group_guard.get_owned(bind_group_id).map_err(|_| {
99 ComputePassError {
100 scope: PassErrorScope::SetBindGroup(bind_group_id),
101 inner: ComputePassErrorInner::InvalidBindGroup(index),
102 }
103 })?,
104 },
105
106 ComputeCommand::SetPipeline(pipeline_id) => ArcComputeCommand::SetPipeline(
107 pipelines_guard
108 .get_owned(pipeline_id)
109 .map_err(|_| ComputePassError {
110 scope: PassErrorScope::SetPipelineCompute(pipeline_id),
111 inner: ComputePassErrorInner::InvalidPipeline(pipeline_id),
112 })?,
113 ),
114
115 ComputeCommand::SetPushConstant {
116 offset,
117 size_bytes,
118 values_offset,
119 } => ArcComputeCommand::SetPushConstant {
120 offset,
121 size_bytes,
122 values_offset,
123 },
124
125 ComputeCommand::Dispatch(dim) => ArcComputeCommand::Dispatch(dim),
126
127 ComputeCommand::DispatchIndirect { buffer_id, offset } => {
128 ArcComputeCommand::DispatchIndirect {
129 buffer: buffers_guard.get_owned(buffer_id).map_err(|_| {
130 ComputePassError {
131 scope: PassErrorScope::Dispatch {
132 indirect: true,
133 pipeline: None, },
135 inner: ComputePassErrorInner::InvalidBuffer(buffer_id),
136 }
137 })?,
138 offset,
139 }
140 }
141
142 ComputeCommand::PushDebugGroup { color, len } => {
143 ArcComputeCommand::PushDebugGroup { color, len }
144 }
145
146 ComputeCommand::PopDebugGroup => ArcComputeCommand::PopDebugGroup,
147
148 ComputeCommand::InsertDebugMarker { color, len } => {
149 ArcComputeCommand::InsertDebugMarker { color, len }
150 }
151
152 ComputeCommand::WriteTimestamp {
153 query_set_id,
154 query_index,
155 } => ArcComputeCommand::WriteTimestamp {
156 query_set: query_set_guard.get_owned(query_set_id).map_err(|_| {
157 ComputePassError {
158 scope: PassErrorScope::WriteTimestamp,
159 inner: ComputePassErrorInner::InvalidQuerySet(query_set_id),
160 }
161 })?,
162 query_index,
163 },
164
165 ComputeCommand::BeginPipelineStatisticsQuery {
166 query_set_id,
167 query_index,
168 } => ArcComputeCommand::BeginPipelineStatisticsQuery {
169 query_set: query_set_guard.get_owned(query_set_id).map_err(|_| {
170 ComputePassError {
171 scope: PassErrorScope::BeginPipelineStatisticsQuery,
172 inner: ComputePassErrorInner::InvalidQuerySet(query_set_id),
173 }
174 })?,
175 query_index,
176 },
177
178 ComputeCommand::EndPipelineStatisticsQuery => {
179 ArcComputeCommand::EndPipelineStatisticsQuery
180 }
181 })
182 })
183 .collect::<Result<Vec<_>, ComputePassError>>()?;
184 Ok(resolved_commands)
185 }
186}
187
188#[derive(Clone, Debug)]
190pub enum ArcComputeCommand<A: HalApi> {
191 SetBindGroup {
192 index: u32,
193 num_dynamic_offsets: usize,
194 bind_group: Arc<BindGroup<A>>,
195 },
196
197 SetPipeline(Arc<ComputePipeline<A>>),
198
199 SetPushConstant {
201 offset: u32,
204
205 size_bytes: u32,
207
208 values_offset: u32,
214 },
215
216 Dispatch([u32; 3]),
217
218 DispatchIndirect {
219 buffer: Arc<Buffer<A>>,
220 offset: wgt::BufferAddress,
221 },
222
223 PushDebugGroup {
224 color: u32,
225 len: usize,
226 },
227
228 PopDebugGroup,
229
230 InsertDebugMarker {
231 color: u32,
232 len: usize,
233 },
234
235 WriteTimestamp {
236 query_set: Arc<QuerySet<A>>,
237 query_index: u32,
238 },
239
240 BeginPipelineStatisticsQuery {
241 query_set: Arc<QuerySet<A>>,
242 query_index: u32,
243 },
244
245 EndPipelineStatisticsQuery,
246}
247
248#[cfg(feature = "trace")]
249impl<A: HalApi> From<&ArcComputeCommand<A>> for ComputeCommand {
250 fn from(value: &ArcComputeCommand<A>) -> Self {
251 use crate::resource::Resource as _;
252
253 match value {
254 ArcComputeCommand::SetBindGroup {
255 index,
256 num_dynamic_offsets,
257 bind_group,
258 } => ComputeCommand::SetBindGroup {
259 index: *index,
260 num_dynamic_offsets: *num_dynamic_offsets,
261 bind_group_id: bind_group.as_info().id(),
262 },
263
264 ArcComputeCommand::SetPipeline(pipeline) => {
265 ComputeCommand::SetPipeline(pipeline.as_info().id())
266 }
267
268 ArcComputeCommand::SetPushConstant {
269 offset,
270 size_bytes,
271 values_offset,
272 } => ComputeCommand::SetPushConstant {
273 offset: *offset,
274 size_bytes: *size_bytes,
275 values_offset: *values_offset,
276 },
277
278 ArcComputeCommand::Dispatch(dim) => ComputeCommand::Dispatch(*dim),
279
280 ArcComputeCommand::DispatchIndirect { buffer, offset } => {
281 ComputeCommand::DispatchIndirect {
282 buffer_id: buffer.as_info().id(),
283 offset: *offset,
284 }
285 }
286
287 ArcComputeCommand::PushDebugGroup { color, len } => ComputeCommand::PushDebugGroup {
288 color: *color,
289 len: *len,
290 },
291
292 ArcComputeCommand::PopDebugGroup => ComputeCommand::PopDebugGroup,
293
294 ArcComputeCommand::InsertDebugMarker { color, len } => {
295 ComputeCommand::InsertDebugMarker {
296 color: *color,
297 len: *len,
298 }
299 }
300
301 ArcComputeCommand::WriteTimestamp {
302 query_set,
303 query_index,
304 } => ComputeCommand::WriteTimestamp {
305 query_set_id: query_set.as_info().id(),
306 query_index: *query_index,
307 },
308
309 ArcComputeCommand::BeginPipelineStatisticsQuery {
310 query_set,
311 query_index,
312 } => ComputeCommand::BeginPipelineStatisticsQuery {
313 query_set_id: query_set.as_info().id(),
314 query_index: *query_index,
315 },
316
317 ArcComputeCommand::EndPipelineStatisticsQuery => {
318 ComputeCommand::EndPipelineStatisticsQuery
319 }
320 }
321 }
322}