bevy_render/render_resource/
pipeline_cache.rs

1use crate::{
2    render_resource::*,
3    renderer::{RenderAdapter, RenderDevice},
4    Extract,
5};
6use bevy_asset::{AssetEvent, AssetId, Assets};
7use bevy_ecs::system::{Res, ResMut};
8use bevy_ecs::{event::EventReader, system::Resource};
9use bevy_tasks::Task;
10use bevy_utils::hashbrown::hash_map::EntryRef;
11use bevy_utils::{
12    default,
13    tracing::{debug, error},
14    HashMap, HashSet,
15};
16use naga::valid::Capabilities;
17use std::{
18    borrow::Cow,
19    future::Future,
20    hash::Hash,
21    mem,
22    ops::Deref,
23    sync::{Arc, Mutex, PoisonError},
24};
25use thiserror::Error;
26#[cfg(feature = "shader_format_spirv")]
27use wgpu::util::make_spirv;
28use wgpu::{
29    DownlevelFlags, Features, PipelineCompilationOptions,
30    VertexBufferLayout as RawVertexBufferLayout,
31};
32
33use crate::render_resource::resource_macros::*;
34
35render_resource_wrapper!(ErasedShaderModule, wgpu::ShaderModule);
36render_resource_wrapper!(ErasedPipelineLayout, wgpu::PipelineLayout);
37
38/// A descriptor for a [`Pipeline`].
39///
40/// Used to store an heterogenous collection of render and compute pipeline descriptors together.
41#[derive(Debug)]
42pub enum PipelineDescriptor {
43    RenderPipelineDescriptor(Box<RenderPipelineDescriptor>),
44    ComputePipelineDescriptor(Box<ComputePipelineDescriptor>),
45}
46
47/// A pipeline defining the data layout and shader logic for a specific GPU task.
48///
49/// Used to store an heterogenous collection of render and compute pipelines together.
50#[derive(Debug)]
51pub enum Pipeline {
52    RenderPipeline(RenderPipeline),
53    ComputePipeline(ComputePipeline),
54}
55
56type CachedPipelineId = usize;
57
58/// Index of a cached render pipeline in a [`PipelineCache`].
59#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, PartialOrd, Ord)]
60pub struct CachedRenderPipelineId(CachedPipelineId);
61
62impl CachedRenderPipelineId {
63    /// An invalid cached render pipeline index, often used to initialize a variable.
64    pub const INVALID: Self = CachedRenderPipelineId(usize::MAX);
65
66    #[inline]
67    pub fn id(&self) -> usize {
68        self.0
69    }
70}
71
72/// Index of a cached compute pipeline in a [`PipelineCache`].
73#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
74pub struct CachedComputePipelineId(CachedPipelineId);
75
76impl CachedComputePipelineId {
77    /// An invalid cached compute pipeline index, often used to initialize a variable.
78    pub const INVALID: Self = CachedComputePipelineId(usize::MAX);
79
80    #[inline]
81    pub fn id(&self) -> usize {
82        self.0
83    }
84}
85
86pub struct CachedPipeline {
87    pub descriptor: PipelineDescriptor,
88    pub state: CachedPipelineState,
89}
90
91/// State of a cached pipeline inserted into a [`PipelineCache`].
92#[derive(Debug)]
93pub enum CachedPipelineState {
94    /// The pipeline GPU object is queued for creation.
95    Queued,
96    /// The pipeline GPU object is being created.
97    Creating(Task<Result<Pipeline, PipelineCacheError>>),
98    /// The pipeline GPU object was created successfully and is available (allocated on the GPU).
99    Ok(Pipeline),
100    /// An error occurred while trying to create the pipeline GPU object.
101    Err(PipelineCacheError),
102}
103
104impl CachedPipelineState {
105    /// Convenience method to "unwrap" a pipeline state into its underlying GPU object.
106    ///
107    /// # Returns
108    ///
109    /// The method returns the allocated pipeline GPU object.
110    ///
111    /// # Panics
112    ///
113    /// This method panics if the pipeline GPU object is not available, either because it is
114    /// pending creation or because an error occurred while attempting to create GPU object.
115    pub fn unwrap(&self) -> &Pipeline {
116        match self {
117            CachedPipelineState::Ok(pipeline) => pipeline,
118            CachedPipelineState::Queued => {
119                panic!("Pipeline has not been compiled yet. It is still in the 'Queued' state.")
120            }
121            CachedPipelineState::Creating(..) => {
122                panic!("Pipeline has not been compiled yet. It is still in the 'Creating' state.")
123            }
124            CachedPipelineState::Err(err) => panic!("{}", err),
125        }
126    }
127}
128
129#[derive(Default)]
130struct ShaderData {
131    pipelines: HashSet<CachedPipelineId>,
132    processed_shaders: HashMap<Box<[ShaderDefVal]>, ErasedShaderModule>,
133    resolved_imports: HashMap<ShaderImport, AssetId<Shader>>,
134    dependents: HashSet<AssetId<Shader>>,
135}
136
137struct ShaderCache {
138    data: HashMap<AssetId<Shader>, ShaderData>,
139    shaders: HashMap<AssetId<Shader>, Shader>,
140    import_path_shaders: HashMap<ShaderImport, AssetId<Shader>>,
141    waiting_on_import: HashMap<ShaderImport, Vec<AssetId<Shader>>>,
142    composer: naga_oil::compose::Composer,
143}
144
145#[derive(Clone, PartialEq, Eq, Debug, Hash)]
146pub enum ShaderDefVal {
147    Bool(String, bool),
148    Int(String, i32),
149    UInt(String, u32),
150}
151
152impl From<&str> for ShaderDefVal {
153    fn from(key: &str) -> Self {
154        ShaderDefVal::Bool(key.to_string(), true)
155    }
156}
157
158impl From<String> for ShaderDefVal {
159    fn from(key: String) -> Self {
160        ShaderDefVal::Bool(key, true)
161    }
162}
163
164impl ShaderDefVal {
165    pub fn value_as_string(&self) -> String {
166        match self {
167            ShaderDefVal::Bool(_, def) => def.to_string(),
168            ShaderDefVal::Int(_, def) => def.to_string(),
169            ShaderDefVal::UInt(_, def) => def.to_string(),
170        }
171    }
172}
173
174impl ShaderCache {
175    fn new(render_device: &RenderDevice, render_adapter: &RenderAdapter) -> Self {
176        let (capabilities, subgroup_stages) = get_capabilities(
177            render_device.features(),
178            render_adapter.get_downlevel_capabilities().flags,
179        );
180
181        #[cfg(debug_assertions)]
182        let composer = naga_oil::compose::Composer::default();
183        #[cfg(not(debug_assertions))]
184        let composer = naga_oil::compose::Composer::non_validating();
185
186        let composer = composer.with_capabilities(capabilities, subgroup_stages);
187
188        Self {
189            composer,
190            data: Default::default(),
191            shaders: Default::default(),
192            import_path_shaders: Default::default(),
193            waiting_on_import: Default::default(),
194        }
195    }
196
197    fn add_import_to_composer(
198        composer: &mut naga_oil::compose::Composer,
199        import_path_shaders: &HashMap<ShaderImport, AssetId<Shader>>,
200        shaders: &HashMap<AssetId<Shader>, Shader>,
201        import: &ShaderImport,
202    ) -> Result<(), PipelineCacheError> {
203        if !composer.contains_module(&import.module_name()) {
204            if let Some(shader_handle) = import_path_shaders.get(import) {
205                if let Some(shader) = shaders.get(shader_handle) {
206                    for import in &shader.imports {
207                        Self::add_import_to_composer(
208                            composer,
209                            import_path_shaders,
210                            shaders,
211                            import,
212                        )?;
213                    }
214
215                    composer.add_composable_module(shader.into())?;
216                }
217            }
218            // if we fail to add a module the composer will tell us what is missing
219        }
220
221        Ok(())
222    }
223
224    #[allow(clippy::result_large_err)]
225    fn get(
226        &mut self,
227        render_device: &RenderDevice,
228        pipeline: CachedPipelineId,
229        id: AssetId<Shader>,
230        shader_defs: &[ShaderDefVal],
231    ) -> Result<ErasedShaderModule, PipelineCacheError> {
232        let shader = self
233            .shaders
234            .get(&id)
235            .ok_or(PipelineCacheError::ShaderNotLoaded(id))?;
236        let data = self.data.entry(id).or_default();
237        let n_asset_imports = shader
238            .imports()
239            .filter(|import| matches!(import, ShaderImport::AssetPath(_)))
240            .count();
241        let n_resolved_asset_imports = data
242            .resolved_imports
243            .keys()
244            .filter(|import| matches!(import, ShaderImport::AssetPath(_)))
245            .count();
246        if n_asset_imports != n_resolved_asset_imports {
247            return Err(PipelineCacheError::ShaderImportNotYetAvailable);
248        }
249
250        data.pipelines.insert(pipeline);
251
252        // PERF: this shader_defs clone isn't great. use raw_entry_mut when it stabilizes
253        let module = match data.processed_shaders.entry_ref(shader_defs) {
254            EntryRef::Occupied(entry) => entry.into_mut(),
255            EntryRef::Vacant(entry) => {
256                let mut shader_defs = shader_defs.to_vec();
257                #[cfg(all(feature = "webgl", target_arch = "wasm32", not(feature = "webgpu")))]
258                {
259                    shader_defs.push("NO_ARRAY_TEXTURES_SUPPORT".into());
260                    shader_defs.push("NO_CUBE_ARRAY_TEXTURES_SUPPORT".into());
261                    shader_defs.push("SIXTEEN_BYTE_ALIGNMENT".into());
262                }
263
264                if cfg!(feature = "ios_simulator") {
265                    shader_defs.push("NO_CUBE_ARRAY_TEXTURES_SUPPORT".into());
266                }
267
268                shader_defs.push(ShaderDefVal::UInt(
269                    String::from("AVAILABLE_STORAGE_BUFFER_BINDINGS"),
270                    render_device.limits().max_storage_buffers_per_shader_stage,
271                ));
272
273                debug!(
274                    "processing shader {:?}, with shader defs {:?}",
275                    id, shader_defs
276                );
277                let shader_source = match &shader.source {
278                    #[cfg(feature = "shader_format_spirv")]
279                    Source::SpirV(data) => make_spirv(data),
280                    #[cfg(not(feature = "shader_format_spirv"))]
281                    Source::SpirV(_) => {
282                        unimplemented!(
283                            "Enable feature \"shader_format_spirv\" to use SPIR-V shaders"
284                        )
285                    }
286                    _ => {
287                        for import in shader.imports() {
288                            Self::add_import_to_composer(
289                                &mut self.composer,
290                                &self.import_path_shaders,
291                                &self.shaders,
292                                import,
293                            )?;
294                        }
295
296                        let shader_defs = shader_defs
297                            .into_iter()
298                            .chain(shader.shader_defs.iter().cloned())
299                            .map(|def| match def {
300                                ShaderDefVal::Bool(k, v) => {
301                                    (k, naga_oil::compose::ShaderDefValue::Bool(v))
302                                }
303                                ShaderDefVal::Int(k, v) => {
304                                    (k, naga_oil::compose::ShaderDefValue::Int(v))
305                                }
306                                ShaderDefVal::UInt(k, v) => {
307                                    (k, naga_oil::compose::ShaderDefValue::UInt(v))
308                                }
309                            })
310                            .collect::<std::collections::HashMap<_, _>>();
311
312                        let naga = self.composer.make_naga_module(
313                            naga_oil::compose::NagaModuleDescriptor {
314                                shader_defs,
315                                ..shader.into()
316                            },
317                        )?;
318
319                        wgpu::ShaderSource::Naga(Cow::Owned(naga))
320                    }
321                };
322
323                let module_descriptor = ShaderModuleDescriptor {
324                    label: None,
325                    source: shader_source,
326                };
327
328                render_device
329                    .wgpu_device()
330                    .push_error_scope(wgpu::ErrorFilter::Validation);
331                let shader_module = render_device.create_shader_module(module_descriptor);
332                let error = render_device.wgpu_device().pop_error_scope();
333
334                // `now_or_never` will return Some if the future is ready and None otherwise.
335                // On native platforms, wgpu will yield the error immediately while on wasm it may take longer since the browser APIs are asynchronous.
336                // So to keep the complexity of the ShaderCache low, we will only catch this error early on native platforms,
337                // and on wasm the error will be handled by wgpu and crash the application.
338                if let Some(Some(wgpu::Error::Validation { description, .. })) =
339                    bevy_utils::futures::now_or_never(error)
340                {
341                    return Err(PipelineCacheError::CreateShaderModule(description));
342                }
343
344                entry.insert(ErasedShaderModule::new(shader_module))
345            }
346        };
347
348        Ok(module.clone())
349    }
350
351    fn clear(&mut self, id: AssetId<Shader>) -> Vec<CachedPipelineId> {
352        let mut shaders_to_clear = vec![id];
353        let mut pipelines_to_queue = Vec::new();
354        while let Some(handle) = shaders_to_clear.pop() {
355            if let Some(data) = self.data.get_mut(&handle) {
356                data.processed_shaders.clear();
357                pipelines_to_queue.extend(data.pipelines.iter().copied());
358                shaders_to_clear.extend(data.dependents.iter().copied());
359
360                if let Some(Shader { import_path, .. }) = self.shaders.get(&handle) {
361                    self.composer
362                        .remove_composable_module(&import_path.module_name());
363                }
364            }
365        }
366
367        pipelines_to_queue
368    }
369
370    fn set_shader(&mut self, id: AssetId<Shader>, shader: Shader) -> Vec<CachedPipelineId> {
371        let pipelines_to_queue = self.clear(id);
372        let path = shader.import_path();
373        self.import_path_shaders.insert(path.clone(), id);
374        if let Some(waiting_shaders) = self.waiting_on_import.get_mut(path) {
375            for waiting_shader in waiting_shaders.drain(..) {
376                // resolve waiting shader import
377                let data = self.data.entry(waiting_shader).or_default();
378                data.resolved_imports.insert(path.clone(), id);
379                // add waiting shader as dependent of this shader
380                let data = self.data.entry(id).or_default();
381                data.dependents.insert(waiting_shader);
382            }
383        }
384
385        for import in shader.imports() {
386            if let Some(import_id) = self.import_path_shaders.get(import).copied() {
387                // resolve import because it is currently available
388                let data = self.data.entry(id).or_default();
389                data.resolved_imports.insert(import.clone(), import_id);
390                // add this shader as a dependent of the import
391                let data = self.data.entry(import_id).or_default();
392                data.dependents.insert(id);
393            } else {
394                let waiting = self.waiting_on_import.entry(import.clone()).or_default();
395                waiting.push(id);
396            }
397        }
398
399        self.shaders.insert(id, shader);
400        pipelines_to_queue
401    }
402
403    fn remove(&mut self, id: AssetId<Shader>) -> Vec<CachedPipelineId> {
404        let pipelines_to_queue = self.clear(id);
405        if let Some(shader) = self.shaders.remove(&id) {
406            self.import_path_shaders.remove(shader.import_path());
407        }
408
409        pipelines_to_queue
410    }
411}
412
413type LayoutCacheKey = (Vec<BindGroupLayoutId>, Vec<PushConstantRange>);
414#[derive(Default)]
415struct LayoutCache {
416    layouts: HashMap<LayoutCacheKey, ErasedPipelineLayout>,
417}
418
419impl LayoutCache {
420    fn get(
421        &mut self,
422        render_device: &RenderDevice,
423        bind_group_layouts: &[BindGroupLayout],
424        push_constant_ranges: Vec<PushConstantRange>,
425    ) -> ErasedPipelineLayout {
426        let bind_group_ids = bind_group_layouts.iter().map(|l| l.id()).collect();
427        self.layouts
428            .entry((bind_group_ids, push_constant_ranges))
429            .or_insert_with_key(|(_, push_constant_ranges)| {
430                let bind_group_layouts = bind_group_layouts
431                    .iter()
432                    .map(|l| l.value())
433                    .collect::<Vec<_>>();
434                ErasedPipelineLayout::new(render_device.create_pipeline_layout(
435                    &PipelineLayoutDescriptor {
436                        bind_group_layouts: &bind_group_layouts,
437                        push_constant_ranges,
438                        ..default()
439                    },
440                ))
441            })
442            .clone()
443    }
444}
445
446/// Cache for render and compute pipelines.
447///
448/// The cache stores existing render and compute pipelines allocated on the GPU, as well as
449/// pending creation. Pipelines inserted into the cache are identified by a unique ID, which
450/// can be used to retrieve the actual GPU object once it's ready. The creation of the GPU
451/// pipeline object is deferred to the [`RenderSet::Render`] step, just before the render
452/// graph starts being processed, as this requires access to the GPU.
453///
454/// Note that the cache does not perform automatic deduplication of identical pipelines. It is
455/// up to the user not to insert the same pipeline twice to avoid wasting GPU resources.
456///
457/// [`RenderSet::Render`]: crate::RenderSet::Render
458#[derive(Resource)]
459pub struct PipelineCache {
460    layout_cache: Arc<Mutex<LayoutCache>>,
461    shader_cache: Arc<Mutex<ShaderCache>>,
462    device: RenderDevice,
463    pipelines: Vec<CachedPipeline>,
464    waiting_pipelines: HashSet<CachedPipelineId>,
465    new_pipelines: Mutex<Vec<CachedPipeline>>,
466    /// If `true`, disables asynchronous pipeline compilation.
467    /// This has no effect on MacOS, wasm, or without the `multi_threaded` feature.
468    synchronous_pipeline_compilation: bool,
469}
470
471impl PipelineCache {
472    /// Returns an iterator over the pipelines in the pipeline cache.
473    pub fn pipelines(&self) -> impl Iterator<Item = &CachedPipeline> {
474        self.pipelines.iter()
475    }
476
477    /// Returns a iterator of the IDs of all currently waiting pipelines.
478    pub fn waiting_pipelines(&self) -> impl Iterator<Item = CachedPipelineId> + '_ {
479        self.waiting_pipelines.iter().copied()
480    }
481
482    /// Create a new pipeline cache associated with the given render device.
483    pub fn new(
484        device: RenderDevice,
485        render_adapter: RenderAdapter,
486        synchronous_pipeline_compilation: bool,
487    ) -> Self {
488        Self {
489            shader_cache: Arc::new(Mutex::new(ShaderCache::new(&device, &render_adapter))),
490            device,
491            layout_cache: default(),
492            waiting_pipelines: default(),
493            new_pipelines: default(),
494            pipelines: default(),
495            synchronous_pipeline_compilation,
496        }
497    }
498
499    /// Get the state of a cached render pipeline.
500    ///
501    /// See [`PipelineCache::queue_render_pipeline()`].
502    #[inline]
503    pub fn get_render_pipeline_state(&self, id: CachedRenderPipelineId) -> &CachedPipelineState {
504        &self.pipelines[id.0].state
505    }
506
507    /// Get the state of a cached compute pipeline.
508    ///
509    /// See [`PipelineCache::queue_compute_pipeline()`].
510    #[inline]
511    pub fn get_compute_pipeline_state(&self, id: CachedComputePipelineId) -> &CachedPipelineState {
512        &self.pipelines[id.0].state
513    }
514
515    /// Get the render pipeline descriptor a cached render pipeline was inserted from.
516    ///
517    /// See [`PipelineCache::queue_render_pipeline()`].
518    #[inline]
519    pub fn get_render_pipeline_descriptor(
520        &self,
521        id: CachedRenderPipelineId,
522    ) -> &RenderPipelineDescriptor {
523        match &self.pipelines[id.0].descriptor {
524            PipelineDescriptor::RenderPipelineDescriptor(descriptor) => descriptor,
525            PipelineDescriptor::ComputePipelineDescriptor(_) => unreachable!(),
526        }
527    }
528
529    /// Get the compute pipeline descriptor a cached render pipeline was inserted from.
530    ///
531    /// See [`PipelineCache::queue_compute_pipeline()`].
532    #[inline]
533    pub fn get_compute_pipeline_descriptor(
534        &self,
535        id: CachedComputePipelineId,
536    ) -> &ComputePipelineDescriptor {
537        match &self.pipelines[id.0].descriptor {
538            PipelineDescriptor::RenderPipelineDescriptor(_) => unreachable!(),
539            PipelineDescriptor::ComputePipelineDescriptor(descriptor) => descriptor,
540        }
541    }
542
543    /// Try to retrieve a render pipeline GPU object from a cached ID.
544    ///
545    /// # Returns
546    ///
547    /// This method returns a successfully created render pipeline if any, or `None` if the pipeline
548    /// was not created yet or if there was an error during creation. You can check the actual creation
549    /// state with [`PipelineCache::get_render_pipeline_state()`].
550    #[inline]
551    pub fn get_render_pipeline(&self, id: CachedRenderPipelineId) -> Option<&RenderPipeline> {
552        if let CachedPipelineState::Ok(Pipeline::RenderPipeline(pipeline)) =
553            &self.pipelines[id.0].state
554        {
555            Some(pipeline)
556        } else {
557            None
558        }
559    }
560
561    /// Wait for a render pipeline to finish compiling.
562    #[inline]
563    pub fn block_on_render_pipeline(&mut self, id: CachedRenderPipelineId) {
564        if self.pipelines.len() <= id.0 {
565            self.process_queue();
566        }
567
568        let state = &mut self.pipelines[id.0].state;
569        if let CachedPipelineState::Creating(task) = state {
570            *state = match bevy_tasks::block_on(task) {
571                Ok(p) => CachedPipelineState::Ok(p),
572                Err(e) => CachedPipelineState::Err(e),
573            };
574        }
575    }
576
577    /// Try to retrieve a compute pipeline GPU object from a cached ID.
578    ///
579    /// # Returns
580    ///
581    /// This method returns a successfully created compute pipeline if any, or `None` if the pipeline
582    /// was not created yet or if there was an error during creation. You can check the actual creation
583    /// state with [`PipelineCache::get_compute_pipeline_state()`].
584    #[inline]
585    pub fn get_compute_pipeline(&self, id: CachedComputePipelineId) -> Option<&ComputePipeline> {
586        if let CachedPipelineState::Ok(Pipeline::ComputePipeline(pipeline)) =
587            &self.pipelines[id.0].state
588        {
589            Some(pipeline)
590        } else {
591            None
592        }
593    }
594
595    /// Insert a render pipeline into the cache, and queue its creation.
596    ///
597    /// The pipeline is always inserted and queued for creation. There is no attempt to deduplicate it with
598    /// an already cached pipeline.
599    ///
600    /// # Returns
601    ///
602    /// This method returns the unique render shader ID of the cached pipeline, which can be used to query
603    /// the caching state with [`get_render_pipeline_state()`] and to retrieve the created GPU pipeline once
604    /// it's ready with [`get_render_pipeline()`].
605    ///
606    /// [`get_render_pipeline_state()`]: PipelineCache::get_render_pipeline_state
607    /// [`get_render_pipeline()`]: PipelineCache::get_render_pipeline
608    pub fn queue_render_pipeline(
609        &self,
610        descriptor: RenderPipelineDescriptor,
611    ) -> CachedRenderPipelineId {
612        let mut new_pipelines = self
613            .new_pipelines
614            .lock()
615            .unwrap_or_else(PoisonError::into_inner);
616        let id = CachedRenderPipelineId(self.pipelines.len() + new_pipelines.len());
617        new_pipelines.push(CachedPipeline {
618            descriptor: PipelineDescriptor::RenderPipelineDescriptor(Box::new(descriptor)),
619            state: CachedPipelineState::Queued,
620        });
621        id
622    }
623
624    /// Insert a compute pipeline into the cache, and queue its creation.
625    ///
626    /// The pipeline is always inserted and queued for creation. There is no attempt to deduplicate it with
627    /// an already cached pipeline.
628    ///
629    /// # Returns
630    ///
631    /// This method returns the unique compute shader ID of the cached pipeline, which can be used to query
632    /// the caching state with [`get_compute_pipeline_state()`] and to retrieve the created GPU pipeline once
633    /// it's ready with [`get_compute_pipeline()`].
634    ///
635    /// [`get_compute_pipeline_state()`]: PipelineCache::get_compute_pipeline_state
636    /// [`get_compute_pipeline()`]: PipelineCache::get_compute_pipeline
637    pub fn queue_compute_pipeline(
638        &self,
639        descriptor: ComputePipelineDescriptor,
640    ) -> CachedComputePipelineId {
641        let mut new_pipelines = self
642            .new_pipelines
643            .lock()
644            .unwrap_or_else(PoisonError::into_inner);
645        let id = CachedComputePipelineId(self.pipelines.len() + new_pipelines.len());
646        new_pipelines.push(CachedPipeline {
647            descriptor: PipelineDescriptor::ComputePipelineDescriptor(Box::new(descriptor)),
648            state: CachedPipelineState::Queued,
649        });
650        id
651    }
652
653    fn set_shader(&mut self, id: AssetId<Shader>, shader: &Shader) {
654        let mut shader_cache = self.shader_cache.lock().unwrap();
655        let pipelines_to_queue = shader_cache.set_shader(id, shader.clone());
656        for cached_pipeline in pipelines_to_queue {
657            self.pipelines[cached_pipeline].state = CachedPipelineState::Queued;
658            self.waiting_pipelines.insert(cached_pipeline);
659        }
660    }
661
662    fn remove_shader(&mut self, shader: AssetId<Shader>) {
663        let mut shader_cache = self.shader_cache.lock().unwrap();
664        let pipelines_to_queue = shader_cache.remove(shader);
665        for cached_pipeline in pipelines_to_queue {
666            self.pipelines[cached_pipeline].state = CachedPipelineState::Queued;
667            self.waiting_pipelines.insert(cached_pipeline);
668        }
669    }
670
671    fn start_create_render_pipeline(
672        &mut self,
673        id: CachedPipelineId,
674        descriptor: RenderPipelineDescriptor,
675    ) -> CachedPipelineState {
676        let device = self.device.clone();
677        let shader_cache = self.shader_cache.clone();
678        let layout_cache = self.layout_cache.clone();
679        create_pipeline_task(
680            async move {
681                let mut shader_cache = shader_cache.lock().unwrap();
682                let mut layout_cache = layout_cache.lock().unwrap();
683
684                let vertex_module = match shader_cache.get(
685                    &device,
686                    id,
687                    descriptor.vertex.shader.id(),
688                    &descriptor.vertex.shader_defs,
689                ) {
690                    Ok(module) => module,
691                    Err(err) => return Err(err),
692                };
693
694                let fragment_module = match &descriptor.fragment {
695                    Some(fragment) => {
696                        match shader_cache.get(
697                            &device,
698                            id,
699                            fragment.shader.id(),
700                            &fragment.shader_defs,
701                        ) {
702                            Ok(module) => Some(module),
703                            Err(err) => return Err(err),
704                        }
705                    }
706                    None => None,
707                };
708
709                let layout =
710                    if descriptor.layout.is_empty() && descriptor.push_constant_ranges.is_empty() {
711                        None
712                    } else {
713                        Some(layout_cache.get(
714                            &device,
715                            &descriptor.layout,
716                            descriptor.push_constant_ranges.to_vec(),
717                        ))
718                    };
719
720                drop((shader_cache, layout_cache));
721
722                let vertex_buffer_layouts = descriptor
723                    .vertex
724                    .buffers
725                    .iter()
726                    .map(|layout| RawVertexBufferLayout {
727                        array_stride: layout.array_stride,
728                        attributes: &layout.attributes,
729                        step_mode: layout.step_mode,
730                    })
731                    .collect::<Vec<_>>();
732
733                let fragment_data = descriptor.fragment.as_ref().map(|fragment| {
734                    (
735                        fragment_module.unwrap(),
736                        fragment.entry_point.deref(),
737                        fragment.targets.as_slice(),
738                    )
739                });
740
741                // TODO: Expose this somehow
742                let compilation_options = PipelineCompilationOptions {
743                    constants: &std::collections::HashMap::new(),
744                    zero_initialize_workgroup_memory: false,
745                };
746
747                let descriptor = RawRenderPipelineDescriptor {
748                    multiview: None,
749                    depth_stencil: descriptor.depth_stencil.clone(),
750                    label: descriptor.label.as_deref(),
751                    layout: layout.as_deref(),
752                    multisample: descriptor.multisample,
753                    primitive: descriptor.primitive,
754                    vertex: RawVertexState {
755                        buffers: &vertex_buffer_layouts,
756                        entry_point: descriptor.vertex.entry_point.deref(),
757                        module: &vertex_module,
758                        // TODO: Should this be the same as the fragment compilation options?
759                        compilation_options: compilation_options.clone(),
760                    },
761                    fragment: fragment_data
762                        .as_ref()
763                        .map(|(module, entry_point, targets)| RawFragmentState {
764                            entry_point,
765                            module,
766                            targets,
767                            // TODO: Should this be the same as the vertex compilation options?
768                            compilation_options,
769                        }),
770                };
771
772                Ok(Pipeline::RenderPipeline(
773                    device.create_render_pipeline(&descriptor),
774                ))
775            },
776            self.synchronous_pipeline_compilation,
777        )
778    }
779
780    fn start_create_compute_pipeline(
781        &mut self,
782        id: CachedPipelineId,
783        descriptor: ComputePipelineDescriptor,
784    ) -> CachedPipelineState {
785        let device = self.device.clone();
786        let shader_cache = self.shader_cache.clone();
787        let layout_cache = self.layout_cache.clone();
788        create_pipeline_task(
789            async move {
790                let mut shader_cache = shader_cache.lock().unwrap();
791                let mut layout_cache = layout_cache.lock().unwrap();
792
793                let compute_module = match shader_cache.get(
794                    &device,
795                    id,
796                    descriptor.shader.id(),
797                    &descriptor.shader_defs,
798                ) {
799                    Ok(module) => module,
800                    Err(err) => return Err(err),
801                };
802
803                let layout =
804                    if descriptor.layout.is_empty() && descriptor.push_constant_ranges.is_empty() {
805                        None
806                    } else {
807                        Some(layout_cache.get(
808                            &device,
809                            &descriptor.layout,
810                            descriptor.push_constant_ranges.to_vec(),
811                        ))
812                    };
813
814                drop((shader_cache, layout_cache));
815
816                let descriptor = RawComputePipelineDescriptor {
817                    label: descriptor.label.as_deref(),
818                    layout: layout.as_deref(),
819                    module: &compute_module,
820                    entry_point: &descriptor.entry_point,
821                    // TODO: Expose this somehow
822                    compilation_options: PipelineCompilationOptions {
823                        constants: &std::collections::HashMap::new(),
824                        zero_initialize_workgroup_memory: false,
825                    },
826                };
827
828                Ok(Pipeline::ComputePipeline(
829                    device.create_compute_pipeline(&descriptor),
830                ))
831            },
832            self.synchronous_pipeline_compilation,
833        )
834    }
835
836    /// Process the pipeline queue and create all pending pipelines if possible.
837    ///
838    /// This is generally called automatically during the [`RenderSet::Render`] step, but can
839    /// be called manually to force creation at a different time.
840    ///
841    /// [`RenderSet::Render`]: crate::RenderSet::Render
842    pub fn process_queue(&mut self) {
843        let mut waiting_pipelines = mem::take(&mut self.waiting_pipelines);
844        let mut pipelines = mem::take(&mut self.pipelines);
845
846        {
847            let mut new_pipelines = self
848                .new_pipelines
849                .lock()
850                .unwrap_or_else(PoisonError::into_inner);
851            for new_pipeline in new_pipelines.drain(..) {
852                let id = pipelines.len();
853                pipelines.push(new_pipeline);
854                waiting_pipelines.insert(id);
855            }
856        }
857
858        for id in waiting_pipelines {
859            self.process_pipeline(&mut pipelines[id], id);
860        }
861
862        self.pipelines = pipelines;
863    }
864
865    fn process_pipeline(&mut self, cached_pipeline: &mut CachedPipeline, id: usize) {
866        match &mut cached_pipeline.state {
867            CachedPipelineState::Queued => {
868                cached_pipeline.state = match &cached_pipeline.descriptor {
869                    PipelineDescriptor::RenderPipelineDescriptor(descriptor) => {
870                        self.start_create_render_pipeline(id, *descriptor.clone())
871                    }
872                    PipelineDescriptor::ComputePipelineDescriptor(descriptor) => {
873                        self.start_create_compute_pipeline(id, *descriptor.clone())
874                    }
875                };
876            }
877
878            CachedPipelineState::Creating(ref mut task) => {
879                match bevy_utils::futures::check_ready(task) {
880                    Some(Ok(pipeline)) => {
881                        cached_pipeline.state = CachedPipelineState::Ok(pipeline);
882                        return;
883                    }
884                    Some(Err(err)) => cached_pipeline.state = CachedPipelineState::Err(err),
885                    _ => (),
886                }
887            }
888
889            CachedPipelineState::Err(err) => match err {
890                // Retry
891                PipelineCacheError::ShaderNotLoaded(_)
892                | PipelineCacheError::ShaderImportNotYetAvailable => {
893                    cached_pipeline.state = CachedPipelineState::Queued;
894                }
895
896                // Shader could not be processed ... retrying won't help
897                PipelineCacheError::ProcessShaderError(err) => {
898                    let error_detail =
899                        err.emit_to_string(&self.shader_cache.lock().unwrap().composer);
900                    error!("failed to process shader:\n{}", error_detail);
901                    return;
902                }
903                PipelineCacheError::CreateShaderModule(description) => {
904                    error!("failed to create shader module: {}", description);
905                    return;
906                }
907            },
908
909            CachedPipelineState::Ok(_) => return,
910        }
911
912        // Retry
913        self.waiting_pipelines.insert(id);
914    }
915
916    pub(crate) fn process_pipeline_queue_system(mut cache: ResMut<Self>) {
917        cache.process_queue();
918    }
919
920    pub(crate) fn extract_shaders(
921        mut cache: ResMut<Self>,
922        shaders: Extract<Res<Assets<Shader>>>,
923        mut events: Extract<EventReader<AssetEvent<Shader>>>,
924    ) {
925        for event in events.read() {
926            #[allow(clippy::match_same_arms)]
927            match event {
928                // PERF: Instead of blocking waiting for the shader cache lock, try again next frame if the lock is currently held
929                AssetEvent::Added { id } | AssetEvent::Modified { id } => {
930                    if let Some(shader) = shaders.get(*id) {
931                        cache.set_shader(*id, shader);
932                    }
933                }
934                AssetEvent::Removed { id } => cache.remove_shader(*id),
935                AssetEvent::Unused { .. } => {}
936                AssetEvent::LoadedWithDependencies { .. } => {
937                    // TODO: handle this
938                }
939            }
940        }
941    }
942}
943
944#[cfg(all(
945    not(target_arch = "wasm32"),
946    not(target_os = "macos"),
947    feature = "multi_threaded"
948))]
949fn create_pipeline_task(
950    task: impl Future<Output = Result<Pipeline, PipelineCacheError>> + Send + 'static,
951    sync: bool,
952) -> CachedPipelineState {
953    if !sync {
954        return CachedPipelineState::Creating(bevy_tasks::AsyncComputeTaskPool::get().spawn(task));
955    }
956
957    match futures_lite::future::block_on(task) {
958        Ok(pipeline) => CachedPipelineState::Ok(pipeline),
959        Err(err) => CachedPipelineState::Err(err),
960    }
961}
962
963#[cfg(any(
964    target_arch = "wasm32",
965    target_os = "macos",
966    not(feature = "multi_threaded")
967))]
968fn create_pipeline_task(
969    task: impl Future<Output = Result<Pipeline, PipelineCacheError>> + Send + 'static,
970    _sync: bool,
971) -> CachedPipelineState {
972    match futures_lite::future::block_on(task) {
973        Ok(pipeline) => CachedPipelineState::Ok(pipeline),
974        Err(err) => CachedPipelineState::Err(err),
975    }
976}
977
978/// Type of error returned by a [`PipelineCache`] when the creation of a GPU pipeline object failed.
979#[derive(Error, Debug)]
980pub enum PipelineCacheError {
981    #[error(
982        "Pipeline could not be compiled because the following shader could not be loaded: {0:?}"
983    )]
984    ShaderNotLoaded(AssetId<Shader>),
985    #[error(transparent)]
986    ProcessShaderError(#[from] naga_oil::compose::ComposerError),
987    #[error("Shader import not yet available.")]
988    ShaderImportNotYetAvailable,
989    #[error("Could not create shader module: {0}")]
990    CreateShaderModule(String),
991}
992
993// TODO: This needs to be kept up to date with the capabilities in the `create_validator` function in wgpu-core
994// https://github.com/gfx-rs/wgpu/blob/trunk/wgpu-core/src/device/mod.rs#L449
995// We use a modified version of the `create_validator` function because `naga_oil`'s composer stores the capabilities
996// and subgroup shader stages instead of a `Validator`.
997// We also can't use that function because `wgpu-core` isn't included in WebGPU builds.
998/// Get the device capabilities and subgroup support for use in `naga_oil`.
999fn get_capabilities(
1000    features: Features,
1001    downlevel: DownlevelFlags,
1002) -> (Capabilities, naga::valid::ShaderStages) {
1003    let mut capabilities = Capabilities::empty();
1004    capabilities.set(
1005        Capabilities::PUSH_CONSTANT,
1006        features.contains(Features::PUSH_CONSTANTS),
1007    );
1008    capabilities.set(
1009        Capabilities::FLOAT64,
1010        features.contains(Features::SHADER_F64),
1011    );
1012    capabilities.set(
1013        Capabilities::PRIMITIVE_INDEX,
1014        features.contains(Features::SHADER_PRIMITIVE_INDEX),
1015    );
1016    capabilities.set(
1017        Capabilities::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING,
1018        features.contains(Features::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING),
1019    );
1020    capabilities.set(
1021        Capabilities::UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING,
1022        features.contains(Features::UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING),
1023    );
1024    // TODO: This needs a proper wgpu feature
1025    capabilities.set(
1026        Capabilities::SAMPLER_NON_UNIFORM_INDEXING,
1027        features.contains(Features::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING),
1028    );
1029    capabilities.set(
1030        Capabilities::STORAGE_TEXTURE_16BIT_NORM_FORMATS,
1031        features.contains(Features::TEXTURE_FORMAT_16BIT_NORM),
1032    );
1033    capabilities.set(
1034        Capabilities::MULTIVIEW,
1035        features.contains(Features::MULTIVIEW),
1036    );
1037    capabilities.set(
1038        Capabilities::EARLY_DEPTH_TEST,
1039        features.contains(Features::SHADER_EARLY_DEPTH_TEST),
1040    );
1041    capabilities.set(
1042        Capabilities::SHADER_INT64,
1043        features.contains(Features::SHADER_INT64),
1044    );
1045    capabilities.set(
1046        Capabilities::MULTISAMPLED_SHADING,
1047        downlevel.contains(DownlevelFlags::MULTISAMPLED_SHADING),
1048    );
1049    capabilities.set(
1050        Capabilities::DUAL_SOURCE_BLENDING,
1051        features.contains(Features::DUAL_SOURCE_BLENDING),
1052    );
1053    capabilities.set(
1054        Capabilities::CUBE_ARRAY_TEXTURES,
1055        downlevel.contains(DownlevelFlags::CUBE_ARRAY_TEXTURES),
1056    );
1057    capabilities.set(
1058        Capabilities::SUBGROUP,
1059        features.intersects(Features::SUBGROUP | Features::SUBGROUP_VERTEX),
1060    );
1061    capabilities.set(
1062        Capabilities::SUBGROUP_BARRIER,
1063        features.intersects(Features::SUBGROUP_BARRIER),
1064    );
1065
1066    let mut subgroup_stages = naga::valid::ShaderStages::empty();
1067    subgroup_stages.set(
1068        naga::valid::ShaderStages::COMPUTE | naga::valid::ShaderStages::FRAGMENT,
1069        features.contains(Features::SUBGROUP),
1070    );
1071    subgroup_stages.set(
1072        naga::valid::ShaderStages::VERTEX,
1073        features.contains(Features::SUBGROUP_VERTEX),
1074    );
1075
1076    (capabilities, subgroup_stages)
1077}