naga_oil/compose/
mod.rs

1use indexmap::IndexMap;
2/// the compose module allows construction of shaders from modules (which are themselves shaders).
3///
4/// it does this by treating shaders as modules, and
5/// - building each module independently to naga IR
6/// - creating "header" files for each supported language, which are used to build dependent modules/shaders
7/// - making final shaders by combining the shader IR with the IR for imported modules
8///
9/// for multiple small shaders with large common imports, this can be faster than parsing the full source for each shader, and it allows for constructing shaders in a cleaner modular manner with better scope control.
10///
11/// ## imports
12///
13/// shaders can be added to the composer as modules. this makes their types, constants, variables and functions available to modules/shaders that import them. note that importing a module will affect the final shader's global state if the module defines globals variables with bindings.
14///
15/// modules must include a `#define_import_path` directive that names the module.
16///
17/// ```ignore
18/// #define_import_path my_module
19///
20/// fn my_func() -> f32 {
21///     return 1.0;
22/// }
23/// ```
24///
25/// shaders can then import the module with an `#import` directive (with an optional `as` name). at point of use, imported items must be qualified:
26///
27/// ```ignore
28/// #import my_module
29/// #import my_other_module as Mod2
30///
31/// fn main() -> f32 {
32///     let x = my_module::my_func();
33///     let y = Mod2::my_other_func();
34///     return x*y;
35/// }
36/// ```
37///
38/// or import a comma-separated list of individual items with a `#from` directive. at point of use, imported items must be prefixed with `::` :
39///
40/// ```ignore
41/// #from my_module import my_func, my_const
42///
43/// fn main() -> f32 {
44///     return ::my_func(::my_const);
45/// }
46/// ```
47///
48/// imports can be nested - modules may import other modules, but not recursively. when a new module is added, all its `#import`s must already have been added.
49/// the same module can be imported multiple times by different modules in the import tree.
50/// there is no overlap of namespaces, so the same function names (or type, constant, or variable names) may be used in different modules.
51///
52/// note: when importing an item with the `#from` directive, the final shader will include the required dependencies (bindings, globals, consts, other functions) of the imported item, but will not include the rest of the imported module. it will however still include all of any modules imported by the imported module. this is probably not desired in general and may be fixed in a future version. currently for a more complete culling of unused dependencies the `prune` module can be used.
53///
54/// ## overriding functions
55///
56/// virtual functions can be declared with the `virtual` keyword:
57/// ```ignore
58/// virtual fn point_light(world_position: vec3<f32>) -> vec3<f32> { ... }
59/// ```
60/// virtual functions defined in imported modules can then be overridden using the `override` keyword:
61///
62/// ```ignore
63/// #import bevy_pbr::lighting as Lighting
64///
65/// override fn Lighting::point_light (world_position: vec3<f32>) -> vec3<f32> {
66///     let original = Lighting::point_light(world_position);
67///     let quantized = vec3<u32>(original * 3.0);
68///     return vec3<f32>(quantized) / 3.0;
69/// }
70/// ```
71///
72/// override function definitions cause *all* calls to the original function in the entire shader scope to be replaced by calls to the new function, with the exception of calls within the override function itself.
73///
74/// the function signature of the override must match the base function.
75///
76/// overrides can be specified at any point in the final shader's import tree.
77///
78/// multiple overrides can be applied to the same function. for example, given :
79/// - a module `a` containing a function `f`,
80/// - a module `b` that imports `a`, and containing an `override a::f` function,
81/// - a module `c` that imports `a` and `b`, and containing an `override a::f` function,
82/// then b and c both specify an override for `a::f`.
83/// the `override fn a::f` declared in module `b` may call to `a::f` within its body.
84/// the `override fn a::f` declared in module 'c' may call to `a::f` within its body, but the call will be redirected to `b::f`.
85/// any other calls to `a::f` (within modules 'a' or `b`, or anywhere else) will end up redirected to `c::f`
86/// in this way a chain or stack of overrides can be applied.
87///
88/// different overrides of the same function can be specified in different import branches. the final stack will be ordered based on the first occurrence of the override in the import tree (using a depth first search).
89///
90/// note that imports into a module/shader are processed in order, but are processed before the body of the current shader/module regardless of where they occur in that module, so there is no way to import a module containing an override and inject a call into the override stack prior to that imported override. you can instead create two modules each containing an override and import them into a parent module/shader to order them as required.
91/// override functions can currently only be defined in wgsl.
92///
93/// if the `override_any` crate feature is enabled, then the `virtual` keyword is not required for the function being overridden.
94///
95/// ## languages
96///
97/// modules can we written in GLSL or WGSL. shaders with entry points can be imported as modules (provided they have a `#define_import_path` directive). entry points are available to call from imported modules either via their name (for WGSL) or via `module::main` (for GLSL).
98///
99/// final shaders can also be written in GLSL or WGSL. for GLSL users must specify whether the shader is a vertex shader or fragment shader via the `ShaderType` argument (GLSL compute shaders are not supported).
100///
101/// ## preprocessing
102///
103/// when generating a final shader or adding a composable module, a set of `shader_def` string/value pairs must be provided. The value can be a bool (`ShaderDefValue::Bool`), an i32 (`ShaderDefValue::Int`) or a u32 (`ShaderDefValue::UInt`).
104///
105/// these allow conditional compilation of parts of modules and the final shader. conditional compilation is performed with `#if` / `#ifdef` / `#ifndef`, `#else` and `#endif` preprocessor directives:
106///
107/// ```ignore
108/// fn get_number() -> f32 {
109///     #ifdef BIG_NUMBER
110///         return 999.0;
111///     #else
112///         return 0.999;
113///     #endif
114/// }
115/// ```
116/// the `#ifdef` directive matches when the def name exists in the input binding set (regardless of value). the `#ifndef` directive is the reverse.
117///
118/// the `#if` directive requires a def name, an operator, and a value for comparison:
119/// - the def name must be a provided `shader_def` name.
120/// - the operator must be one of `==`, `!=`, `>=`, `>`, `<`, `<=`
121/// - the value must be an integer literal if comparing to a `ShaderDef::Int`, or `true` or `false` if comparing to a `ShaderDef::Bool`.
122///
123/// shader defs can also be used in the shader source with `#SHADER_DEF` or `#{SHADER_DEF}`, and will be substituted for their value.
124///
125/// ## error reporting
126///
127/// codespan reporting for errors is available using the error `emit_to_string` method. this requires validation to be enabled, which is true by default. `Composer::non_validating()` produces a non-validating composer that is not able to give accurate error reporting.
128///
129use naga::{
130    valid::{Capabilities, ShaderStages},
131    EntryPoint,
132};
133use regex::Regex;
134use std::collections::{hash_map::Entry, BTreeMap, HashMap, HashSet};
135use tracing::{debug, trace};
136
137use crate::{
138    compose::preprocess::{PreprocessOutput, PreprocessorMetaData},
139    derive::DerivedModule,
140    redirect::Redirector,
141};
142
143pub use self::error::{ComposerError, ComposerErrorInner, ErrSource};
144use self::preprocess::Preprocessor;
145
146pub mod comment_strip_iter;
147pub mod error;
148pub mod parse_imports;
149pub mod preprocess;
150mod test;
151pub mod tokenizer;
152
153#[derive(Hash, PartialEq, Eq, Clone, Copy, Debug, Default)]
154pub enum ShaderLanguage {
155    #[default]
156    Wgsl,
157    #[cfg(feature = "glsl")]
158    Glsl,
159}
160
161#[derive(Hash, PartialEq, Eq, Clone, Copy, Debug, Default)]
162pub enum ShaderType {
163    #[default]
164    Wgsl,
165    #[cfg(feature = "glsl")]
166    GlslVertex,
167    #[cfg(feature = "glsl")]
168    GlslFragment,
169}
170
171impl From<ShaderType> for ShaderLanguage {
172    fn from(ty: ShaderType) -> Self {
173        match ty {
174            ShaderType::Wgsl => ShaderLanguage::Wgsl,
175            #[cfg(feature = "glsl")]
176            ShaderType::GlslVertex | ShaderType::GlslFragment => ShaderLanguage::Glsl,
177        }
178    }
179}
180
181#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
182pub enum ShaderDefValue {
183    Bool(bool),
184    Int(i32),
185    UInt(u32),
186}
187
188impl Default for ShaderDefValue {
189    fn default() -> Self {
190        ShaderDefValue::Bool(true)
191    }
192}
193
194impl ShaderDefValue {
195    fn value_as_string(&self) -> String {
196        match self {
197            ShaderDefValue::Bool(val) => val.to_string(),
198            ShaderDefValue::Int(val) => val.to_string(),
199            ShaderDefValue::UInt(val) => val.to_string(),
200        }
201    }
202}
203
204#[derive(Clone, PartialEq, Eq, Hash, Debug, Default)]
205pub struct OwnedShaderDefs(BTreeMap<String, ShaderDefValue>);
206
207#[derive(Clone, PartialEq, Eq, Hash, Debug)]
208struct ModuleKey(OwnedShaderDefs);
209
210impl ModuleKey {
211    fn from_members(key: &HashMap<String, ShaderDefValue>, universe: &[String]) -> Self {
212        let mut acc = OwnedShaderDefs::default();
213        for item in universe {
214            if let Some(value) = key.get(item) {
215                acc.0.insert(item.to_owned(), *value);
216            }
217        }
218        ModuleKey(acc)
219    }
220}
221
222// a module built with a specific set of shader_defs
223#[derive(Default, Debug)]
224pub struct ComposableModule {
225    // module decoration, prefixed to all items from this module in the final source
226    pub decorated_name: String,
227    // module names required as imports, optionally with a list of items to import
228    pub imports: Vec<ImportDefinition>,
229    // types exported
230    pub owned_types: HashSet<String>,
231    // constants exported
232    pub owned_constants: HashSet<String>,
233    // vars exported
234    pub owned_vars: HashSet<String>,
235    // functions exported
236    pub owned_functions: HashSet<String>,
237    // local functions that can be overridden
238    pub virtual_functions: HashSet<String>,
239    // overriding functions defined in this module
240    // target function -> Vec<replacement functions>
241    pub override_functions: IndexMap<String, Vec<String>>,
242    // naga module, built against headers for any imports
243    module_ir: naga::Module,
244    // headers in different shader languages, used for building modules/shaders that import this module
245    // headers contain types, constants, global vars and empty function definitions -
246    // just enough to convert source strings that want to import this module into naga IR
247    // headers: HashMap<ShaderLanguage, String>,
248    header_ir: naga::Module,
249    // character offset of the start of the owned module string
250    start_offset: usize,
251}
252
253// data used to build a ComposableModule
254#[derive(Debug)]
255pub struct ComposableModuleDefinition {
256    pub name: String,
257    // shader text (with auto bindings replaced - we do this on module add as we only want to do it once to avoid burning slots)
258    pub sanitized_source: String,
259    // language
260    pub language: ShaderLanguage,
261    // source path for error display
262    pub file_path: String,
263    // shader def values bound to this module
264    pub shader_defs: HashMap<String, ShaderDefValue>,
265    // list of shader_defs that can affect this module
266    effective_defs: Vec<String>,
267    // full list of possible imports (regardless of shader_def configuration)
268    all_imports: HashSet<String>,
269    // additional imports to add (as though they were included in the source after any other imports)
270    additional_imports: Vec<ImportDefinition>,
271    // built composable modules for a given set of shader defs
272    modules: HashMap<ModuleKey, ComposableModule>,
273    // used in spans when this module is included
274    module_index: usize,
275    // preprocessor meta data
276    // metadata: PreprocessorMetaData,
277}
278
279impl ComposableModuleDefinition {
280    fn get_module(
281        &self,
282        shader_defs: &HashMap<String, ShaderDefValue>,
283    ) -> Option<&ComposableModule> {
284        self.modules
285            .get(&ModuleKey::from_members(shader_defs, &self.effective_defs))
286    }
287
288    fn insert_module(
289        &mut self,
290        shader_defs: &HashMap<String, ShaderDefValue>,
291        module: ComposableModule,
292    ) -> &ComposableModule {
293        match self
294            .modules
295            .entry(ModuleKey::from_members(shader_defs, &self.effective_defs))
296        {
297            Entry::Occupied(_) => panic!("entry already populated"),
298            Entry::Vacant(v) => v.insert(module),
299        }
300    }
301}
302
303#[derive(Debug, Clone, Default, PartialEq, Eq)]
304pub struct ImportDefinition {
305    pub import: String,
306    pub items: Vec<String>,
307}
308
309#[derive(Debug, Clone)]
310pub struct ImportDefWithOffset {
311    definition: ImportDefinition,
312    offset: usize,
313}
314
315/// module composer.
316/// stores any modules that can be imported into a shader
317/// and builds the final shader
318#[derive(Debug)]
319pub struct Composer {
320    pub validate: bool,
321    pub module_sets: HashMap<String, ComposableModuleDefinition>,
322    pub module_index: HashMap<usize, String>,
323    pub capabilities: naga::valid::Capabilities,
324    /// The shader stages that the subgroup operations are valid for.
325    /// Used when creating a validator for the module.
326    /// See https://github.com/gfx-rs/wgpu/blob/d9c054c645af0ea9ef81617c3e762fbf0f3fecda/wgpu-core/src/device/mod.rs#L515
327    /// for how to set this for proper subgroup ops support.
328    pub subgroup_stages: ShaderStages,
329    preprocessor: Preprocessor,
330    check_decoration_regex: Regex,
331    undecorate_regex: Regex,
332    virtual_fn_regex: Regex,
333    override_fn_regex: Regex,
334    undecorate_override_regex: Regex,
335    auto_binding_regex: Regex,
336    auto_binding_index: u32,
337}
338
339// shift for module index
340// 21 gives
341//   max size for shader of 2m characters
342//   max 2048 modules
343const SPAN_SHIFT: usize = 21;
344
345impl Default for Composer {
346    fn default() -> Self {
347        Self {
348            validate: true,
349            capabilities: Default::default(),
350            subgroup_stages: ShaderStages::empty(),
351            module_sets: Default::default(),
352            module_index: Default::default(),
353            preprocessor: Preprocessor::default(),
354            check_decoration_regex: Regex::new(
355                format!(
356                    "({}|{})",
357                    regex_syntax::escape(DECORATION_PRE),
358                    regex_syntax::escape(DECORATION_OVERRIDE_PRE)
359                )
360                .as_str(),
361            )
362            .unwrap(),
363            undecorate_regex: Regex::new(
364                format!(
365                    r"(\x1B\[\d+\w)?([\w\d_]+){}([A-Z0-9]*){}",
366                    regex_syntax::escape(DECORATION_PRE),
367                    regex_syntax::escape(DECORATION_POST)
368                )
369                .as_str(),
370            )
371            .unwrap(),
372            virtual_fn_regex: Regex::new(
373                r"(?P<lead>[\s]*virtual\s+fn\s+)(?P<function>[^\s]+)(?P<trail>\s*)\(",
374            )
375            .unwrap(),
376            override_fn_regex: Regex::new(
377                format!(
378                    r"(override\s+fn\s+)([^\s]+){}([\w\d]+){}(\s*)\(",
379                    regex_syntax::escape(DECORATION_PRE),
380                    regex_syntax::escape(DECORATION_POST)
381                )
382                .as_str(),
383            )
384            .unwrap(),
385            undecorate_override_regex: Regex::new(
386                format!(
387                    "{}([A-Z0-9]*){}",
388                    regex_syntax::escape(DECORATION_OVERRIDE_PRE),
389                    regex_syntax::escape(DECORATION_POST)
390                )
391                .as_str(),
392            )
393            .unwrap(),
394            auto_binding_regex: Regex::new(r"@binding\(auto\)").unwrap(),
395            auto_binding_index: 0,
396        }
397    }
398}
399
400const DECORATION_PRE: &str = "X_naga_oil_mod_X";
401const DECORATION_POST: &str = "X";
402
403// must be same length as DECORATION_PRE for spans to work
404const DECORATION_OVERRIDE_PRE: &str = "X_naga_oil_vrt_X";
405
406struct IrBuildResult {
407    module: naga::Module,
408    start_offset: usize,
409    override_functions: IndexMap<String, Vec<String>>,
410}
411
412impl Composer {
413    pub fn decorated_name(module_name: Option<&str>, item_name: &str) -> String {
414        match module_name {
415            Some(module_name) => format!("{}{}", item_name, Self::decorate(module_name)),
416            None => item_name.to_owned(),
417        }
418    }
419
420    fn decorate(module: &str) -> String {
421        let encoded = data_encoding::BASE32_NOPAD.encode(module.as_bytes());
422        format!("{DECORATION_PRE}{encoded}{DECORATION_POST}")
423    }
424
425    fn decode(from: &str) -> String {
426        String::from_utf8(data_encoding::BASE32_NOPAD.decode(from.as_bytes()).unwrap()).unwrap()
427    }
428
429    /// This creates a validator that properly detects subgroup support.
430    fn create_validator(&self) -> naga::valid::Validator {
431        let subgroup_operations = if self.capabilities.contains(Capabilities::SUBGROUP) {
432            use naga::valid::SubgroupOperationSet as S;
433            S::BASIC | S::VOTE | S::ARITHMETIC | S::BALLOT | S::SHUFFLE | S::SHUFFLE_RELATIVE
434        } else {
435            naga::valid::SubgroupOperationSet::empty()
436        };
437        let mut validator =
438            naga::valid::Validator::new(naga::valid::ValidationFlags::all(), self.capabilities);
439        validator.subgroup_stages(self.subgroup_stages);
440        validator.subgroup_operations(subgroup_operations);
441        validator
442    }
443
444    fn undecorate(&self, string: &str) -> String {
445        let undecor = self
446            .undecorate_regex
447            .replace_all(string, |caps: &regex::Captures| {
448                format!(
449                    "{}{}::{}",
450                    caps.get(1).map(|cc| cc.as_str()).unwrap_or(""),
451                    Self::decode(caps.get(3).unwrap().as_str()),
452                    caps.get(2).unwrap().as_str()
453                )
454            });
455
456        let undecor =
457            self.undecorate_override_regex
458                .replace_all(&undecor, |caps: &regex::Captures| {
459                    format!(
460                        "override fn {}::",
461                        Self::decode(caps.get(1).unwrap().as_str())
462                    )
463                });
464
465        undecor.to_string()
466    }
467
468    fn sanitize_and_set_auto_bindings(&mut self, source: &str) -> String {
469        let mut substituted_source = source.replace("\r\n", "\n").replace('\r', "\n");
470        if !substituted_source.ends_with('\n') {
471            substituted_source.push('\n');
472        }
473
474        // replace @binding(auto) with an incrementing index
475        struct AutoBindingReplacer<'a> {
476            auto: &'a mut u32,
477        }
478
479        impl<'a> regex::Replacer for AutoBindingReplacer<'a> {
480            fn replace_append(&mut self, _: &regex::Captures<'_>, dst: &mut String) {
481                dst.push_str(&format!("@binding({})", self.auto));
482                *self.auto += 1;
483            }
484        }
485
486        let substituted_source = self.auto_binding_regex.replace_all(
487            &substituted_source,
488            AutoBindingReplacer {
489                auto: &mut self.auto_binding_index,
490            },
491        );
492
493        substituted_source.into_owned()
494    }
495
496    fn naga_to_string(
497        &self,
498        naga_module: &mut naga::Module,
499        language: ShaderLanguage,
500        #[allow(unused)] header_for: &str, // Only used when GLSL is enabled
501    ) -> Result<String, ComposerErrorInner> {
502        // TODO: cache headers again
503        let info = self
504            .create_validator()
505            .validate(naga_module)
506            .map_err(ComposerErrorInner::HeaderValidationError)?;
507
508        match language {
509            ShaderLanguage::Wgsl => naga::back::wgsl::write_string(
510                naga_module,
511                &info,
512                naga::back::wgsl::WriterFlags::EXPLICIT_TYPES,
513            )
514            .map_err(ComposerErrorInner::WgslBackError),
515            #[cfg(feature = "glsl")]
516            ShaderLanguage::Glsl => {
517                let vec4 = naga_module.types.insert(
518                    naga::Type {
519                        name: None,
520                        inner: naga::TypeInner::Vector {
521                            size: naga::VectorSize::Quad,
522                            scalar: naga::Scalar::F32,
523                        },
524                    },
525                    naga::Span::UNDEFINED,
526                );
527                // add a dummy entry point for glsl headers
528                let dummy_entry_point = "dummy_module_entry_point".to_owned();
529                let func = naga::Function {
530                    name: Some(dummy_entry_point.clone()),
531                    arguments: Default::default(),
532                    result: Some(naga::FunctionResult {
533                        ty: vec4,
534                        binding: Some(naga::Binding::BuiltIn(naga::BuiltIn::Position {
535                            invariant: false,
536                        })),
537                    }),
538                    local_variables: Default::default(),
539                    expressions: Default::default(),
540                    named_expressions: Default::default(),
541                    body: Default::default(),
542                };
543                let ep = EntryPoint {
544                    name: dummy_entry_point.clone(),
545                    stage: naga::ShaderStage::Vertex,
546                    function: func,
547                    early_depth_test: None,
548                    workgroup_size: [0, 0, 0],
549                };
550
551                naga_module.entry_points.push(ep);
552
553                let info = self
554                    .create_validator()
555                    .validate(naga_module)
556                    .map_err(ComposerErrorInner::HeaderValidationError)?;
557
558                let mut string = String::new();
559                let options = naga::back::glsl::Options {
560                    version: naga::back::glsl::Version::Desktop(450),
561                    writer_flags: naga::back::glsl::WriterFlags::INCLUDE_UNUSED_ITEMS,
562                    ..Default::default()
563                };
564                let pipeline_options = naga::back::glsl::PipelineOptions {
565                    shader_stage: naga::ShaderStage::Vertex,
566                    entry_point: dummy_entry_point,
567                    multiview: None,
568                };
569                let mut writer = naga::back::glsl::Writer::new(
570                    &mut string,
571                    naga_module,
572                    &info,
573                    &options,
574                    &pipeline_options,
575                    naga::proc::BoundsCheckPolicies::default(),
576                )
577                .map_err(ComposerErrorInner::GlslBackError)?;
578
579                writer.write().map_err(ComposerErrorInner::GlslBackError)?;
580
581                // strip version decl and main() impl
582                let lines: Vec<_> = string.lines().collect();
583                let string = lines[1..lines.len() - 3].join("\n");
584                trace!("glsl header for {}:\n\"\n{:?}\n\"", header_for, string);
585
586                Ok(string)
587            }
588        }
589    }
590
591    // build naga module for a given shader_def configuration. builds a minimal self-contained module built against headers for imports
592    fn create_module_ir(
593        &self,
594        name: &str,
595        source: String,
596        language: ShaderLanguage,
597        imports: &[ImportDefinition],
598        shader_defs: &HashMap<String, ShaderDefValue>,
599    ) -> Result<IrBuildResult, ComposerError> {
600        debug!("creating IR for {} with defs: {:?}", name, shader_defs);
601
602        let mut module_string = match language {
603            ShaderLanguage::Wgsl => String::new(),
604            #[cfg(feature = "glsl")]
605            ShaderLanguage::Glsl => String::from("#version 450\n"),
606        };
607
608        let mut override_functions: IndexMap<String, Vec<String>> = IndexMap::default();
609        let mut added_imports: HashSet<String> = HashSet::new();
610        let mut header_module = DerivedModule::default();
611
612        for import in imports {
613            if added_imports.contains(&import.import) {
614                continue;
615            }
616            // add to header module
617            self.add_import(
618                &mut header_module,
619                import,
620                shader_defs,
621                true,
622                &mut added_imports,
623            );
624
625            // // we must have ensured these exist with Composer::ensure_imports()
626            trace!("looking for {}", import.import);
627            let import_module_set = self.module_sets.get(&import.import).unwrap();
628            trace!("with defs {:?}", shader_defs);
629            let module = import_module_set.get_module(shader_defs).unwrap();
630            trace!("ok");
631
632            // gather overrides
633            if !module.override_functions.is_empty() {
634                for (original, replacements) in &module.override_functions {
635                    match override_functions.entry(original.clone()) {
636                        indexmap::map::Entry::Occupied(o) => {
637                            let existing = o.into_mut();
638                            let new_replacements: Vec<_> = replacements
639                                .iter()
640                                .filter(|rep| !existing.contains(rep))
641                                .cloned()
642                                .collect();
643                            existing.extend(new_replacements);
644                        }
645                        indexmap::map::Entry::Vacant(v) => {
646                            v.insert(replacements.clone());
647                        }
648                    }
649                }
650            }
651        }
652
653        let composed_header = self
654            .naga_to_string(&mut header_module.into(), language, name)
655            .map_err(|inner| ComposerError {
656                inner,
657                source: ErrSource::Module {
658                    name: name.to_owned(),
659                    offset: 0,
660                    defs: shader_defs.clone(),
661                },
662            })?;
663        module_string.push_str(&composed_header);
664
665        let start_offset = module_string.len();
666
667        module_string.push_str(&source);
668
669        trace!(
670            "parsing {}: {}, header len {}, total len {}",
671            name,
672            module_string,
673            start_offset,
674            module_string.len()
675        );
676        let module = match language {
677            ShaderLanguage::Wgsl => naga::front::wgsl::parse_str(&module_string).map_err(|e| {
678                debug!("full err'd source file: \n---\n{}\n---", module_string);
679                ComposerError {
680                    inner: ComposerErrorInner::WgslParseError(e),
681                    source: ErrSource::Module {
682                        name: name.to_owned(),
683                        offset: start_offset,
684                        defs: shader_defs.clone(),
685                    },
686                }
687            })?,
688            #[cfg(feature = "glsl")]
689            ShaderLanguage::Glsl => naga::front::glsl::Frontend::default()
690                .parse(
691                    &naga::front::glsl::Options {
692                        stage: naga::ShaderStage::Vertex,
693                        defines: Default::default(),
694                    },
695                    &module_string,
696                )
697                .map_err(|e| {
698                    debug!("full err'd source file: \n---\n{}\n---", module_string);
699                    ComposerError {
700                        inner: ComposerErrorInner::GlslParseError(e),
701                        source: ErrSource::Module {
702                            name: name.to_owned(),
703                            offset: start_offset,
704                            defs: shader_defs.clone(),
705                        },
706                    }
707                })?,
708        };
709
710        Ok(IrBuildResult {
711            module,
712            start_offset,
713            override_functions,
714        })
715    }
716
717    // check that identifiers exported by a module do not get modified in string export
718    fn validate_identifiers(
719        source_ir: &naga::Module,
720        lang: ShaderLanguage,
721        header: &str,
722        module_decoration: &str,
723        owned_types: &HashSet<String>,
724    ) -> Result<(), ComposerErrorInner> {
725        // TODO: remove this once glsl front support is complete
726        #[cfg(feature = "glsl")]
727        if lang == ShaderLanguage::Glsl {
728            return Ok(());
729        }
730
731        let recompiled = match lang {
732            ShaderLanguage::Wgsl => naga::front::wgsl::parse_str(header).unwrap(),
733            #[cfg(feature = "glsl")]
734            ShaderLanguage::Glsl => naga::front::glsl::Frontend::default()
735                .parse(
736                    &naga::front::glsl::Options {
737                        stage: naga::ShaderStage::Vertex,
738                        defines: Default::default(),
739                    },
740                    &format!("{}\n{}", header, "void main() {}"),
741                )
742                .map_err(|e| {
743                    debug!("full err'd source file: \n---\n{header}\n---");
744                    ComposerErrorInner::GlslParseError(e)
745                })?,
746        };
747
748        let recompiled_types: IndexMap<_, _> = recompiled
749            .types
750            .iter()
751            .flat_map(|(h, ty)| ty.name.as_deref().map(|name| (name, h)))
752            .collect();
753        for (h, ty) in source_ir.types.iter() {
754            if let Some(name) = &ty.name {
755                let decorated_type_name = format!("{name}{module_decoration}");
756                if !owned_types.contains(&decorated_type_name) {
757                    continue;
758                }
759                match recompiled_types.get(decorated_type_name.as_str()) {
760                    Some(recompiled_h) => {
761                        if let naga::TypeInner::Struct { members, .. } = &ty.inner {
762                            let recompiled_ty = recompiled.types.get_handle(*recompiled_h).unwrap();
763                            let naga::TypeInner::Struct {
764                                members: recompiled_members,
765                                ..
766                            } = &recompiled_ty.inner
767                            else {
768                                panic!();
769                            };
770                            for (member, recompiled_member) in
771                                members.iter().zip(recompiled_members)
772                            {
773                                if member.name != recompiled_member.name {
774                                    return Err(ComposerErrorInner::InvalidIdentifier {
775                                        original: member.name.clone().unwrap_or_default(),
776                                        at: source_ir.types.get_span(h),
777                                    });
778                                }
779                            }
780                        }
781                    }
782                    None => {
783                        return Err(ComposerErrorInner::InvalidIdentifier {
784                            original: name.clone(),
785                            at: source_ir.types.get_span(h),
786                        })
787                    }
788                }
789            }
790        }
791
792        let recompiled_consts: HashSet<_> = recompiled
793            .constants
794            .iter()
795            .flat_map(|(_, c)| c.name.as_deref())
796            .filter(|name| name.ends_with(module_decoration))
797            .collect();
798        for (h, c) in source_ir.constants.iter() {
799            if let Some(name) = &c.name {
800                if name.ends_with(module_decoration) && !recompiled_consts.contains(name.as_str()) {
801                    return Err(ComposerErrorInner::InvalidIdentifier {
802                        original: name.clone(),
803                        at: source_ir.constants.get_span(h),
804                    });
805                }
806            }
807        }
808
809        let recompiled_globals: HashSet<_> = recompiled
810            .global_variables
811            .iter()
812            .flat_map(|(_, c)| c.name.as_deref())
813            .filter(|name| name.ends_with(module_decoration))
814            .collect();
815        for (h, gv) in source_ir.global_variables.iter() {
816            if let Some(name) = &gv.name {
817                if name.ends_with(module_decoration) && !recompiled_globals.contains(name.as_str())
818                {
819                    return Err(ComposerErrorInner::InvalidIdentifier {
820                        original: name.clone(),
821                        at: source_ir.global_variables.get_span(h),
822                    });
823                }
824            }
825        }
826
827        let recompiled_fns: HashSet<_> = recompiled
828            .functions
829            .iter()
830            .flat_map(|(_, c)| c.name.as_deref())
831            .filter(|name| name.ends_with(module_decoration))
832            .collect();
833        for (h, f) in source_ir.functions.iter() {
834            if let Some(name) = &f.name {
835                if name.ends_with(module_decoration) && !recompiled_fns.contains(name.as_str()) {
836                    return Err(ComposerErrorInner::InvalidIdentifier {
837                        original: name.clone(),
838                        at: source_ir.functions.get_span(h),
839                    });
840                }
841            }
842        }
843
844        Ok(())
845    }
846
847    // build a ComposableModule from a ComposableModuleDefinition, for a given set of shader defs
848    // - build the naga IR (against headers)
849    // - record any types/vars/constants/functions that are defined within this module
850    // - build headers for each supported language
851    #[allow(clippy::too_many_arguments)]
852    fn create_composable_module(
853        &mut self,
854        module_definition: &ComposableModuleDefinition,
855        module_decoration: String,
856        shader_defs: &HashMap<String, ShaderDefValue>,
857        create_headers: bool,
858        demote_entrypoints: bool,
859        source: &str,
860        imports: Vec<ImportDefWithOffset>,
861    ) -> Result<ComposableModule, ComposerError> {
862        let mut imports: Vec<_> = imports
863            .into_iter()
864            .map(|import_with_offset| import_with_offset.definition)
865            .collect();
866        imports.extend(module_definition.additional_imports.to_vec());
867
868        trace!(
869            "create composable module {}: source len {}",
870            module_definition.name,
871            source.len()
872        );
873
874        // record virtual/overridable functions
875        let mut virtual_functions: HashSet<String> = Default::default();
876        let source = self
877            .virtual_fn_regex
878            .replace_all(source, |cap: &regex::Captures| {
879                let target_function = cap.get(2).unwrap().as_str().to_owned();
880
881                let replacement_str = format!(
882                    "{}fn {}{}(",
883                    " ".repeat(cap.get(1).unwrap().range().len() - 3),
884                    target_function,
885                    " ".repeat(cap.get(3).unwrap().range().len()),
886                );
887
888                virtual_functions.insert(target_function);
889
890                replacement_str
891            });
892
893        // record and rename override functions
894        let mut local_override_functions: IndexMap<String, String> = Default::default();
895
896        #[cfg(not(feature = "override_any"))]
897        let mut override_error = None;
898
899        let source =
900            self.override_fn_regex
901                .replace_all(&source, |cap: &regex::Captures| {
902                    let target_module = cap.get(3).unwrap().as_str().to_owned();
903                    let target_function = cap.get(2).unwrap().as_str().to_owned();
904
905                    #[cfg(not(feature = "override_any"))]
906                    {
907                        let wrap_err = |inner: ComposerErrorInner| -> ComposerError {
908                            ComposerError {
909                                inner,
910                                source: ErrSource::Module {
911                                    name: module_definition.name.to_owned(),
912                                    offset: 0,
913                                    defs: shader_defs.clone(),
914                                },
915                            }
916                        };
917
918                        // ensure overrides are applied to virtual functions
919                        let raw_module_name = Self::decode(&target_module);
920                        let module_set = self.module_sets.get(&raw_module_name);
921
922                        match module_set {
923                            None => {
924                                // TODO this should be unreachable?
925                                let pos = cap.get(3).unwrap().start();
926                                override_error = Some(wrap_err(
927                                    ComposerErrorInner::ImportNotFound(raw_module_name, pos),
928                                ));
929                            }
930                            Some(module_set) => {
931                                let module = module_set.get_module(shader_defs).unwrap();
932                                if !module.virtual_functions.contains(&target_function) {
933                                    let pos = cap.get(2).unwrap().start();
934                                    override_error =
935                                        Some(wrap_err(ComposerErrorInner::OverrideNotVirtual {
936                                            name: target_function.clone(),
937                                            pos,
938                                        }));
939                                }
940                            }
941                        }
942                    }
943
944                    let base_name = format!(
945                        "{}{}{}{}",
946                        target_function.as_str(),
947                        DECORATION_PRE,
948                        target_module.as_str(),
949                        DECORATION_POST,
950                    );
951                    let rename = format!(
952                        "{}{}{}{}",
953                        target_function.as_str(),
954                        DECORATION_OVERRIDE_PRE,
955                        target_module.as_str(),
956                        DECORATION_POST,
957                    );
958
959                    let replacement_str = format!(
960                        "{}fn {}{}(",
961                        " ".repeat(cap.get(1).unwrap().range().len() - 3),
962                        rename,
963                        " ".repeat(cap.get(4).unwrap().range().len()),
964                    );
965
966                    local_override_functions.insert(rename, base_name);
967
968                    replacement_str
969                })
970                .to_string();
971
972        #[cfg(not(feature = "override_any"))]
973        if let Some(err) = override_error {
974            return Err(err);
975        }
976
977        trace!("local overrides: {:?}", local_override_functions);
978        trace!(
979            "create composable module {}: source len {}",
980            module_definition.name,
981            source.len()
982        );
983
984        let IrBuildResult {
985            module: mut source_ir,
986            start_offset,
987            mut override_functions,
988        } = self.create_module_ir(
989            &module_definition.name,
990            source,
991            module_definition.language,
992            &imports,
993            shader_defs,
994        )?;
995
996        // from here on errors need to be reported using the modified source with start_offset
997        let wrap_err = |inner: ComposerErrorInner| -> ComposerError {
998            ComposerError {
999                inner,
1000                source: ErrSource::Module {
1001                    name: module_definition.name.to_owned(),
1002                    offset: start_offset,
1003                    defs: shader_defs.clone(),
1004                },
1005            }
1006        };
1007
1008        // add our local override to the total set of overrides for the given function
1009        for (rename, base_name) in &local_override_functions {
1010            override_functions
1011                .entry(base_name.clone())
1012                .or_default()
1013                .push(format!("{rename}{module_decoration}"));
1014        }
1015
1016        // rename and record owned items (except types which can't be mutably accessed)
1017        let mut owned_constants = IndexMap::new();
1018        for (h, c) in source_ir.constants.iter_mut() {
1019            if let Some(name) = c.name.as_mut() {
1020                if !name.contains(DECORATION_PRE) {
1021                    *name = format!("{name}{module_decoration}");
1022                    owned_constants.insert(name.clone(), h);
1023                }
1024            }
1025        }
1026
1027        // These are naga/wgpu's pipeline override constants, not naga_oil's overrides
1028        let mut owned_pipeline_overrides = IndexMap::new();
1029        for (h, po) in source_ir.overrides.iter_mut() {
1030            if let Some(name) = po.name.as_mut() {
1031                if !name.contains(DECORATION_PRE) {
1032                    *name = format!("{name}{module_decoration}");
1033                    owned_pipeline_overrides.insert(name.clone(), h);
1034                }
1035            }
1036        }
1037
1038        let mut owned_vars = IndexMap::new();
1039        for (h, gv) in source_ir.global_variables.iter_mut() {
1040            if let Some(name) = gv.name.as_mut() {
1041                if !name.contains(DECORATION_PRE) {
1042                    *name = format!("{name}{module_decoration}");
1043
1044                    owned_vars.insert(name.clone(), h);
1045                }
1046            }
1047        }
1048
1049        let mut owned_functions = IndexMap::new();
1050        for (h_f, f) in source_ir.functions.iter_mut() {
1051            if let Some(name) = f.name.as_mut() {
1052                if !name.contains(DECORATION_PRE) {
1053                    *name = format!("{name}{module_decoration}");
1054
1055                    // create dummy header function
1056                    let header_function = naga::Function {
1057                        name: Some(name.clone()),
1058                        arguments: f.arguments.to_vec(),
1059                        result: f.result.clone(),
1060                        local_variables: Default::default(),
1061                        expressions: Default::default(),
1062                        named_expressions: Default::default(),
1063                        body: Default::default(),
1064                    };
1065
1066                    // record owned function
1067                    owned_functions.insert(name.clone(), (Some(h_f), header_function));
1068                }
1069            }
1070        }
1071
1072        if demote_entrypoints {
1073            // make normal functions out of the source entry points
1074            for ep in &mut source_ir.entry_points {
1075                ep.function.name = Some(format!(
1076                    "{}{}",
1077                    ep.function.name.as_deref().unwrap_or("main"),
1078                    module_decoration,
1079                ));
1080                let header_function = naga::Function {
1081                    name: ep.function.name.clone(),
1082                    arguments: ep
1083                        .function
1084                        .arguments
1085                        .iter()
1086                        .cloned()
1087                        .map(|arg| naga::FunctionArgument {
1088                            name: arg.name,
1089                            ty: arg.ty,
1090                            binding: None,
1091                        })
1092                        .collect(),
1093                    result: ep.function.result.clone().map(|res| naga::FunctionResult {
1094                        ty: res.ty,
1095                        binding: None,
1096                    }),
1097                    local_variables: Default::default(),
1098                    expressions: Default::default(),
1099                    named_expressions: Default::default(),
1100                    body: Default::default(),
1101                };
1102
1103                owned_functions.insert(ep.function.name.clone().unwrap(), (None, header_function));
1104            }
1105        };
1106
1107        let mut module_builder = DerivedModule::default();
1108        let mut header_builder = DerivedModule::default();
1109        module_builder.set_shader_source(&source_ir, 0);
1110        header_builder.set_shader_source(&source_ir, 0);
1111
1112        let mut owned_types = HashSet::new();
1113        for (h, ty) in source_ir.types.iter() {
1114            if let Some(name) = &ty.name {
1115                // we need to exclude autogenerated struct names, i.e. those that begin with "__"
1116                // "__" is a reserved prefix for naga so user variables cannot use it.
1117                if !name.contains(DECORATION_PRE) && !name.starts_with("__") {
1118                    let name = format!("{name}{module_decoration}");
1119                    owned_types.insert(name.clone());
1120                    // copy and rename types
1121                    module_builder.rename_type(&h, Some(name.clone()));
1122                    header_builder.rename_type(&h, Some(name));
1123                    continue;
1124                }
1125            }
1126
1127            // copy all required types
1128            module_builder.import_type(&h);
1129        }
1130
1131        // copy owned types into header and module
1132        for h in owned_constants.values() {
1133            header_builder.import_const(h);
1134            module_builder.import_const(h);
1135        }
1136
1137        for h in owned_pipeline_overrides.values() {
1138            header_builder.import_pipeline_override(h);
1139            module_builder.import_pipeline_override(h);
1140        }
1141
1142        for h in owned_vars.values() {
1143            header_builder.import_global(h);
1144            module_builder.import_global(h);
1145        }
1146
1147        // only stubs of owned functions into the header
1148        for (h_f, f) in owned_functions.values() {
1149            let span = h_f
1150                .map(|h_f| source_ir.functions.get_span(h_f))
1151                .unwrap_or(naga::Span::UNDEFINED);
1152            header_builder.import_function(f, span); // header stub function
1153        }
1154        // all functions into the module (note source_ir only contains stubs for imported functions)
1155        for (h_f, f) in source_ir.functions.iter() {
1156            let span = source_ir.functions.get_span(h_f);
1157            module_builder.import_function(f, span);
1158        }
1159        // // including entry points as vanilla functions if required
1160        if demote_entrypoints {
1161            for ep in &source_ir.entry_points {
1162                let mut f = ep.function.clone();
1163                f.arguments = f
1164                    .arguments
1165                    .into_iter()
1166                    .map(|arg| naga::FunctionArgument {
1167                        name: arg.name,
1168                        ty: arg.ty,
1169                        binding: None,
1170                    })
1171                    .collect();
1172                f.result = f.result.map(|res| naga::FunctionResult {
1173                    ty: res.ty,
1174                    binding: None,
1175                });
1176
1177                module_builder.import_function(&f, naga::Span::UNDEFINED);
1178                // todo figure out how to get span info for entrypoints
1179            }
1180        }
1181
1182        let module_ir = module_builder.into_module_with_entrypoints();
1183        let mut header_ir: naga::Module = header_builder.into();
1184
1185        if self.validate && create_headers {
1186            // check that identifiers haven't been renamed
1187            #[allow(clippy::single_element_loop)]
1188            for language in [
1189                ShaderLanguage::Wgsl,
1190                #[cfg(feature = "glsl")]
1191                ShaderLanguage::Glsl,
1192            ] {
1193                let header = self
1194                    .naga_to_string(&mut header_ir, language, &module_definition.name)
1195                    .map_err(wrap_err)?;
1196                Self::validate_identifiers(
1197                    &source_ir,
1198                    language,
1199                    &header,
1200                    &module_decoration,
1201                    &owned_types,
1202                )
1203                .map_err(wrap_err)?;
1204            }
1205        }
1206
1207        let composable_module = ComposableModule {
1208            decorated_name: module_decoration,
1209            imports,
1210            owned_types,
1211            owned_constants: owned_constants.into_keys().collect(),
1212            owned_vars: owned_vars.into_keys().collect(),
1213            owned_functions: owned_functions.into_keys().collect(),
1214            virtual_functions,
1215            override_functions,
1216            module_ir,
1217            header_ir,
1218            start_offset,
1219        };
1220
1221        Ok(composable_module)
1222    }
1223
1224    // shunt all data owned by a composable into a derived module
1225    fn add_composable_data<'a>(
1226        derived: &mut DerivedModule<'a>,
1227        composable: &'a ComposableModule,
1228        items: Option<&Vec<String>>,
1229        span_offset: usize,
1230        header: bool,
1231    ) {
1232        let items: Option<HashSet<String>> = items.map(|items| {
1233            items
1234                .iter()
1235                .map(|item| format!("{}{}", item, composable.decorated_name))
1236                .collect()
1237        });
1238        let items = items.as_ref();
1239
1240        let source_ir = match header {
1241            true => &composable.header_ir,
1242            false => &composable.module_ir,
1243        };
1244
1245        derived.set_shader_source(source_ir, span_offset);
1246
1247        for (h, ty) in source_ir.types.iter() {
1248            if let Some(name) = &ty.name {
1249                if composable.owned_types.contains(name)
1250                    && items.map_or(true, |items| items.contains(name))
1251                {
1252                    derived.import_type(&h);
1253                }
1254            }
1255        }
1256
1257        for (h, c) in source_ir.constants.iter() {
1258            if let Some(name) = &c.name {
1259                if composable.owned_constants.contains(name)
1260                    && items.map_or(true, |items| items.contains(name))
1261                {
1262                    derived.import_const(&h);
1263                }
1264            }
1265        }
1266
1267        for (h, po) in source_ir.overrides.iter() {
1268            if let Some(name) = &po.name {
1269                if composable.owned_functions.contains(name)
1270                    && items.map_or(true, |items| items.contains(name))
1271                {
1272                    derived.import_pipeline_override(&h);
1273                }
1274            }
1275        }
1276
1277        for (h, v) in source_ir.global_variables.iter() {
1278            if let Some(name) = &v.name {
1279                if composable.owned_vars.contains(name)
1280                    && items.map_or(true, |items| items.contains(name))
1281                {
1282                    derived.import_global(&h);
1283                }
1284            }
1285        }
1286
1287        for (h_f, f) in source_ir.functions.iter() {
1288            if let Some(name) = &f.name {
1289                if composable.owned_functions.contains(name)
1290                    && (items.map_or(true, |items| items.contains(name))
1291                        || composable
1292                            .override_functions
1293                            .values()
1294                            .any(|v| v.contains(name)))
1295                {
1296                    let span = composable.module_ir.functions.get_span(h_f);
1297                    derived.import_function_if_new(f, span);
1298                }
1299            }
1300        }
1301
1302        derived.clear_shader_source();
1303    }
1304
1305    // add an import (and recursive imports) into a derived module
1306    fn add_import<'a>(
1307        &'a self,
1308        derived: &mut DerivedModule<'a>,
1309        import: &ImportDefinition,
1310        shader_defs: &HashMap<String, ShaderDefValue>,
1311        header: bool,
1312        already_added: &mut HashSet<String>,
1313    ) {
1314        if already_added.contains(&import.import) {
1315            trace!("skipping {}, already added", import.import);
1316            return;
1317        }
1318
1319        let import_module_set = self.module_sets.get(&import.import).unwrap();
1320        let module = import_module_set.get_module(shader_defs).unwrap();
1321
1322        for import in &module.imports {
1323            self.add_import(derived, import, shader_defs, header, already_added);
1324        }
1325
1326        Self::add_composable_data(
1327            derived,
1328            module,
1329            Some(&import.items),
1330            import_module_set.module_index << SPAN_SHIFT,
1331            header,
1332        );
1333    }
1334
1335    fn ensure_import(
1336        &mut self,
1337        module_set: &ComposableModuleDefinition,
1338        shader_defs: &HashMap<String, ShaderDefValue>,
1339    ) -> Result<ComposableModule, ComposerError> {
1340        let PreprocessOutput {
1341            preprocessed_source,
1342            imports,
1343        } = self
1344            .preprocessor
1345            .preprocess(&module_set.sanitized_source, shader_defs, self.validate)
1346            .map_err(|inner| ComposerError {
1347                inner,
1348                source: ErrSource::Module {
1349                    name: module_set.name.to_owned(),
1350                    offset: 0,
1351                    defs: shader_defs.clone(),
1352                },
1353            })?;
1354
1355        self.ensure_imports(imports.iter().map(|import| &import.definition), shader_defs)?;
1356        self.ensure_imports(&module_set.additional_imports, shader_defs)?;
1357
1358        self.create_composable_module(
1359            module_set,
1360            Self::decorate(&module_set.name),
1361            shader_defs,
1362            true,
1363            true,
1364            &preprocessed_source,
1365            imports,
1366        )
1367    }
1368
1369    // build required ComposableModules for a given set of shader_defs
1370    fn ensure_imports<'a>(
1371        &mut self,
1372        imports: impl IntoIterator<Item = &'a ImportDefinition>,
1373        shader_defs: &HashMap<String, ShaderDefValue>,
1374    ) -> Result<(), ComposerError> {
1375        for ImportDefinition { import, .. } in imports.into_iter() {
1376            // we've already ensured imports exist when they were added
1377            let module_set = self.module_sets.get(import).unwrap();
1378            if module_set.get_module(shader_defs).is_some() {
1379                continue;
1380            }
1381
1382            // we need to build the module
1383            // take the set so we can recurse without borrowing
1384            let (set_key, mut module_set) = self.module_sets.remove_entry(import).unwrap();
1385
1386            match self.ensure_import(&module_set, shader_defs) {
1387                Ok(module) => {
1388                    module_set.insert_module(shader_defs, module);
1389                    self.module_sets.insert(set_key, module_set);
1390                }
1391                Err(e) => {
1392                    self.module_sets.insert(set_key, module_set);
1393                    return Err(e);
1394                }
1395            }
1396        }
1397
1398        Ok(())
1399    }
1400}
1401
1402#[derive(Default)]
1403pub struct ComposableModuleDescriptor<'a> {
1404    pub source: &'a str,
1405    pub file_path: &'a str,
1406    pub language: ShaderLanguage,
1407    pub as_name: Option<String>,
1408    pub additional_imports: &'a [ImportDefinition],
1409    pub shader_defs: HashMap<String, ShaderDefValue>,
1410}
1411
1412#[derive(Default)]
1413pub struct NagaModuleDescriptor<'a> {
1414    pub source: &'a str,
1415    pub file_path: &'a str,
1416    pub shader_type: ShaderType,
1417    pub shader_defs: HashMap<String, ShaderDefValue>,
1418    pub additional_imports: &'a [ImportDefinition],
1419}
1420
1421// public api
1422impl Composer {
1423    /// create a non-validating composer.
1424    /// validation errors in the final shader will not be caught, and errors resulting from their
1425    /// use will have bad span data, so codespan reporting will fail.
1426    /// use default() to create a validating composer.
1427    pub fn non_validating() -> Self {
1428        Self {
1429            validate: false,
1430            ..Default::default()
1431        }
1432    }
1433
1434    /// specify capabilities to be used for naga module generation.
1435    /// purges any existing modules
1436    /// See https://github.com/gfx-rs/wgpu/blob/d9c054c645af0ea9ef81617c3e762fbf0f3fecda/wgpu-core/src/device/mod.rs#L515
1437    /// for how to set the subgroup_stages value.
1438    pub fn with_capabilities(
1439        self,
1440        capabilities: naga::valid::Capabilities,
1441        subgroup_stages: naga::valid::ShaderStages,
1442    ) -> Self {
1443        Self {
1444            capabilities,
1445            validate: self.validate,
1446            subgroup_stages,
1447            ..Default::default()
1448        }
1449    }
1450
1451    /// check if a module with the given name has been added
1452    pub fn contains_module(&self, module_name: &str) -> bool {
1453        self.module_sets.contains_key(module_name)
1454    }
1455
1456    /// add a composable module to the composer.
1457    /// all modules imported by this module must already have been added
1458    pub fn add_composable_module(
1459        &mut self,
1460        desc: ComposableModuleDescriptor,
1461    ) -> Result<&ComposableModuleDefinition, ComposerError> {
1462        let ComposableModuleDescriptor {
1463            source,
1464            file_path,
1465            language,
1466            as_name,
1467            additional_imports,
1468            mut shader_defs,
1469        } = desc;
1470
1471        // reject a module containing the DECORATION strings
1472        if let Some(decor) = self.check_decoration_regex.find(source) {
1473            return Err(ComposerError {
1474                inner: ComposerErrorInner::DecorationInSource(decor.range()),
1475                source: ErrSource::Constructing {
1476                    path: file_path.to_owned(),
1477                    source: source.to_owned(),
1478                    offset: 0,
1479                },
1480            });
1481        }
1482
1483        let substituted_source = self.sanitize_and_set_auto_bindings(source);
1484
1485        let PreprocessorMetaData {
1486            name: module_name,
1487            mut imports,
1488            mut effective_defs,
1489            ..
1490        } = self
1491            .preprocessor
1492            .get_preprocessor_metadata(&substituted_source, false)
1493            .map_err(|inner| ComposerError {
1494                inner,
1495                source: ErrSource::Constructing {
1496                    path: file_path.to_owned(),
1497                    source: source.to_owned(),
1498                    offset: 0,
1499                },
1500            })?;
1501        let module_name = as_name.or(module_name);
1502        if module_name.is_none() {
1503            return Err(ComposerError {
1504                inner: ComposerErrorInner::NoModuleName,
1505                source: ErrSource::Constructing {
1506                    path: file_path.to_owned(),
1507                    source: source.to_owned(),
1508                    offset: 0,
1509                },
1510            });
1511        }
1512        let module_name = module_name.unwrap();
1513
1514        debug!(
1515            "adding module definition for {} with defs: {:?}",
1516            module_name, shader_defs
1517        );
1518
1519        // add custom imports
1520        let additional_imports = additional_imports.to_vec();
1521        imports.extend(
1522            additional_imports
1523                .iter()
1524                .cloned()
1525                .map(|def| ImportDefWithOffset {
1526                    definition: def,
1527                    offset: 0,
1528                }),
1529        );
1530
1531        for import in &imports {
1532            // we require modules already added so that we can capture the shader_defs that may impact us by impacting our dependencies
1533            let module_set = self
1534                .module_sets
1535                .get(&import.definition.import)
1536                .ok_or_else(|| ComposerError {
1537                    inner: ComposerErrorInner::ImportNotFound(
1538                        import.definition.import.clone(),
1539                        import.offset,
1540                    ),
1541                    source: ErrSource::Constructing {
1542                        path: file_path.to_owned(),
1543                        source: substituted_source.to_owned(),
1544                        offset: 0,
1545                    },
1546                })?;
1547            effective_defs.extend(module_set.effective_defs.iter().cloned());
1548            shader_defs.extend(
1549                module_set
1550                    .shader_defs
1551                    .iter()
1552                    .map(|def| (def.0.clone(), *def.1)),
1553            );
1554        }
1555
1556        // remove defs that are already specified through our imports
1557        effective_defs.retain(|name| !shader_defs.contains_key(name));
1558
1559        // can't gracefully report errors for more modules. perhaps this should be a warning
1560        assert!((self.module_sets.len() as u32) < u32::MAX >> SPAN_SHIFT);
1561        let module_index = self.module_sets.len() + 1;
1562
1563        let module_set = ComposableModuleDefinition {
1564            name: module_name.clone(),
1565            sanitized_source: substituted_source,
1566            file_path: file_path.to_owned(),
1567            language,
1568            effective_defs: effective_defs.into_iter().collect(),
1569            all_imports: imports.into_iter().map(|id| id.definition.import).collect(),
1570            additional_imports,
1571            shader_defs,
1572            module_index,
1573            modules: Default::default(),
1574        };
1575
1576        // invalidate dependent modules if this module already exists
1577        self.remove_composable_module(&module_name);
1578
1579        self.module_sets.insert(module_name.clone(), module_set);
1580        self.module_index.insert(module_index, module_name.clone());
1581        Ok(self.module_sets.get(&module_name).unwrap())
1582    }
1583
1584    /// remove a composable module. also removes modules that depend on this module, as we cannot be sure about
1585    /// the completeness of their effective shader defs any more...
1586    pub fn remove_composable_module(&mut self, module_name: &str) {
1587        // todo this could be improved by making effective defs an Option<HashSet> and populating on demand?
1588        let mut dependent_sets = Vec::new();
1589
1590        if self.module_sets.remove(module_name).is_some() {
1591            dependent_sets.extend(self.module_sets.iter().filter_map(|(dependent_name, set)| {
1592                if set.all_imports.contains(module_name) {
1593                    Some(dependent_name.clone())
1594                } else {
1595                    None
1596                }
1597            }));
1598        }
1599
1600        for dependent_set in dependent_sets {
1601            self.remove_composable_module(&dependent_set);
1602        }
1603    }
1604
1605    /// build a naga shader module
1606    pub fn make_naga_module(
1607        &mut self,
1608        desc: NagaModuleDescriptor,
1609    ) -> Result<naga::Module, ComposerError> {
1610        let NagaModuleDescriptor {
1611            source,
1612            file_path,
1613            shader_type,
1614            mut shader_defs,
1615            additional_imports,
1616        } = desc;
1617
1618        let sanitized_source = self.sanitize_and_set_auto_bindings(source);
1619
1620        let PreprocessorMetaData {
1621            name,
1622            defines,
1623            imports,
1624            ..
1625        } = self
1626            .preprocessor
1627            .get_preprocessor_metadata(&sanitized_source, true)
1628            .map_err(|inner| ComposerError {
1629                inner,
1630                source: ErrSource::Constructing {
1631                    path: file_path.to_owned(),
1632                    source: sanitized_source.to_owned(),
1633                    offset: 0,
1634                },
1635            })?;
1636        shader_defs.extend(defines);
1637
1638        let name = name.unwrap_or_default();
1639
1640        // make sure imports have been added
1641        // and gather additional defs specified at module level
1642        for (import_name, offset) in imports
1643            .iter()
1644            .map(|id| (&id.definition.import, id.offset))
1645            .chain(additional_imports.iter().map(|ai| (&ai.import, 0)))
1646        {
1647            if let Some(module_set) = self.module_sets.get(import_name) {
1648                for (def, value) in &module_set.shader_defs {
1649                    if let Some(prior_value) = shader_defs.insert(def.clone(), *value) {
1650                        if prior_value != *value {
1651                            return Err(ComposerError {
1652                                inner: ComposerErrorInner::InconsistentShaderDefValue {
1653                                    def: def.clone(),
1654                                },
1655                                source: ErrSource::Constructing {
1656                                    path: file_path.to_owned(),
1657                                    source: sanitized_source.to_owned(),
1658                                    offset: 0,
1659                                },
1660                            });
1661                        }
1662                    }
1663                }
1664            } else {
1665                return Err(ComposerError {
1666                    inner: ComposerErrorInner::ImportNotFound(import_name.clone(), offset),
1667                    source: ErrSource::Constructing {
1668                        path: file_path.to_owned(),
1669                        source: sanitized_source,
1670                        offset: 0,
1671                    },
1672                });
1673            }
1674        }
1675        self.ensure_imports(
1676            imports.iter().map(|import| &import.definition),
1677            &shader_defs,
1678        )?;
1679        self.ensure_imports(additional_imports, &shader_defs)?;
1680
1681        let definition = ComposableModuleDefinition {
1682            name,
1683            sanitized_source: sanitized_source.clone(),
1684            language: shader_type.into(),
1685            file_path: file_path.to_owned(),
1686            module_index: 0,
1687            additional_imports: additional_imports.to_vec(),
1688            // we don't care about these for creating a top-level module
1689            effective_defs: Default::default(),
1690            all_imports: Default::default(),
1691            shader_defs: Default::default(),
1692            modules: Default::default(),
1693        };
1694
1695        let PreprocessOutput {
1696            preprocessed_source,
1697            imports,
1698        } = self
1699            .preprocessor
1700            .preprocess(&sanitized_source, &shader_defs, self.validate)
1701            .map_err(|inner| ComposerError {
1702                inner,
1703                source: ErrSource::Constructing {
1704                    path: file_path.to_owned(),
1705                    source: sanitized_source,
1706                    offset: 0,
1707                },
1708            })?;
1709
1710        let composable = self
1711            .create_composable_module(
1712                &definition,
1713                String::from(""),
1714                &shader_defs,
1715                false,
1716                false,
1717                &preprocessed_source,
1718                imports,
1719            )
1720            .map_err(|e| ComposerError {
1721                inner: e.inner,
1722                source: ErrSource::Constructing {
1723                    path: definition.file_path.to_owned(),
1724                    source: preprocessed_source.clone(),
1725                    offset: e.source.offset(),
1726                },
1727            })?;
1728
1729        let mut derived = DerivedModule::default();
1730
1731        let mut already_added = Default::default();
1732        for import in &composable.imports {
1733            self.add_import(
1734                &mut derived,
1735                import,
1736                &shader_defs,
1737                false,
1738                &mut already_added,
1739            );
1740        }
1741
1742        Self::add_composable_data(&mut derived, &composable, None, 0, false);
1743
1744        let stage = match shader_type {
1745            #[cfg(feature = "glsl")]
1746            ShaderType::GlslVertex => Some(naga::ShaderStage::Vertex),
1747            #[cfg(feature = "glsl")]
1748            ShaderType::GlslFragment => Some(naga::ShaderStage::Fragment),
1749            _ => None,
1750        };
1751
1752        let mut entry_points = Vec::default();
1753        derived.set_shader_source(&composable.module_ir, 0);
1754        for ep in &composable.module_ir.entry_points {
1755            let mapped_func = derived.localize_function(&ep.function);
1756            entry_points.push(EntryPoint {
1757                name: ep.name.clone(),
1758                function: mapped_func,
1759                stage: stage.unwrap_or(ep.stage),
1760                early_depth_test: ep.early_depth_test,
1761                workgroup_size: ep.workgroup_size,
1762            });
1763        }
1764
1765        let mut naga_module = naga::Module {
1766            entry_points,
1767            ..derived.into()
1768        };
1769
1770        // apply overrides
1771        if !composable.override_functions.is_empty() {
1772            let mut redirect = Redirector::new(naga_module);
1773
1774            for (base_function, overrides) in composable.override_functions {
1775                let mut omit = HashSet::default();
1776
1777                let mut original = base_function;
1778                for replacement in overrides {
1779                    let (_h_orig, _h_replace) = redirect
1780                        .redirect_function(&original, &replacement, &omit)
1781                        .map_err(|e| ComposerError {
1782                            inner: e.into(),
1783                            source: ErrSource::Constructing {
1784                                path: file_path.to_owned(),
1785                                source: preprocessed_source.clone(),
1786                                offset: composable.start_offset,
1787                            },
1788                        })?;
1789                    omit.insert(replacement.clone());
1790                    original = replacement;
1791                }
1792            }
1793
1794            naga_module = redirect.into_module().map_err(|e| ComposerError {
1795                inner: e.into(),
1796                source: ErrSource::Constructing {
1797                    path: file_path.to_owned(),
1798                    source: preprocessed_source.clone(),
1799                    offset: composable.start_offset,
1800                },
1801            })?;
1802        }
1803
1804        // validation
1805        if self.validate {
1806            let info = self.create_validator().validate(&naga_module);
1807            match info {
1808                Ok(_) => Ok(naga_module),
1809                Err(e) => {
1810                    let original_span = e.spans().last();
1811                    let err_source = match original_span.and_then(|(span, _)| span.to_range()) {
1812                        Some(rng) => {
1813                            let module_index = rng.start >> SPAN_SHIFT;
1814                            match module_index {
1815                                0 => ErrSource::Constructing {
1816                                    path: file_path.to_owned(),
1817                                    source: preprocessed_source.clone(),
1818                                    offset: composable.start_offset,
1819                                },
1820                                _ => {
1821                                    let module_name =
1822                                        self.module_index.get(&module_index).unwrap().clone();
1823                                    let offset = self
1824                                        .module_sets
1825                                        .get(&module_name)
1826                                        .unwrap()
1827                                        .get_module(&shader_defs)
1828                                        .unwrap()
1829                                        .start_offset;
1830                                    ErrSource::Module {
1831                                        name: module_name,
1832                                        offset,
1833                                        defs: shader_defs.clone(),
1834                                    }
1835                                }
1836                            }
1837                        }
1838                        None => ErrSource::Constructing {
1839                            path: file_path.to_owned(),
1840                            source: preprocessed_source.clone(),
1841                            offset: composable.start_offset,
1842                        },
1843                    };
1844
1845                    Err(ComposerError {
1846                        inner: ComposerErrorInner::ShaderValidationError(e),
1847                        source: err_source,
1848                    })
1849                }
1850            }
1851        } else {
1852            Ok(naga_module)
1853        }
1854    }
1855}
1856
1857static PREPROCESSOR: once_cell::sync::Lazy<Preprocessor> =
1858    once_cell::sync::Lazy::new(Preprocessor::default);
1859
1860/// Get module name and all required imports (ignoring shader_defs) from a shader string
1861pub fn get_preprocessor_data(
1862    source: &str,
1863) -> (
1864    Option<String>,
1865    Vec<ImportDefinition>,
1866    HashMap<String, ShaderDefValue>,
1867) {
1868    if let Ok(PreprocessorMetaData {
1869        name,
1870        imports,
1871        defines,
1872        ..
1873    }) = PREPROCESSOR.get_preprocessor_metadata(source, true)
1874    {
1875        (
1876            name,
1877            imports
1878                .into_iter()
1879                .map(|import_with_offset| import_with_offset.definition)
1880                .collect(),
1881            defines,
1882        )
1883    } else {
1884        // if errors occur we return nothing; the actual error will be displayed when the caller attempts to use the shader
1885        Default::default()
1886    }
1887}