bevy_render/render_resource/
shader.rs

1use super::ShaderDefVal;
2use crate::define_atomic_id;
3use bevy_asset::{io::Reader, Asset, AssetLoader, AssetPath, Handle, LoadContext};
4use bevy_reflect::TypePath;
5use bevy_utils::tracing::error;
6use futures_lite::AsyncReadExt;
7use std::{borrow::Cow, marker::Copy};
8use thiserror::Error;
9
10define_atomic_id!(ShaderId);
11
12#[derive(Error, Debug)]
13pub enum ShaderReflectError {
14    #[error(transparent)]
15    WgslParse(#[from] naga::front::wgsl::ParseError),
16    #[cfg(feature = "shader_format_glsl")]
17    #[error("GLSL Parse Error: {0:?}")]
18    GlslParse(Vec<naga::front::glsl::Error>),
19    #[cfg(feature = "shader_format_spirv")]
20    #[error(transparent)]
21    SpirVParse(#[from] naga::front::spv::Error),
22    #[error(transparent)]
23    Validation(#[from] naga::WithSpan<naga::valid::ValidationError>),
24}
25/// A shader, as defined by its [`ShaderSource`](wgpu::ShaderSource) and [`ShaderStage`](naga::ShaderStage)
26/// This is an "unprocessed" shader. It can contain preprocessor directives.
27#[derive(Asset, TypePath, Debug, Clone)]
28pub struct Shader {
29    pub path: String,
30    pub source: Source,
31    pub import_path: ShaderImport,
32    pub imports: Vec<ShaderImport>,
33    // extra imports not specified in the source string
34    pub additional_imports: Vec<naga_oil::compose::ImportDefinition>,
35    // any shader defs that will be included when this module is used
36    pub shader_defs: Vec<ShaderDefVal>,
37    // we must store strong handles to our dependencies to stop them
38    // from being immediately dropped if we are the only user.
39    pub file_dependencies: Vec<Handle<Shader>>,
40}
41
42impl Shader {
43    fn preprocess(source: &str, path: &str) -> (ShaderImport, Vec<ShaderImport>) {
44        let (import_path, imports, _) = naga_oil::compose::get_preprocessor_data(source);
45
46        let import_path = import_path
47            .map(ShaderImport::Custom)
48            .unwrap_or_else(|| ShaderImport::AssetPath(path.to_owned()));
49
50        let imports = imports
51            .into_iter()
52            .map(|import| {
53                if import.import.starts_with('\"') {
54                    let import = import
55                        .import
56                        .chars()
57                        .skip(1)
58                        .take_while(|c| *c != '\"')
59                        .collect();
60                    ShaderImport::AssetPath(import)
61                } else {
62                    ShaderImport::Custom(import.import)
63                }
64            })
65            .collect();
66
67        (import_path, imports)
68    }
69
70    pub fn from_wgsl(source: impl Into<Cow<'static, str>>, path: impl Into<String>) -> Shader {
71        let source = source.into();
72        let path = path.into();
73        let (import_path, imports) = Shader::preprocess(&source, &path);
74        Shader {
75            path,
76            imports,
77            import_path,
78            source: Source::Wgsl(source),
79            additional_imports: Default::default(),
80            shader_defs: Default::default(),
81            file_dependencies: Default::default(),
82        }
83    }
84
85    pub fn from_wgsl_with_defs(
86        source: impl Into<Cow<'static, str>>,
87        path: impl Into<String>,
88        shader_defs: Vec<ShaderDefVal>,
89    ) -> Shader {
90        Self {
91            shader_defs,
92            ..Self::from_wgsl(source, path)
93        }
94    }
95
96    pub fn from_glsl(
97        source: impl Into<Cow<'static, str>>,
98        stage: naga::ShaderStage,
99        path: impl Into<String>,
100    ) -> Shader {
101        let source = source.into();
102        let path = path.into();
103        let (import_path, imports) = Shader::preprocess(&source, &path);
104        Shader {
105            path,
106            imports,
107            import_path,
108            source: Source::Glsl(source, stage),
109            additional_imports: Default::default(),
110            shader_defs: Default::default(),
111            file_dependencies: Default::default(),
112        }
113    }
114
115    pub fn from_spirv(source: impl Into<Cow<'static, [u8]>>, path: impl Into<String>) -> Shader {
116        let path = path.into();
117        Shader {
118            path: path.clone(),
119            imports: Vec::new(),
120            import_path: ShaderImport::AssetPath(path),
121            source: Source::SpirV(source.into()),
122            additional_imports: Default::default(),
123            shader_defs: Default::default(),
124            file_dependencies: Default::default(),
125        }
126    }
127
128    pub fn set_import_path<P: Into<String>>(&mut self, import_path: P) {
129        self.import_path = ShaderImport::Custom(import_path.into());
130    }
131
132    #[must_use]
133    pub fn with_import_path<P: Into<String>>(mut self, import_path: P) -> Self {
134        self.set_import_path(import_path);
135        self
136    }
137
138    #[inline]
139    pub fn import_path(&self) -> &ShaderImport {
140        &self.import_path
141    }
142
143    pub fn imports(&self) -> impl ExactSizeIterator<Item = &ShaderImport> {
144        self.imports.iter()
145    }
146}
147
148impl<'a> From<&'a Shader> for naga_oil::compose::ComposableModuleDescriptor<'a> {
149    fn from(shader: &'a Shader) -> Self {
150        let shader_defs = shader
151            .shader_defs
152            .iter()
153            .map(|def| match def {
154                ShaderDefVal::Bool(name, b) => {
155                    (name.clone(), naga_oil::compose::ShaderDefValue::Bool(*b))
156                }
157                ShaderDefVal::Int(name, i) => {
158                    (name.clone(), naga_oil::compose::ShaderDefValue::Int(*i))
159                }
160                ShaderDefVal::UInt(name, i) => {
161                    (name.clone(), naga_oil::compose::ShaderDefValue::UInt(*i))
162                }
163            })
164            .collect();
165
166        let as_name = match &shader.import_path {
167            ShaderImport::AssetPath(asset_path) => Some(format!("\"{asset_path}\"")),
168            ShaderImport::Custom(_) => None,
169        };
170
171        naga_oil::compose::ComposableModuleDescriptor {
172            source: shader.source.as_str(),
173            file_path: &shader.path,
174            language: (&shader.source).into(),
175            additional_imports: &shader.additional_imports,
176            shader_defs,
177            as_name,
178        }
179    }
180}
181
182impl<'a> From<&'a Shader> for naga_oil::compose::NagaModuleDescriptor<'a> {
183    fn from(shader: &'a Shader) -> Self {
184        naga_oil::compose::NagaModuleDescriptor {
185            source: shader.source.as_str(),
186            file_path: &shader.path,
187            shader_type: (&shader.source).into(),
188            ..Default::default()
189        }
190    }
191}
192
193#[derive(Debug, Clone)]
194pub enum Source {
195    Wgsl(Cow<'static, str>),
196    Glsl(Cow<'static, str>, naga::ShaderStage),
197    SpirV(Cow<'static, [u8]>),
198    // TODO: consider the following
199    // PrecompiledSpirVMacros(HashMap<HashSet<String>, Vec<u32>>)
200    // NagaModule(Module) ... Module impls Serialize/Deserialize
201}
202
203impl Source {
204    pub fn as_str(&self) -> &str {
205        match self {
206            Source::Wgsl(s) | Source::Glsl(s, _) => s,
207            Source::SpirV(_) => panic!("spirv not yet implemented"),
208        }
209    }
210}
211
212impl From<&Source> for naga_oil::compose::ShaderLanguage {
213    fn from(value: &Source) -> Self {
214        match value {
215            Source::Wgsl(_) => naga_oil::compose::ShaderLanguage::Wgsl,
216            #[cfg(any(feature = "shader_format_glsl", target_arch = "wasm32"))]
217            Source::Glsl(_, _) => naga_oil::compose::ShaderLanguage::Glsl,
218            #[cfg(all(not(feature = "shader_format_glsl"), not(target_arch = "wasm32")))]
219            Source::Glsl(_, _) => panic!(
220                "GLSL is not supported in this configuration; use the feature `shader_format_glsl`"
221            ),
222            Source::SpirV(_) => panic!("spirv not yet implemented"),
223        }
224    }
225}
226
227impl From<&Source> for naga_oil::compose::ShaderType {
228    fn from(value: &Source) -> Self {
229        match value {
230            Source::Wgsl(_) => naga_oil::compose::ShaderType::Wgsl,
231            #[cfg(any(feature = "shader_format_glsl", target_arch = "wasm32"))]
232            Source::Glsl(_, shader_stage) => match shader_stage {
233                naga::ShaderStage::Vertex => naga_oil::compose::ShaderType::GlslVertex,
234                naga::ShaderStage::Fragment => naga_oil::compose::ShaderType::GlslFragment,
235                naga::ShaderStage::Compute => panic!("glsl compute not yet implemented"),
236            },
237            #[cfg(all(not(feature = "shader_format_glsl"), not(target_arch = "wasm32")))]
238            Source::Glsl(_, _) => panic!(
239                "GLSL is not supported in this configuration; use the feature `shader_format_glsl`"
240            ),
241            Source::SpirV(_) => panic!("spirv not yet implemented"),
242        }
243    }
244}
245
246#[derive(Default)]
247pub struct ShaderLoader;
248
249#[non_exhaustive]
250#[derive(Debug, Error)]
251pub enum ShaderLoaderError {
252    #[error("Could not load shader: {0}")]
253    Io(#[from] std::io::Error),
254    #[error("Could not parse shader: {0}")]
255    Parse(#[from] std::string::FromUtf8Error),
256}
257
258impl AssetLoader for ShaderLoader {
259    type Asset = Shader;
260    type Settings = ();
261    type Error = ShaderLoaderError;
262    async fn load<'a>(
263        &'a self,
264        reader: &'a mut Reader<'_>,
265        _settings: &'a Self::Settings,
266        load_context: &'a mut LoadContext<'_>,
267    ) -> Result<Shader, Self::Error> {
268        let ext = load_context.path().extension().unwrap().to_str().unwrap();
269        let path = load_context.asset_path().to_string();
270        // On windows, the path will inconsistently use \ or /.
271        // TODO: remove this once AssetPath forces cross-platform "slash" consistency. See #10511
272        let path = path.replace(std::path::MAIN_SEPARATOR, "/");
273        let mut bytes = Vec::new();
274        reader.read_to_end(&mut bytes).await?;
275        let mut shader = match ext {
276            "spv" => Shader::from_spirv(bytes, load_context.path().to_string_lossy()),
277            "wgsl" => Shader::from_wgsl(String::from_utf8(bytes)?, path),
278            "vert" => Shader::from_glsl(String::from_utf8(bytes)?, naga::ShaderStage::Vertex, path),
279            "frag" => {
280                Shader::from_glsl(String::from_utf8(bytes)?, naga::ShaderStage::Fragment, path)
281            }
282            "comp" => {
283                Shader::from_glsl(String::from_utf8(bytes)?, naga::ShaderStage::Compute, path)
284            }
285            _ => panic!("unhandled extension: {ext}"),
286        };
287
288        // collect and store file dependencies
289        for import in &shader.imports {
290            if let ShaderImport::AssetPath(asset_path) = import {
291                shader.file_dependencies.push(load_context.load(asset_path));
292            }
293        }
294        Ok(shader)
295    }
296
297    fn extensions(&self) -> &[&str] {
298        &["spv", "wgsl", "vert", "frag", "comp"]
299    }
300}
301
302#[derive(Debug, PartialEq, Eq, Clone, Hash)]
303pub enum ShaderImport {
304    AssetPath(String),
305    Custom(String),
306}
307
308impl ShaderImport {
309    pub fn module_name(&self) -> Cow<'_, String> {
310        match self {
311            ShaderImport::AssetPath(s) => Cow::Owned(format!("\"{s}\"")),
312            ShaderImport::Custom(s) => Cow::Borrowed(s),
313        }
314    }
315}
316
317/// A reference to a shader asset.
318pub enum ShaderRef {
319    /// Use the "default" shader for the current context.
320    Default,
321    /// A handle to a shader stored in the [`Assets<Shader>`](bevy_asset::Assets) resource
322    Handle(Handle<Shader>),
323    /// An asset path leading to a shader
324    Path(AssetPath<'static>),
325}
326
327impl From<Handle<Shader>> for ShaderRef {
328    fn from(handle: Handle<Shader>) -> Self {
329        Self::Handle(handle)
330    }
331}
332
333impl From<AssetPath<'static>> for ShaderRef {
334    fn from(path: AssetPath<'static>) -> Self {
335        Self::Path(path)
336    }
337}
338
339impl From<&'static str> for ShaderRef {
340    fn from(path: &'static str) -> Self {
341        Self::Path(AssetPath::from(path))
342    }
343}