naga/valid/
mod.rs

1/*!
2Shader validator.
3*/
4
5mod analyzer;
6mod compose;
7mod expression;
8mod function;
9mod handles;
10mod interface;
11mod r#type;
12
13use crate::{
14    arena::Handle,
15    proc::{ExpressionKindTracker, LayoutError, Layouter, TypeResolution},
16    FastHashSet,
17};
18use bit_set::BitSet;
19use std::ops;
20
21//TODO: analyze the model at the same time as we validate it,
22// merge the corresponding matches over expressions and statements.
23
24use crate::span::{AddSpan as _, WithSpan};
25pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, UniformityRequirements};
26pub use compose::ComposeError;
27pub use expression::{check_literal_value, LiteralError};
28pub use expression::{ConstExpressionError, ExpressionError};
29pub use function::{CallError, FunctionError, LocalVariableError};
30pub use interface::{EntryPointError, GlobalVariableError, VaryingError};
31pub use r#type::{Disalignment, TypeError, TypeFlags, WidthError};
32
33use self::handles::InvalidHandleError;
34
35bitflags::bitflags! {
36    /// Validation flags.
37    ///
38    /// If you are working with trusted shaders, then you may be able
39    /// to save some time by skipping validation.
40    ///
41    /// If you do not perform full validation, invalid shaders may
42    /// cause Naga to panic. If you do perform full validation and
43    /// [`Validator::validate`] returns `Ok`, then Naga promises that
44    /// code generation will either succeed or return an error; it
45    /// should never panic.
46    ///
47    /// The default value for `ValidationFlags` is
48    /// `ValidationFlags::all()`.
49    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
50    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
51    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
52    pub struct ValidationFlags: u8 {
53        /// Expressions.
54        const EXPRESSIONS = 0x1;
55        /// Statements and blocks of them.
56        const BLOCKS = 0x2;
57        /// Uniformity of control flow for operations that require it.
58        const CONTROL_FLOW_UNIFORMITY = 0x4;
59        /// Host-shareable structure layouts.
60        const STRUCT_LAYOUTS = 0x8;
61        /// Constants.
62        const CONSTANTS = 0x10;
63        /// Group, binding, and location attributes.
64        const BINDINGS = 0x20;
65    }
66}
67
68impl Default for ValidationFlags {
69    fn default() -> Self {
70        Self::all()
71    }
72}
73
74bitflags::bitflags! {
75    /// Allowed IR capabilities.
76    #[must_use]
77    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
78    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
79    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
80    pub struct Capabilities: u32 {
81        /// Support for [`AddressSpace:PushConstant`].
82        const PUSH_CONSTANT = 0x1;
83        /// Float values with width = 8.
84        const FLOAT64 = 0x2;
85        /// Support for [`Builtin:PrimitiveIndex`].
86        const PRIMITIVE_INDEX = 0x4;
87        /// Support for non-uniform indexing of sampled textures and storage buffer arrays.
88        const SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING = 0x8;
89        /// Support for non-uniform indexing of uniform buffers and storage texture arrays.
90        const UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING = 0x10;
91        /// Support for non-uniform indexing of samplers.
92        const SAMPLER_NON_UNIFORM_INDEXING = 0x20;
93        /// Support for [`Builtin::ClipDistance`].
94        const CLIP_DISTANCE = 0x40;
95        /// Support for [`Builtin::CullDistance`].
96        const CULL_DISTANCE = 0x80;
97        /// Support for 16-bit normalized storage texture formats.
98        const STORAGE_TEXTURE_16BIT_NORM_FORMATS = 0x100;
99        /// Support for [`BuiltIn::ViewIndex`].
100        const MULTIVIEW = 0x200;
101        /// Support for `early_depth_test`.
102        const EARLY_DEPTH_TEST = 0x400;
103        /// Support for [`Builtin::SampleIndex`] and [`Sampling::Sample`].
104        const MULTISAMPLED_SHADING = 0x800;
105        /// Support for ray queries and acceleration structures.
106        const RAY_QUERY = 0x1000;
107        /// Support for generating two sources for blending from fragment shaders.
108        const DUAL_SOURCE_BLENDING = 0x2000;
109        /// Support for arrayed cube textures.
110        const CUBE_ARRAY_TEXTURES = 0x4000;
111        /// Support for 64-bit signed and unsigned integers.
112        const SHADER_INT64 = 0x8000;
113        /// Support for subgroup operations.
114        const SUBGROUP = 0x10000;
115        /// Support for subgroup barriers.
116        const SUBGROUP_BARRIER = 0x20000;
117    }
118}
119
120impl Default for Capabilities {
121    fn default() -> Self {
122        Self::MULTISAMPLED_SHADING | Self::CUBE_ARRAY_TEXTURES
123    }
124}
125
126bitflags::bitflags! {
127    /// Supported subgroup operations
128    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
129    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
130    #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
131    pub struct SubgroupOperationSet: u8 {
132        /// Elect, Barrier
133        const BASIC = 1 << 0;
134        /// Any, All
135        const VOTE = 1 << 1;
136        /// reductions, scans
137        const ARITHMETIC = 1 << 2;
138        /// ballot, broadcast
139        const BALLOT = 1 << 3;
140        /// shuffle, shuffle xor
141        const SHUFFLE = 1 << 4;
142        /// shuffle up, down
143        const SHUFFLE_RELATIVE = 1 << 5;
144        // We don't support these operations yet
145        // /// Clustered
146        // const CLUSTERED = 1 << 6;
147        // /// Quad supported
148        // const QUAD_FRAGMENT_COMPUTE = 1 << 7;
149        // /// Quad supported in all stages
150        // const QUAD_ALL_STAGES = 1 << 8;
151    }
152}
153
154impl super::SubgroupOperation {
155    const fn required_operations(&self) -> SubgroupOperationSet {
156        use SubgroupOperationSet as S;
157        match *self {
158            Self::All | Self::Any => S::VOTE,
159            Self::Add | Self::Mul | Self::Min | Self::Max | Self::And | Self::Or | Self::Xor => {
160                S::ARITHMETIC
161            }
162        }
163    }
164}
165
166impl super::GatherMode {
167    const fn required_operations(&self) -> SubgroupOperationSet {
168        use SubgroupOperationSet as S;
169        match *self {
170            Self::BroadcastFirst | Self::Broadcast(_) => S::BALLOT,
171            Self::Shuffle(_) | Self::ShuffleXor(_) => S::SHUFFLE,
172            Self::ShuffleUp(_) | Self::ShuffleDown(_) => S::SHUFFLE_RELATIVE,
173        }
174    }
175}
176
177bitflags::bitflags! {
178    /// Validation flags.
179    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
180    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
181    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
182    pub struct ShaderStages: u8 {
183        const VERTEX = 0x1;
184        const FRAGMENT = 0x2;
185        const COMPUTE = 0x4;
186    }
187}
188
189#[derive(Debug, Clone)]
190#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
191#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
192pub struct ModuleInfo {
193    type_flags: Vec<TypeFlags>,
194    functions: Vec<FunctionInfo>,
195    entry_points: Vec<FunctionInfo>,
196    const_expression_types: Box<[TypeResolution]>,
197}
198
199impl ops::Index<Handle<crate::Type>> for ModuleInfo {
200    type Output = TypeFlags;
201    fn index(&self, handle: Handle<crate::Type>) -> &Self::Output {
202        &self.type_flags[handle.index()]
203    }
204}
205
206impl ops::Index<Handle<crate::Function>> for ModuleInfo {
207    type Output = FunctionInfo;
208    fn index(&self, handle: Handle<crate::Function>) -> &Self::Output {
209        &self.functions[handle.index()]
210    }
211}
212
213impl ops::Index<Handle<crate::Expression>> for ModuleInfo {
214    type Output = TypeResolution;
215    fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
216        &self.const_expression_types[handle.index()]
217    }
218}
219
220#[derive(Debug)]
221pub struct Validator {
222    flags: ValidationFlags,
223    capabilities: Capabilities,
224    subgroup_stages: ShaderStages,
225    subgroup_operations: SubgroupOperationSet,
226    types: Vec<r#type::TypeInfo>,
227    layouter: Layouter,
228    location_mask: BitSet,
229    ep_resource_bindings: FastHashSet<crate::ResourceBinding>,
230    #[allow(dead_code)]
231    switch_values: FastHashSet<crate::SwitchValue>,
232    valid_expression_list: Vec<Handle<crate::Expression>>,
233    valid_expression_set: BitSet,
234    override_ids: FastHashSet<u16>,
235    allow_overrides: bool,
236}
237
238#[derive(Clone, Debug, thiserror::Error)]
239#[cfg_attr(test, derive(PartialEq))]
240pub enum ConstantError {
241    #[error("Initializer must be a const-expression")]
242    InitializerExprType,
243    #[error("The type doesn't match the constant")]
244    InvalidType,
245    #[error("The type is not constructible")]
246    NonConstructibleType,
247}
248
249#[derive(Clone, Debug, thiserror::Error)]
250#[cfg_attr(test, derive(PartialEq))]
251pub enum OverrideError {
252    #[error("Override name and ID are missing")]
253    MissingNameAndID,
254    #[error("Override ID must be unique")]
255    DuplicateID,
256    #[error("Initializer must be a const-expression or override-expression")]
257    InitializerExprType,
258    #[error("The type doesn't match the override")]
259    InvalidType,
260    #[error("The type is not constructible")]
261    NonConstructibleType,
262    #[error("The type is not a scalar")]
263    TypeNotScalar,
264    #[error("Override declarations are not allowed")]
265    NotAllowed,
266}
267
268#[derive(Clone, Debug, thiserror::Error)]
269#[cfg_attr(test, derive(PartialEq))]
270pub enum ValidationError {
271    #[error(transparent)]
272    InvalidHandle(#[from] InvalidHandleError),
273    #[error(transparent)]
274    Layouter(#[from] LayoutError),
275    #[error("Type {handle:?} '{name}' is invalid")]
276    Type {
277        handle: Handle<crate::Type>,
278        name: String,
279        source: TypeError,
280    },
281    #[error("Constant expression {handle:?} is invalid")]
282    ConstExpression {
283        handle: Handle<crate::Expression>,
284        source: ConstExpressionError,
285    },
286    #[error("Constant {handle:?} '{name}' is invalid")]
287    Constant {
288        handle: Handle<crate::Constant>,
289        name: String,
290        source: ConstantError,
291    },
292    #[error("Override {handle:?} '{name}' is invalid")]
293    Override {
294        handle: Handle<crate::Override>,
295        name: String,
296        source: OverrideError,
297    },
298    #[error("Global variable {handle:?} '{name}' is invalid")]
299    GlobalVariable {
300        handle: Handle<crate::GlobalVariable>,
301        name: String,
302        source: GlobalVariableError,
303    },
304    #[error("Function {handle:?} '{name}' is invalid")]
305    Function {
306        handle: Handle<crate::Function>,
307        name: String,
308        source: FunctionError,
309    },
310    #[error("Entry point {name} at {stage:?} is invalid")]
311    EntryPoint {
312        stage: crate::ShaderStage,
313        name: String,
314        source: EntryPointError,
315    },
316    #[error("Module is corrupted")]
317    Corrupted,
318}
319
320impl crate::TypeInner {
321    const fn is_sized(&self) -> bool {
322        match *self {
323            Self::Scalar { .. }
324            | Self::Vector { .. }
325            | Self::Matrix { .. }
326            | Self::Array {
327                size: crate::ArraySize::Constant(_),
328                ..
329            }
330            | Self::Atomic { .. }
331            | Self::Pointer { .. }
332            | Self::ValuePointer { .. }
333            | Self::Struct { .. } => true,
334            Self::Array { .. }
335            | Self::Image { .. }
336            | Self::Sampler { .. }
337            | Self::AccelerationStructure
338            | Self::RayQuery
339            | Self::BindingArray { .. } => false,
340        }
341    }
342
343    /// Return the `ImageDimension` for which `self` is an appropriate coordinate.
344    const fn image_storage_coordinates(&self) -> Option<crate::ImageDimension> {
345        match *self {
346            Self::Scalar(crate::Scalar {
347                kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
348                ..
349            }) => Some(crate::ImageDimension::D1),
350            Self::Vector {
351                size: crate::VectorSize::Bi,
352                scalar:
353                    crate::Scalar {
354                        kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
355                        ..
356                    },
357            } => Some(crate::ImageDimension::D2),
358            Self::Vector {
359                size: crate::VectorSize::Tri,
360                scalar:
361                    crate::Scalar {
362                        kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
363                        ..
364                    },
365            } => Some(crate::ImageDimension::D3),
366            _ => None,
367        }
368    }
369}
370
371impl Validator {
372    /// Construct a new validator instance.
373    pub fn new(flags: ValidationFlags, capabilities: Capabilities) -> Self {
374        Validator {
375            flags,
376            capabilities,
377            subgroup_stages: ShaderStages::empty(),
378            subgroup_operations: SubgroupOperationSet::empty(),
379            types: Vec::new(),
380            layouter: Layouter::default(),
381            location_mask: BitSet::new(),
382            ep_resource_bindings: FastHashSet::default(),
383            switch_values: FastHashSet::default(),
384            valid_expression_list: Vec::new(),
385            valid_expression_set: BitSet::new(),
386            override_ids: FastHashSet::default(),
387            allow_overrides: true,
388        }
389    }
390
391    pub fn subgroup_stages(&mut self, stages: ShaderStages) -> &mut Self {
392        self.subgroup_stages = stages;
393        self
394    }
395
396    pub fn subgroup_operations(&mut self, operations: SubgroupOperationSet) -> &mut Self {
397        self.subgroup_operations = operations;
398        self
399    }
400
401    /// Reset the validator internals
402    pub fn reset(&mut self) {
403        self.types.clear();
404        self.layouter.clear();
405        self.location_mask.clear();
406        self.ep_resource_bindings.clear();
407        self.switch_values.clear();
408        self.valid_expression_list.clear();
409        self.valid_expression_set.clear();
410        self.override_ids.clear();
411    }
412
413    fn validate_constant(
414        &self,
415        handle: Handle<crate::Constant>,
416        gctx: crate::proc::GlobalCtx,
417        mod_info: &ModuleInfo,
418        global_expr_kind: &ExpressionKindTracker,
419    ) -> Result<(), ConstantError> {
420        let con = &gctx.constants[handle];
421
422        let type_info = &self.types[con.ty.index()];
423        if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
424            return Err(ConstantError::NonConstructibleType);
425        }
426
427        if !global_expr_kind.is_const(con.init) {
428            return Err(ConstantError::InitializerExprType);
429        }
430
431        let decl_ty = &gctx.types[con.ty].inner;
432        let init_ty = mod_info[con.init].inner_with(gctx.types);
433        if !decl_ty.equivalent(init_ty, gctx.types) {
434            return Err(ConstantError::InvalidType);
435        }
436
437        Ok(())
438    }
439
440    fn validate_override(
441        &mut self,
442        handle: Handle<crate::Override>,
443        gctx: crate::proc::GlobalCtx,
444        mod_info: &ModuleInfo,
445    ) -> Result<(), OverrideError> {
446        if !self.allow_overrides {
447            return Err(OverrideError::NotAllowed);
448        }
449
450        let o = &gctx.overrides[handle];
451
452        if o.name.is_none() && o.id.is_none() {
453            return Err(OverrideError::MissingNameAndID);
454        }
455
456        if let Some(id) = o.id {
457            if !self.override_ids.insert(id) {
458                return Err(OverrideError::DuplicateID);
459            }
460        }
461
462        let type_info = &self.types[o.ty.index()];
463        if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
464            return Err(OverrideError::NonConstructibleType);
465        }
466
467        let decl_ty = &gctx.types[o.ty].inner;
468        match decl_ty {
469            &crate::TypeInner::Scalar(scalar) => match scalar {
470                crate::Scalar::BOOL
471                | crate::Scalar::I32
472                | crate::Scalar::U32
473                | crate::Scalar::F32
474                | crate::Scalar::F64 => {}
475                _ => return Err(OverrideError::TypeNotScalar),
476            },
477            _ => return Err(OverrideError::TypeNotScalar),
478        }
479
480        if let Some(init) = o.init {
481            let init_ty = mod_info[init].inner_with(gctx.types);
482            if !decl_ty.equivalent(init_ty, gctx.types) {
483                return Err(OverrideError::InvalidType);
484            }
485        }
486
487        Ok(())
488    }
489
490    /// Check the given module to be valid.
491    pub fn validate(
492        &mut self,
493        module: &crate::Module,
494    ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
495        self.allow_overrides = true;
496        self.validate_impl(module)
497    }
498
499    /// Check the given module to be valid.
500    ///
501    /// With the additional restriction that overrides are not present.
502    pub fn validate_no_overrides(
503        &mut self,
504        module: &crate::Module,
505    ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
506        self.allow_overrides = false;
507        self.validate_impl(module)
508    }
509
510    fn validate_impl(
511        &mut self,
512        module: &crate::Module,
513    ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
514        self.reset();
515        self.reset_types(module.types.len());
516
517        Self::validate_module_handles(module).map_err(|e| e.with_span())?;
518
519        self.layouter.update(module.to_ctx()).map_err(|e| {
520            let handle = e.ty;
521            ValidationError::from(e).with_span_handle(handle, &module.types)
522        })?;
523
524        // These should all get overwritten.
525        let placeholder = TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar {
526            kind: crate::ScalarKind::Bool,
527            width: 0,
528        }));
529
530        let mut mod_info = ModuleInfo {
531            type_flags: Vec::with_capacity(module.types.len()),
532            functions: Vec::with_capacity(module.functions.len()),
533            entry_points: Vec::with_capacity(module.entry_points.len()),
534            const_expression_types: vec![placeholder; module.global_expressions.len()]
535                .into_boxed_slice(),
536        };
537
538        for (handle, ty) in module.types.iter() {
539            let ty_info = self
540                .validate_type(handle, module.to_ctx())
541                .map_err(|source| {
542                    ValidationError::Type {
543                        handle,
544                        name: ty.name.clone().unwrap_or_default(),
545                        source,
546                    }
547                    .with_span_handle(handle, &module.types)
548                })?;
549            mod_info.type_flags.push(ty_info.flags);
550            self.types[handle.index()] = ty_info;
551        }
552
553        {
554            let t = crate::Arena::new();
555            let resolve_context = crate::proc::ResolveContext::with_locals(module, &t, &[]);
556            for (handle, _) in module.global_expressions.iter() {
557                mod_info
558                    .process_const_expression(handle, &resolve_context, module.to_ctx())
559                    .map_err(|source| {
560                        ValidationError::ConstExpression { handle, source }
561                            .with_span_handle(handle, &module.global_expressions)
562                    })?
563            }
564        }
565
566        let global_expr_kind = ExpressionKindTracker::from_arena(&module.global_expressions);
567
568        if self.flags.contains(ValidationFlags::CONSTANTS) {
569            for (handle, _) in module.global_expressions.iter() {
570                self.validate_const_expression(
571                    handle,
572                    module.to_ctx(),
573                    &mod_info,
574                    &global_expr_kind,
575                )
576                .map_err(|source| {
577                    ValidationError::ConstExpression { handle, source }
578                        .with_span_handle(handle, &module.global_expressions)
579                })?
580            }
581
582            for (handle, constant) in module.constants.iter() {
583                self.validate_constant(handle, module.to_ctx(), &mod_info, &global_expr_kind)
584                    .map_err(|source| {
585                        ValidationError::Constant {
586                            handle,
587                            name: constant.name.clone().unwrap_or_default(),
588                            source,
589                        }
590                        .with_span_handle(handle, &module.constants)
591                    })?
592            }
593
594            for (handle, override_) in module.overrides.iter() {
595                self.validate_override(handle, module.to_ctx(), &mod_info)
596                    .map_err(|source| {
597                        ValidationError::Override {
598                            handle,
599                            name: override_.name.clone().unwrap_or_default(),
600                            source,
601                        }
602                        .with_span_handle(handle, &module.overrides)
603                    })?
604            }
605        }
606
607        for (var_handle, var) in module.global_variables.iter() {
608            self.validate_global_var(var, module.to_ctx(), &mod_info, &global_expr_kind)
609                .map_err(|source| {
610                    ValidationError::GlobalVariable {
611                        handle: var_handle,
612                        name: var.name.clone().unwrap_or_default(),
613                        source,
614                    }
615                    .with_span_handle(var_handle, &module.global_variables)
616                })?;
617        }
618
619        for (handle, fun) in module.functions.iter() {
620            match self.validate_function(fun, module, &mod_info, false, &global_expr_kind) {
621                Ok(info) => mod_info.functions.push(info),
622                Err(error) => {
623                    return Err(error.and_then(|source| {
624                        ValidationError::Function {
625                            handle,
626                            name: fun.name.clone().unwrap_or_default(),
627                            source,
628                        }
629                        .with_span_handle(handle, &module.functions)
630                    }))
631                }
632            }
633        }
634
635        let mut ep_map = FastHashSet::default();
636        for ep in module.entry_points.iter() {
637            if !ep_map.insert((ep.stage, &ep.name)) {
638                return Err(ValidationError::EntryPoint {
639                    stage: ep.stage,
640                    name: ep.name.clone(),
641                    source: EntryPointError::Conflict,
642                }
643                .with_span()); // TODO: keep some EP span information?
644            }
645
646            match self.validate_entry_point(ep, module, &mod_info, &global_expr_kind) {
647                Ok(info) => mod_info.entry_points.push(info),
648                Err(error) => {
649                    return Err(error.and_then(|source| {
650                        ValidationError::EntryPoint {
651                            stage: ep.stage,
652                            name: ep.name.clone(),
653                            source,
654                        }
655                        .with_span()
656                    }));
657                }
658            }
659        }
660
661        Ok(mod_info)
662    }
663}
664
665fn validate_atomic_compare_exchange_struct(
666    types: &crate::UniqueArena<crate::Type>,
667    members: &[crate::StructMember],
668    scalar_predicate: impl FnOnce(&crate::TypeInner) -> bool,
669) -> bool {
670    members.len() == 2
671        && members[0].name.as_deref() == Some("old_value")
672        && scalar_predicate(&types[members[0].ty].inner)
673        && members[1].name.as_deref() == Some("exchanged")
674        && types[members[1].ty].inner == crate::TypeInner::Scalar(crate::Scalar::BOOL)
675}