1mod analyzer;
6mod compose;
7mod expression;
8mod function;
9mod handles;
10mod interface;
11mod r#type;
12
13use crate::{
14 arena::Handle,
15 proc::{ExpressionKindTracker, LayoutError, Layouter, TypeResolution},
16 FastHashSet,
17};
18use bit_set::BitSet;
19use std::ops;
20
21use crate::span::{AddSpan as _, WithSpan};
25pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, UniformityRequirements};
26pub use compose::ComposeError;
27pub use expression::{check_literal_value, LiteralError};
28pub use expression::{ConstExpressionError, ExpressionError};
29pub use function::{CallError, FunctionError, LocalVariableError};
30pub use interface::{EntryPointError, GlobalVariableError, VaryingError};
31pub use r#type::{Disalignment, TypeError, TypeFlags, WidthError};
32
33use self::handles::InvalidHandleError;
34
35bitflags::bitflags! {
36 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
50 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
51 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
52 pub struct ValidationFlags: u8 {
53 const EXPRESSIONS = 0x1;
55 const BLOCKS = 0x2;
57 const CONTROL_FLOW_UNIFORMITY = 0x4;
59 const STRUCT_LAYOUTS = 0x8;
61 const CONSTANTS = 0x10;
63 const BINDINGS = 0x20;
65 }
66}
67
68impl Default for ValidationFlags {
69 fn default() -> Self {
70 Self::all()
71 }
72}
73
74bitflags::bitflags! {
75 #[must_use]
77 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
78 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
79 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
80 pub struct Capabilities: u32 {
81 const PUSH_CONSTANT = 0x1;
83 const FLOAT64 = 0x2;
85 const PRIMITIVE_INDEX = 0x4;
87 const SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING = 0x8;
89 const UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING = 0x10;
91 const SAMPLER_NON_UNIFORM_INDEXING = 0x20;
93 const CLIP_DISTANCE = 0x40;
95 const CULL_DISTANCE = 0x80;
97 const STORAGE_TEXTURE_16BIT_NORM_FORMATS = 0x100;
99 const MULTIVIEW = 0x200;
101 const EARLY_DEPTH_TEST = 0x400;
103 const MULTISAMPLED_SHADING = 0x800;
105 const RAY_QUERY = 0x1000;
107 const DUAL_SOURCE_BLENDING = 0x2000;
109 const CUBE_ARRAY_TEXTURES = 0x4000;
111 const SHADER_INT64 = 0x8000;
113 const SUBGROUP = 0x10000;
115 const SUBGROUP_BARRIER = 0x20000;
117 }
118}
119
120impl Default for Capabilities {
121 fn default() -> Self {
122 Self::MULTISAMPLED_SHADING | Self::CUBE_ARRAY_TEXTURES
123 }
124}
125
126bitflags::bitflags! {
127 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
129 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
130 #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
131 pub struct SubgroupOperationSet: u8 {
132 const BASIC = 1 << 0;
134 const VOTE = 1 << 1;
136 const ARITHMETIC = 1 << 2;
138 const BALLOT = 1 << 3;
140 const SHUFFLE = 1 << 4;
142 const SHUFFLE_RELATIVE = 1 << 5;
144 }
152}
153
154impl super::SubgroupOperation {
155 const fn required_operations(&self) -> SubgroupOperationSet {
156 use SubgroupOperationSet as S;
157 match *self {
158 Self::All | Self::Any => S::VOTE,
159 Self::Add | Self::Mul | Self::Min | Self::Max | Self::And | Self::Or | Self::Xor => {
160 S::ARITHMETIC
161 }
162 }
163 }
164}
165
166impl super::GatherMode {
167 const fn required_operations(&self) -> SubgroupOperationSet {
168 use SubgroupOperationSet as S;
169 match *self {
170 Self::BroadcastFirst | Self::Broadcast(_) => S::BALLOT,
171 Self::Shuffle(_) | Self::ShuffleXor(_) => S::SHUFFLE,
172 Self::ShuffleUp(_) | Self::ShuffleDown(_) => S::SHUFFLE_RELATIVE,
173 }
174 }
175}
176
177bitflags::bitflags! {
178 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
180 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
181 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
182 pub struct ShaderStages: u8 {
183 const VERTEX = 0x1;
184 const FRAGMENT = 0x2;
185 const COMPUTE = 0x4;
186 }
187}
188
189#[derive(Debug, Clone)]
190#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
191#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
192pub struct ModuleInfo {
193 type_flags: Vec<TypeFlags>,
194 functions: Vec<FunctionInfo>,
195 entry_points: Vec<FunctionInfo>,
196 const_expression_types: Box<[TypeResolution]>,
197}
198
199impl ops::Index<Handle<crate::Type>> for ModuleInfo {
200 type Output = TypeFlags;
201 fn index(&self, handle: Handle<crate::Type>) -> &Self::Output {
202 &self.type_flags[handle.index()]
203 }
204}
205
206impl ops::Index<Handle<crate::Function>> for ModuleInfo {
207 type Output = FunctionInfo;
208 fn index(&self, handle: Handle<crate::Function>) -> &Self::Output {
209 &self.functions[handle.index()]
210 }
211}
212
213impl ops::Index<Handle<crate::Expression>> for ModuleInfo {
214 type Output = TypeResolution;
215 fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
216 &self.const_expression_types[handle.index()]
217 }
218}
219
220#[derive(Debug)]
221pub struct Validator {
222 flags: ValidationFlags,
223 capabilities: Capabilities,
224 subgroup_stages: ShaderStages,
225 subgroup_operations: SubgroupOperationSet,
226 types: Vec<r#type::TypeInfo>,
227 layouter: Layouter,
228 location_mask: BitSet,
229 ep_resource_bindings: FastHashSet<crate::ResourceBinding>,
230 #[allow(dead_code)]
231 switch_values: FastHashSet<crate::SwitchValue>,
232 valid_expression_list: Vec<Handle<crate::Expression>>,
233 valid_expression_set: BitSet,
234 override_ids: FastHashSet<u16>,
235 allow_overrides: bool,
236}
237
238#[derive(Clone, Debug, thiserror::Error)]
239#[cfg_attr(test, derive(PartialEq))]
240pub enum ConstantError {
241 #[error("Initializer must be a const-expression")]
242 InitializerExprType,
243 #[error("The type doesn't match the constant")]
244 InvalidType,
245 #[error("The type is not constructible")]
246 NonConstructibleType,
247}
248
249#[derive(Clone, Debug, thiserror::Error)]
250#[cfg_attr(test, derive(PartialEq))]
251pub enum OverrideError {
252 #[error("Override name and ID are missing")]
253 MissingNameAndID,
254 #[error("Override ID must be unique")]
255 DuplicateID,
256 #[error("Initializer must be a const-expression or override-expression")]
257 InitializerExprType,
258 #[error("The type doesn't match the override")]
259 InvalidType,
260 #[error("The type is not constructible")]
261 NonConstructibleType,
262 #[error("The type is not a scalar")]
263 TypeNotScalar,
264 #[error("Override declarations are not allowed")]
265 NotAllowed,
266}
267
268#[derive(Clone, Debug, thiserror::Error)]
269#[cfg_attr(test, derive(PartialEq))]
270pub enum ValidationError {
271 #[error(transparent)]
272 InvalidHandle(#[from] InvalidHandleError),
273 #[error(transparent)]
274 Layouter(#[from] LayoutError),
275 #[error("Type {handle:?} '{name}' is invalid")]
276 Type {
277 handle: Handle<crate::Type>,
278 name: String,
279 source: TypeError,
280 },
281 #[error("Constant expression {handle:?} is invalid")]
282 ConstExpression {
283 handle: Handle<crate::Expression>,
284 source: ConstExpressionError,
285 },
286 #[error("Constant {handle:?} '{name}' is invalid")]
287 Constant {
288 handle: Handle<crate::Constant>,
289 name: String,
290 source: ConstantError,
291 },
292 #[error("Override {handle:?} '{name}' is invalid")]
293 Override {
294 handle: Handle<crate::Override>,
295 name: String,
296 source: OverrideError,
297 },
298 #[error("Global variable {handle:?} '{name}' is invalid")]
299 GlobalVariable {
300 handle: Handle<crate::GlobalVariable>,
301 name: String,
302 source: GlobalVariableError,
303 },
304 #[error("Function {handle:?} '{name}' is invalid")]
305 Function {
306 handle: Handle<crate::Function>,
307 name: String,
308 source: FunctionError,
309 },
310 #[error("Entry point {name} at {stage:?} is invalid")]
311 EntryPoint {
312 stage: crate::ShaderStage,
313 name: String,
314 source: EntryPointError,
315 },
316 #[error("Module is corrupted")]
317 Corrupted,
318}
319
320impl crate::TypeInner {
321 const fn is_sized(&self) -> bool {
322 match *self {
323 Self::Scalar { .. }
324 | Self::Vector { .. }
325 | Self::Matrix { .. }
326 | Self::Array {
327 size: crate::ArraySize::Constant(_),
328 ..
329 }
330 | Self::Atomic { .. }
331 | Self::Pointer { .. }
332 | Self::ValuePointer { .. }
333 | Self::Struct { .. } => true,
334 Self::Array { .. }
335 | Self::Image { .. }
336 | Self::Sampler { .. }
337 | Self::AccelerationStructure
338 | Self::RayQuery
339 | Self::BindingArray { .. } => false,
340 }
341 }
342
343 const fn image_storage_coordinates(&self) -> Option<crate::ImageDimension> {
345 match *self {
346 Self::Scalar(crate::Scalar {
347 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
348 ..
349 }) => Some(crate::ImageDimension::D1),
350 Self::Vector {
351 size: crate::VectorSize::Bi,
352 scalar:
353 crate::Scalar {
354 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
355 ..
356 },
357 } => Some(crate::ImageDimension::D2),
358 Self::Vector {
359 size: crate::VectorSize::Tri,
360 scalar:
361 crate::Scalar {
362 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
363 ..
364 },
365 } => Some(crate::ImageDimension::D3),
366 _ => None,
367 }
368 }
369}
370
371impl Validator {
372 pub fn new(flags: ValidationFlags, capabilities: Capabilities) -> Self {
374 Validator {
375 flags,
376 capabilities,
377 subgroup_stages: ShaderStages::empty(),
378 subgroup_operations: SubgroupOperationSet::empty(),
379 types: Vec::new(),
380 layouter: Layouter::default(),
381 location_mask: BitSet::new(),
382 ep_resource_bindings: FastHashSet::default(),
383 switch_values: FastHashSet::default(),
384 valid_expression_list: Vec::new(),
385 valid_expression_set: BitSet::new(),
386 override_ids: FastHashSet::default(),
387 allow_overrides: true,
388 }
389 }
390
391 pub fn subgroup_stages(&mut self, stages: ShaderStages) -> &mut Self {
392 self.subgroup_stages = stages;
393 self
394 }
395
396 pub fn subgroup_operations(&mut self, operations: SubgroupOperationSet) -> &mut Self {
397 self.subgroup_operations = operations;
398 self
399 }
400
401 pub fn reset(&mut self) {
403 self.types.clear();
404 self.layouter.clear();
405 self.location_mask.clear();
406 self.ep_resource_bindings.clear();
407 self.switch_values.clear();
408 self.valid_expression_list.clear();
409 self.valid_expression_set.clear();
410 self.override_ids.clear();
411 }
412
413 fn validate_constant(
414 &self,
415 handle: Handle<crate::Constant>,
416 gctx: crate::proc::GlobalCtx,
417 mod_info: &ModuleInfo,
418 global_expr_kind: &ExpressionKindTracker,
419 ) -> Result<(), ConstantError> {
420 let con = &gctx.constants[handle];
421
422 let type_info = &self.types[con.ty.index()];
423 if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
424 return Err(ConstantError::NonConstructibleType);
425 }
426
427 if !global_expr_kind.is_const(con.init) {
428 return Err(ConstantError::InitializerExprType);
429 }
430
431 let decl_ty = &gctx.types[con.ty].inner;
432 let init_ty = mod_info[con.init].inner_with(gctx.types);
433 if !decl_ty.equivalent(init_ty, gctx.types) {
434 return Err(ConstantError::InvalidType);
435 }
436
437 Ok(())
438 }
439
440 fn validate_override(
441 &mut self,
442 handle: Handle<crate::Override>,
443 gctx: crate::proc::GlobalCtx,
444 mod_info: &ModuleInfo,
445 ) -> Result<(), OverrideError> {
446 if !self.allow_overrides {
447 return Err(OverrideError::NotAllowed);
448 }
449
450 let o = &gctx.overrides[handle];
451
452 if o.name.is_none() && o.id.is_none() {
453 return Err(OverrideError::MissingNameAndID);
454 }
455
456 if let Some(id) = o.id {
457 if !self.override_ids.insert(id) {
458 return Err(OverrideError::DuplicateID);
459 }
460 }
461
462 let type_info = &self.types[o.ty.index()];
463 if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
464 return Err(OverrideError::NonConstructibleType);
465 }
466
467 let decl_ty = &gctx.types[o.ty].inner;
468 match decl_ty {
469 &crate::TypeInner::Scalar(scalar) => match scalar {
470 crate::Scalar::BOOL
471 | crate::Scalar::I32
472 | crate::Scalar::U32
473 | crate::Scalar::F32
474 | crate::Scalar::F64 => {}
475 _ => return Err(OverrideError::TypeNotScalar),
476 },
477 _ => return Err(OverrideError::TypeNotScalar),
478 }
479
480 if let Some(init) = o.init {
481 let init_ty = mod_info[init].inner_with(gctx.types);
482 if !decl_ty.equivalent(init_ty, gctx.types) {
483 return Err(OverrideError::InvalidType);
484 }
485 }
486
487 Ok(())
488 }
489
490 pub fn validate(
492 &mut self,
493 module: &crate::Module,
494 ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
495 self.allow_overrides = true;
496 self.validate_impl(module)
497 }
498
499 pub fn validate_no_overrides(
503 &mut self,
504 module: &crate::Module,
505 ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
506 self.allow_overrides = false;
507 self.validate_impl(module)
508 }
509
510 fn validate_impl(
511 &mut self,
512 module: &crate::Module,
513 ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
514 self.reset();
515 self.reset_types(module.types.len());
516
517 Self::validate_module_handles(module).map_err(|e| e.with_span())?;
518
519 self.layouter.update(module.to_ctx()).map_err(|e| {
520 let handle = e.ty;
521 ValidationError::from(e).with_span_handle(handle, &module.types)
522 })?;
523
524 let placeholder = TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar {
526 kind: crate::ScalarKind::Bool,
527 width: 0,
528 }));
529
530 let mut mod_info = ModuleInfo {
531 type_flags: Vec::with_capacity(module.types.len()),
532 functions: Vec::with_capacity(module.functions.len()),
533 entry_points: Vec::with_capacity(module.entry_points.len()),
534 const_expression_types: vec![placeholder; module.global_expressions.len()]
535 .into_boxed_slice(),
536 };
537
538 for (handle, ty) in module.types.iter() {
539 let ty_info = self
540 .validate_type(handle, module.to_ctx())
541 .map_err(|source| {
542 ValidationError::Type {
543 handle,
544 name: ty.name.clone().unwrap_or_default(),
545 source,
546 }
547 .with_span_handle(handle, &module.types)
548 })?;
549 mod_info.type_flags.push(ty_info.flags);
550 self.types[handle.index()] = ty_info;
551 }
552
553 {
554 let t = crate::Arena::new();
555 let resolve_context = crate::proc::ResolveContext::with_locals(module, &t, &[]);
556 for (handle, _) in module.global_expressions.iter() {
557 mod_info
558 .process_const_expression(handle, &resolve_context, module.to_ctx())
559 .map_err(|source| {
560 ValidationError::ConstExpression { handle, source }
561 .with_span_handle(handle, &module.global_expressions)
562 })?
563 }
564 }
565
566 let global_expr_kind = ExpressionKindTracker::from_arena(&module.global_expressions);
567
568 if self.flags.contains(ValidationFlags::CONSTANTS) {
569 for (handle, _) in module.global_expressions.iter() {
570 self.validate_const_expression(
571 handle,
572 module.to_ctx(),
573 &mod_info,
574 &global_expr_kind,
575 )
576 .map_err(|source| {
577 ValidationError::ConstExpression { handle, source }
578 .with_span_handle(handle, &module.global_expressions)
579 })?
580 }
581
582 for (handle, constant) in module.constants.iter() {
583 self.validate_constant(handle, module.to_ctx(), &mod_info, &global_expr_kind)
584 .map_err(|source| {
585 ValidationError::Constant {
586 handle,
587 name: constant.name.clone().unwrap_or_default(),
588 source,
589 }
590 .with_span_handle(handle, &module.constants)
591 })?
592 }
593
594 for (handle, override_) in module.overrides.iter() {
595 self.validate_override(handle, module.to_ctx(), &mod_info)
596 .map_err(|source| {
597 ValidationError::Override {
598 handle,
599 name: override_.name.clone().unwrap_or_default(),
600 source,
601 }
602 .with_span_handle(handle, &module.overrides)
603 })?
604 }
605 }
606
607 for (var_handle, var) in module.global_variables.iter() {
608 self.validate_global_var(var, module.to_ctx(), &mod_info, &global_expr_kind)
609 .map_err(|source| {
610 ValidationError::GlobalVariable {
611 handle: var_handle,
612 name: var.name.clone().unwrap_or_default(),
613 source,
614 }
615 .with_span_handle(var_handle, &module.global_variables)
616 })?;
617 }
618
619 for (handle, fun) in module.functions.iter() {
620 match self.validate_function(fun, module, &mod_info, false, &global_expr_kind) {
621 Ok(info) => mod_info.functions.push(info),
622 Err(error) => {
623 return Err(error.and_then(|source| {
624 ValidationError::Function {
625 handle,
626 name: fun.name.clone().unwrap_or_default(),
627 source,
628 }
629 .with_span_handle(handle, &module.functions)
630 }))
631 }
632 }
633 }
634
635 let mut ep_map = FastHashSet::default();
636 for ep in module.entry_points.iter() {
637 if !ep_map.insert((ep.stage, &ep.name)) {
638 return Err(ValidationError::EntryPoint {
639 stage: ep.stage,
640 name: ep.name.clone(),
641 source: EntryPointError::Conflict,
642 }
643 .with_span()); }
645
646 match self.validate_entry_point(ep, module, &mod_info, &global_expr_kind) {
647 Ok(info) => mod_info.entry_points.push(info),
648 Err(error) => {
649 return Err(error.and_then(|source| {
650 ValidationError::EntryPoint {
651 stage: ep.stage,
652 name: ep.name.clone(),
653 source,
654 }
655 .with_span()
656 }));
657 }
658 }
659 }
660
661 Ok(mod_info)
662 }
663}
664
665fn validate_atomic_compare_exchange_struct(
666 types: &crate::UniqueArena<crate::Type>,
667 members: &[crate::StructMember],
668 scalar_predicate: impl FnOnce(&crate::TypeInner) -> bool,
669) -> bool {
670 members.len() == 2
671 && members[0].name.as_deref() == Some("old_value")
672 && scalar_predicate(&types[members[0].ty].inner)
673 && members[1].name.as_deref() == Some("exchanged")
674 && types[members[1].ty].inner == crate::TypeInner::Scalar(crate::Scalar::BOOL)
675}