1use crate::arena::Handle;
2use crate::arena::{Arena, UniqueArena};
3
4use super::validate_atomic_compare_exchange_struct;
5
6use super::{
7 analyzer::{UniformityDisruptor, UniformityRequirements},
8 ExpressionError, FunctionInfo, ModuleInfo,
9};
10use crate::span::WithSpan;
11use crate::span::{AddSpan as _, MapErrWithSpan as _};
12
13use bit_set::BitSet;
14
15#[derive(Clone, Debug, thiserror::Error)]
16#[cfg_attr(test, derive(PartialEq))]
17pub enum CallError {
18 #[error("Argument {index} expression is invalid")]
19 Argument {
20 index: usize,
21 source: ExpressionError,
22 },
23 #[error("Result expression {0:?} has already been introduced earlier")]
24 ResultAlreadyInScope(Handle<crate::Expression>),
25 #[error("Result value is invalid")]
26 ResultValue(#[source] ExpressionError),
27 #[error("Requires {required} arguments, but {seen} are provided")]
28 ArgumentCount { required: usize, seen: usize },
29 #[error("Argument {index} value {seen_expression:?} doesn't match the type {required:?}")]
30 ArgumentType {
31 index: usize,
32 required: Handle<crate::Type>,
33 seen_expression: Handle<crate::Expression>,
34 },
35 #[error("The emitted expression doesn't match the call")]
36 ExpressionMismatch(Option<Handle<crate::Expression>>),
37}
38
39#[derive(Clone, Debug, thiserror::Error)]
40#[cfg_attr(test, derive(PartialEq))]
41pub enum AtomicError {
42 #[error("Pointer {0:?} to atomic is invalid.")]
43 InvalidPointer(Handle<crate::Expression>),
44 #[error("Operand {0:?} has invalid type.")]
45 InvalidOperand(Handle<crate::Expression>),
46 #[error("Result type for {0:?} doesn't match the statement")]
47 ResultTypeMismatch(Handle<crate::Expression>),
48}
49
50#[derive(Clone, Debug, thiserror::Error)]
51#[cfg_attr(test, derive(PartialEq))]
52pub enum SubgroupError {
53 #[error("Operand {0:?} has invalid type.")]
54 InvalidOperand(Handle<crate::Expression>),
55 #[error("Result type for {0:?} doesn't match the statement")]
56 ResultTypeMismatch(Handle<crate::Expression>),
57 #[error("Support for subgroup operation {0:?} is required")]
58 UnsupportedOperation(super::SubgroupOperationSet),
59 #[error("Unknown operation")]
60 UnknownOperation,
61}
62
63#[derive(Clone, Debug, thiserror::Error)]
64#[cfg_attr(test, derive(PartialEq))]
65pub enum LocalVariableError {
66 #[error("Local variable has a type {0:?} that can't be stored in a local variable.")]
67 InvalidType(Handle<crate::Type>),
68 #[error("Initializer doesn't match the variable type")]
69 InitializerType,
70 #[error("Initializer is not a const or override expression")]
71 NonConstOrOverrideInitializer,
72}
73
74#[derive(Clone, Debug, thiserror::Error)]
75#[cfg_attr(test, derive(PartialEq))]
76pub enum FunctionError {
77 #[error("Expression {handle:?} is invalid")]
78 Expression {
79 handle: Handle<crate::Expression>,
80 source: ExpressionError,
81 },
82 #[error("Expression {0:?} can't be introduced - it's already in scope")]
83 ExpressionAlreadyInScope(Handle<crate::Expression>),
84 #[error("Local variable {handle:?} '{name}' is invalid")]
85 LocalVariable {
86 handle: Handle<crate::LocalVariable>,
87 name: String,
88 source: LocalVariableError,
89 },
90 #[error("Argument '{name}' at index {index} has a type that can't be passed into functions.")]
91 InvalidArgumentType { index: usize, name: String },
92 #[error("The function's given return type cannot be returned from functions")]
93 NonConstructibleReturnType,
94 #[error("Argument '{name}' at index {index} is a pointer of space {space:?}, which can't be passed into functions.")]
95 InvalidArgumentPointerSpace {
96 index: usize,
97 name: String,
98 space: crate::AddressSpace,
99 },
100 #[error("There are instructions after `return`/`break`/`continue`")]
101 InstructionsAfterReturn,
102 #[error("The `break` is used outside of a `loop` or `switch` context")]
103 BreakOutsideOfLoopOrSwitch,
104 #[error("The `continue` is used outside of a `loop` context")]
105 ContinueOutsideOfLoop,
106 #[error("The `return` is called within a `continuing` block")]
107 InvalidReturnSpot,
108 #[error("The `return` value {0:?} does not match the function return value")]
109 InvalidReturnType(Option<Handle<crate::Expression>>),
110 #[error("The `if` condition {0:?} is not a boolean scalar")]
111 InvalidIfType(Handle<crate::Expression>),
112 #[error("The `switch` value {0:?} is not an integer scalar")]
113 InvalidSwitchType(Handle<crate::Expression>),
114 #[error("Multiple `switch` cases for {0:?} are present")]
115 ConflictingSwitchCase(crate::SwitchValue),
116 #[error("The `switch` contains cases with conflicting types")]
117 ConflictingCaseType,
118 #[error("The `switch` is missing a `default` case")]
119 MissingDefaultCase,
120 #[error("Multiple `default` cases are present")]
121 MultipleDefaultCases,
122 #[error("The last `switch` case contains a `fallthrough`")]
123 LastCaseFallTrough,
124 #[error("The pointer {0:?} doesn't relate to a valid destination for a store")]
125 InvalidStorePointer(Handle<crate::Expression>),
126 #[error("The value {0:?} can not be stored")]
127 InvalidStoreValue(Handle<crate::Expression>),
128 #[error("The type of {value:?} doesn't match the type stored in {pointer:?}")]
129 InvalidStoreTypes {
130 pointer: Handle<crate::Expression>,
131 value: Handle<crate::Expression>,
132 },
133 #[error("Image store parameters are invalid")]
134 InvalidImageStore(#[source] ExpressionError),
135 #[error("Call to {function:?} is invalid")]
136 InvalidCall {
137 function: Handle<crate::Function>,
138 #[source]
139 error: CallError,
140 },
141 #[error("Atomic operation is invalid")]
142 InvalidAtomic(#[from] AtomicError),
143 #[error("Ray Query {0:?} is not a local variable")]
144 InvalidRayQueryExpression(Handle<crate::Expression>),
145 #[error("Acceleration structure {0:?} is not a matching expression")]
146 InvalidAccelerationStructure(Handle<crate::Expression>),
147 #[error("Ray descriptor {0:?} is not a matching expression")]
148 InvalidRayDescriptor(Handle<crate::Expression>),
149 #[error("Ray Query {0:?} does not have a matching type")]
150 InvalidRayQueryType(Handle<crate::Type>),
151 #[error("Shader requires capability {0:?}")]
152 MissingCapability(super::Capabilities),
153 #[error(
154 "Required uniformity of control flow for {0:?} in {1:?} is not fulfilled because of {2:?}"
155 )]
156 NonUniformControlFlow(
157 UniformityRequirements,
158 Handle<crate::Expression>,
159 UniformityDisruptor,
160 ),
161 #[error("Functions that are not entry points cannot have `@location` or `@builtin` attributes on their arguments: \"{name}\" has attributes")]
162 PipelineInputRegularFunction { name: String },
163 #[error("Functions that are not entry points cannot have `@location` or `@builtin` attributes on their return value types")]
164 PipelineOutputRegularFunction,
165 #[error("Required uniformity for WorkGroupUniformLoad is not fulfilled because of {0:?}")]
166 NonUniformWorkgroupUniformLoad(UniformityDisruptor),
168 #[error("The expression {0:?} for a WorkGroupUniformLoad isn't a WorkgroupUniformLoadResult")]
170 WorkgroupUniformLoadExpressionMismatch(Handle<crate::Expression>),
171 #[error("The expression {0:?} is not valid as a WorkGroupUniformLoad argument. It should be a Pointer in Workgroup address space")]
172 WorkgroupUniformLoadInvalidPointer(Handle<crate::Expression>),
173 #[error("Subgroup operation is invalid")]
174 InvalidSubgroup(#[from] SubgroupError),
175}
176
177bitflags::bitflags! {
178 #[repr(transparent)]
179 #[derive(Clone, Copy)]
180 struct ControlFlowAbility: u8 {
181 const RETURN = 0x1;
183 const BREAK = 0x2;
185 const CONTINUE = 0x4;
187 }
188}
189
190struct BlockInfo {
191 stages: super::ShaderStages,
192 finished: bool,
193}
194
195struct BlockContext<'a> {
196 abilities: ControlFlowAbility,
197 info: &'a FunctionInfo,
198 expressions: &'a Arena<crate::Expression>,
199 types: &'a UniqueArena<crate::Type>,
200 local_vars: &'a Arena<crate::LocalVariable>,
201 global_vars: &'a Arena<crate::GlobalVariable>,
202 functions: &'a Arena<crate::Function>,
203 special_types: &'a crate::SpecialTypes,
204 prev_infos: &'a [FunctionInfo],
205 return_type: Option<Handle<crate::Type>>,
206}
207
208impl<'a> BlockContext<'a> {
209 fn new(
210 fun: &'a crate::Function,
211 module: &'a crate::Module,
212 info: &'a FunctionInfo,
213 prev_infos: &'a [FunctionInfo],
214 ) -> Self {
215 Self {
216 abilities: ControlFlowAbility::RETURN,
217 info,
218 expressions: &fun.expressions,
219 types: &module.types,
220 local_vars: &fun.local_variables,
221 global_vars: &module.global_variables,
222 functions: &module.functions,
223 special_types: &module.special_types,
224 prev_infos,
225 return_type: fun.result.as_ref().map(|fr| fr.ty),
226 }
227 }
228
229 const fn with_abilities(&self, abilities: ControlFlowAbility) -> Self {
230 BlockContext { abilities, ..*self }
231 }
232
233 fn get_expression(&self, handle: Handle<crate::Expression>) -> &'a crate::Expression {
234 &self.expressions[handle]
235 }
236
237 fn resolve_type_impl(
238 &self,
239 handle: Handle<crate::Expression>,
240 valid_expressions: &BitSet,
241 ) -> Result<&crate::TypeInner, WithSpan<ExpressionError>> {
242 if handle.index() >= self.expressions.len() {
243 Err(ExpressionError::DoesntExist.with_span())
244 } else if !valid_expressions.contains(handle.index()) {
245 Err(ExpressionError::NotInScope.with_span_handle(handle, self.expressions))
246 } else {
247 Ok(self.info[handle].ty.inner_with(self.types))
248 }
249 }
250
251 fn resolve_type(
252 &self,
253 handle: Handle<crate::Expression>,
254 valid_expressions: &BitSet,
255 ) -> Result<&crate::TypeInner, WithSpan<FunctionError>> {
256 self.resolve_type_impl(handle, valid_expressions)
257 .map_err_inner(|source| FunctionError::Expression { handle, source }.with_span())
258 }
259
260 fn resolve_pointer_type(
261 &self,
262 handle: Handle<crate::Expression>,
263 ) -> Result<&crate::TypeInner, FunctionError> {
264 if handle.index() >= self.expressions.len() {
265 Err(FunctionError::Expression {
266 handle,
267 source: ExpressionError::DoesntExist,
268 })
269 } else {
270 Ok(self.info[handle].ty.inner_with(self.types))
271 }
272 }
273}
274
275impl super::Validator {
276 fn validate_call(
277 &mut self,
278 function: Handle<crate::Function>,
279 arguments: &[Handle<crate::Expression>],
280 result: Option<Handle<crate::Expression>>,
281 context: &BlockContext,
282 ) -> Result<super::ShaderStages, WithSpan<CallError>> {
283 let fun = &context.functions[function];
284 if fun.arguments.len() != arguments.len() {
285 return Err(CallError::ArgumentCount {
286 required: fun.arguments.len(),
287 seen: arguments.len(),
288 }
289 .with_span());
290 }
291 for (index, (arg, &expr)) in fun.arguments.iter().zip(arguments).enumerate() {
292 let ty = context
293 .resolve_type_impl(expr, &self.valid_expression_set)
294 .map_err_inner(|source| {
295 CallError::Argument { index, source }
296 .with_span_handle(expr, context.expressions)
297 })?;
298 let arg_inner = &context.types[arg.ty].inner;
299 if !ty.equivalent(arg_inner, context.types) {
300 return Err(CallError::ArgumentType {
301 index,
302 required: arg.ty,
303 seen_expression: expr,
304 }
305 .with_span_handle(expr, context.expressions));
306 }
307 }
308
309 if let Some(expr) = result {
310 if self.valid_expression_set.insert(expr.index()) {
311 self.valid_expression_list.push(expr);
312 } else {
313 return Err(CallError::ResultAlreadyInScope(expr)
314 .with_span_handle(expr, context.expressions));
315 }
316 match context.expressions[expr] {
317 crate::Expression::CallResult(callee)
318 if fun.result.is_some() && callee == function => {}
319 _ => {
320 return Err(CallError::ExpressionMismatch(result)
321 .with_span_handle(expr, context.expressions))
322 }
323 }
324 } else if fun.result.is_some() {
325 return Err(CallError::ExpressionMismatch(result).with_span());
326 }
327
328 let callee_info = &context.prev_infos[function.index()];
329 Ok(callee_info.available_stages)
330 }
331
332 fn emit_expression(
333 &mut self,
334 handle: Handle<crate::Expression>,
335 context: &BlockContext,
336 ) -> Result<(), WithSpan<FunctionError>> {
337 if self.valid_expression_set.insert(handle.index()) {
338 self.valid_expression_list.push(handle);
339 Ok(())
340 } else {
341 Err(FunctionError::ExpressionAlreadyInScope(handle)
342 .with_span_handle(handle, context.expressions))
343 }
344 }
345
346 fn validate_atomic(
347 &mut self,
348 pointer: Handle<crate::Expression>,
349 fun: &crate::AtomicFunction,
350 value: Handle<crate::Expression>,
351 result: Handle<crate::Expression>,
352 context: &BlockContext,
353 ) -> Result<(), WithSpan<FunctionError>> {
354 let pointer_inner = context.resolve_type(pointer, &self.valid_expression_set)?;
355 let ptr_scalar = match *pointer_inner {
356 crate::TypeInner::Pointer { base, .. } => match context.types[base].inner {
357 crate::TypeInner::Atomic(scalar) => scalar,
358 ref other => {
359 log::error!("Atomic pointer to type {:?}", other);
360 return Err(AtomicError::InvalidPointer(pointer)
361 .with_span_handle(pointer, context.expressions)
362 .into_other());
363 }
364 },
365 ref other => {
366 log::error!("Atomic on type {:?}", other);
367 return Err(AtomicError::InvalidPointer(pointer)
368 .with_span_handle(pointer, context.expressions)
369 .into_other());
370 }
371 };
372
373 let value_inner = context.resolve_type(value, &self.valid_expression_set)?;
374 match *value_inner {
375 crate::TypeInner::Scalar(scalar) if scalar == ptr_scalar => {}
376 ref other => {
377 log::error!("Atomic operand type {:?}", other);
378 return Err(AtomicError::InvalidOperand(value)
379 .with_span_handle(value, context.expressions)
380 .into_other());
381 }
382 }
383
384 if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
385 if context.resolve_type(cmp, &self.valid_expression_set)? != value_inner {
386 log::error!("Atomic exchange comparison has a different type from the value");
387 return Err(AtomicError::InvalidOperand(cmp)
388 .with_span_handle(cmp, context.expressions)
389 .into_other());
390 }
391 }
392
393 self.emit_expression(result, context)?;
394 match context.expressions[result] {
395 crate::Expression::AtomicResult { ty, comparison }
396 if {
397 let scalar_predicate =
398 |ty: &crate::TypeInner| *ty == crate::TypeInner::Scalar(ptr_scalar);
399 match &context.types[ty].inner {
400 ty if !comparison => scalar_predicate(ty),
401 &crate::TypeInner::Struct { ref members, .. } if comparison => {
402 validate_atomic_compare_exchange_struct(
403 context.types,
404 members,
405 scalar_predicate,
406 )
407 }
408 _ => false,
409 }
410 } => {}
411 _ => {
412 return Err(AtomicError::ResultTypeMismatch(result)
413 .with_span_handle(result, context.expressions)
414 .into_other())
415 }
416 }
417 Ok(())
418 }
419 fn validate_subgroup_operation(
420 &mut self,
421 op: &crate::SubgroupOperation,
422 collective_op: &crate::CollectiveOperation,
423 argument: Handle<crate::Expression>,
424 result: Handle<crate::Expression>,
425 context: &BlockContext,
426 ) -> Result<(), WithSpan<FunctionError>> {
427 let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?;
428
429 let (is_scalar, scalar) = match *argument_inner {
430 crate::TypeInner::Scalar(scalar) => (true, scalar),
431 crate::TypeInner::Vector { scalar, .. } => (false, scalar),
432 _ => {
433 log::error!("Subgroup operand type {:?}", argument_inner);
434 return Err(SubgroupError::InvalidOperand(argument)
435 .with_span_handle(argument, context.expressions)
436 .into_other());
437 }
438 };
439
440 use crate::ScalarKind as sk;
441 use crate::SubgroupOperation as sg;
442 match (scalar.kind, *op) {
443 (sk::Bool, sg::All | sg::Any) if is_scalar => {}
444 (sk::Sint | sk::Uint | sk::Float, sg::Add | sg::Mul | sg::Min | sg::Max) => {}
445 (sk::Sint | sk::Uint, sg::And | sg::Or | sg::Xor) => {}
446
447 (_, _) => {
448 log::error!("Subgroup operand type {:?}", argument_inner);
449 return Err(SubgroupError::InvalidOperand(argument)
450 .with_span_handle(argument, context.expressions)
451 .into_other());
452 }
453 };
454
455 use crate::CollectiveOperation as co;
456 match (*collective_op, *op) {
457 (
458 co::Reduce,
459 sg::All
460 | sg::Any
461 | sg::Add
462 | sg::Mul
463 | sg::Min
464 | sg::Max
465 | sg::And
466 | sg::Or
467 | sg::Xor,
468 ) => {}
469 (co::InclusiveScan | co::ExclusiveScan, sg::Add | sg::Mul) => {}
470
471 (_, _) => {
472 return Err(SubgroupError::UnknownOperation.with_span().into_other());
473 }
474 };
475
476 self.emit_expression(result, context)?;
477 match context.expressions[result] {
478 crate::Expression::SubgroupOperationResult { ty }
479 if { &context.types[ty].inner == argument_inner } => {}
480 _ => {
481 return Err(SubgroupError::ResultTypeMismatch(result)
482 .with_span_handle(result, context.expressions)
483 .into_other())
484 }
485 }
486 Ok(())
487 }
488 fn validate_subgroup_gather(
489 &mut self,
490 mode: &crate::GatherMode,
491 argument: Handle<crate::Expression>,
492 result: Handle<crate::Expression>,
493 context: &BlockContext,
494 ) -> Result<(), WithSpan<FunctionError>> {
495 match *mode {
496 crate::GatherMode::BroadcastFirst => {}
497 crate::GatherMode::Broadcast(index)
498 | crate::GatherMode::Shuffle(index)
499 | crate::GatherMode::ShuffleDown(index)
500 | crate::GatherMode::ShuffleUp(index)
501 | crate::GatherMode::ShuffleXor(index) => {
502 let index_ty = context.resolve_type(index, &self.valid_expression_set)?;
503 match *index_ty {
504 crate::TypeInner::Scalar(crate::Scalar::U32) => {}
505 _ => {
506 log::error!(
507 "Subgroup gather index type {:?}, expected unsigned int",
508 index_ty
509 );
510 return Err(SubgroupError::InvalidOperand(argument)
511 .with_span_handle(index, context.expressions)
512 .into_other());
513 }
514 }
515 }
516 }
517 let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?;
518 if !matches!(*argument_inner,
519 crate::TypeInner::Scalar ( scalar, .. ) | crate::TypeInner::Vector { scalar, .. }
520 if matches!(scalar.kind, crate::ScalarKind::Uint | crate::ScalarKind::Sint | crate::ScalarKind::Float)
521 ) {
522 log::error!("Subgroup gather operand type {:?}", argument_inner);
523 return Err(SubgroupError::InvalidOperand(argument)
524 .with_span_handle(argument, context.expressions)
525 .into_other());
526 }
527
528 self.emit_expression(result, context)?;
529 match context.expressions[result] {
530 crate::Expression::SubgroupOperationResult { ty }
531 if { &context.types[ty].inner == argument_inner } => {}
532 _ => {
533 return Err(SubgroupError::ResultTypeMismatch(result)
534 .with_span_handle(result, context.expressions)
535 .into_other())
536 }
537 }
538 Ok(())
539 }
540
541 fn validate_block_impl(
542 &mut self,
543 statements: &crate::Block,
544 context: &BlockContext,
545 ) -> Result<BlockInfo, WithSpan<FunctionError>> {
546 use crate::{AddressSpace, Statement as S, TypeInner as Ti};
547 let mut finished = false;
548 let mut stages = super::ShaderStages::all();
549 for (statement, &span) in statements.span_iter() {
550 if finished {
551 return Err(FunctionError::InstructionsAfterReturn
552 .with_span_static(span, "instructions after return"));
553 }
554 match *statement {
555 S::Emit(ref range) => {
556 for handle in range.clone() {
557 self.emit_expression(handle, context)?;
558 }
559 }
560 S::Block(ref block) => {
561 let info = self.validate_block(block, context)?;
562 stages &= info.stages;
563 finished = info.finished;
564 }
565 S::If {
566 condition,
567 ref accept,
568 ref reject,
569 } => {
570 match *context.resolve_type(condition, &self.valid_expression_set)? {
571 Ti::Scalar(crate::Scalar {
572 kind: crate::ScalarKind::Bool,
573 width: _,
574 }) => {}
575 _ => {
576 return Err(FunctionError::InvalidIfType(condition)
577 .with_span_handle(condition, context.expressions))
578 }
579 }
580 stages &= self.validate_block(accept, context)?.stages;
581 stages &= self.validate_block(reject, context)?.stages;
582 }
583 S::Switch {
584 selector,
585 ref cases,
586 } => {
587 let uint = match context
588 .resolve_type(selector, &self.valid_expression_set)?
589 .scalar_kind()
590 {
591 Some(crate::ScalarKind::Uint) => true,
592 Some(crate::ScalarKind::Sint) => false,
593 _ => {
594 return Err(FunctionError::InvalidSwitchType(selector)
595 .with_span_handle(selector, context.expressions))
596 }
597 };
598 self.switch_values.clear();
599 for case in cases {
600 match case.value {
601 crate::SwitchValue::I32(_) if !uint => {}
602 crate::SwitchValue::U32(_) if uint => {}
603 crate::SwitchValue::Default => {}
604 _ => {
605 return Err(FunctionError::ConflictingCaseType.with_span_static(
606 case.body
607 .span_iter()
608 .next()
609 .map_or(Default::default(), |(_, s)| *s),
610 "conflicting switch arm here",
611 ));
612 }
613 };
614 if !self.switch_values.insert(case.value) {
615 return Err(match case.value {
616 crate::SwitchValue::Default => FunctionError::MultipleDefaultCases
617 .with_span_static(
618 case.body
619 .span_iter()
620 .next()
621 .map_or(Default::default(), |(_, s)| *s),
622 "duplicated switch arm here",
623 ),
624 _ => FunctionError::ConflictingSwitchCase(case.value)
625 .with_span_static(
626 case.body
627 .span_iter()
628 .next()
629 .map_or(Default::default(), |(_, s)| *s),
630 "conflicting switch arm here",
631 ),
632 });
633 }
634 }
635 if !self.switch_values.contains(&crate::SwitchValue::Default) {
636 return Err(FunctionError::MissingDefaultCase
637 .with_span_static(span, "missing default case"));
638 }
639 if let Some(case) = cases.last() {
640 if case.fall_through {
641 return Err(FunctionError::LastCaseFallTrough.with_span_static(
642 case.body
643 .span_iter()
644 .next()
645 .map_or(Default::default(), |(_, s)| *s),
646 "bad switch arm here",
647 ));
648 }
649 }
650 let pass_through_abilities = context.abilities
651 & (ControlFlowAbility::RETURN | ControlFlowAbility::CONTINUE);
652 let sub_context =
653 context.with_abilities(pass_through_abilities | ControlFlowAbility::BREAK);
654 for case in cases {
655 stages &= self.validate_block(&case.body, &sub_context)?.stages;
656 }
657 }
658 S::Loop {
659 ref body,
660 ref continuing,
661 break_if,
662 } => {
663 let base_expression_count = self.valid_expression_list.len();
666 let pass_through_abilities = context.abilities & ControlFlowAbility::RETURN;
667 stages &= self
668 .validate_block_impl(
669 body,
670 &context.with_abilities(
671 pass_through_abilities
672 | ControlFlowAbility::BREAK
673 | ControlFlowAbility::CONTINUE,
674 ),
675 )?
676 .stages;
677 stages &= self
678 .validate_block_impl(
679 continuing,
680 &context.with_abilities(ControlFlowAbility::empty()),
681 )?
682 .stages;
683
684 if let Some(condition) = break_if {
685 match *context.resolve_type(condition, &self.valid_expression_set)? {
686 Ti::Scalar(crate::Scalar {
687 kind: crate::ScalarKind::Bool,
688 width: _,
689 }) => {}
690 _ => {
691 return Err(FunctionError::InvalidIfType(condition)
692 .with_span_handle(condition, context.expressions))
693 }
694 }
695 }
696
697 for handle in self.valid_expression_list.drain(base_expression_count..) {
698 self.valid_expression_set.remove(handle.index());
699 }
700 }
701 S::Break => {
702 if !context.abilities.contains(ControlFlowAbility::BREAK) {
703 return Err(FunctionError::BreakOutsideOfLoopOrSwitch
704 .with_span_static(span, "invalid break"));
705 }
706 finished = true;
707 }
708 S::Continue => {
709 if !context.abilities.contains(ControlFlowAbility::CONTINUE) {
710 return Err(FunctionError::ContinueOutsideOfLoop
711 .with_span_static(span, "invalid continue"));
712 }
713 finished = true;
714 }
715 S::Return { value } => {
716 if !context.abilities.contains(ControlFlowAbility::RETURN) {
717 return Err(FunctionError::InvalidReturnSpot
718 .with_span_static(span, "invalid return"));
719 }
720 let value_ty = value
721 .map(|expr| context.resolve_type(expr, &self.valid_expression_set))
722 .transpose()?;
723 let expected_ty = context.return_type.map(|ty| &context.types[ty].inner);
724 let okay = match (value_ty, expected_ty) {
727 (None, None) => true,
728 (Some(value_inner), Some(expected_inner)) => {
729 value_inner.equivalent(expected_inner, context.types)
730 }
731 (_, _) => false,
732 };
733
734 if !okay {
735 log::error!(
736 "Returning {:?} where {:?} is expected",
737 value_ty,
738 expected_ty
739 );
740 if let Some(handle) = value {
741 return Err(FunctionError::InvalidReturnType(value)
742 .with_span_handle(handle, context.expressions));
743 } else {
744 return Err(FunctionError::InvalidReturnType(value)
745 .with_span_static(span, "invalid return"));
746 }
747 }
748 finished = true;
749 }
750 S::Kill => {
751 stages &= super::ShaderStages::FRAGMENT;
752 finished = true;
753 }
754 S::Barrier(barrier) => {
755 stages &= super::ShaderStages::COMPUTE;
756 if barrier.contains(crate::Barrier::SUB_GROUP) {
757 if !self.capabilities.contains(
758 super::Capabilities::SUBGROUP | super::Capabilities::SUBGROUP_BARRIER,
759 ) {
760 return Err(FunctionError::MissingCapability(
761 super::Capabilities::SUBGROUP
762 | super::Capabilities::SUBGROUP_BARRIER,
763 )
764 .with_span_static(span, "missing capability for this operation"));
765 }
766 if !self
767 .subgroup_operations
768 .contains(super::SubgroupOperationSet::BASIC)
769 {
770 return Err(FunctionError::InvalidSubgroup(
771 SubgroupError::UnsupportedOperation(
772 super::SubgroupOperationSet::BASIC,
773 ),
774 )
775 .with_span_static(span, "support for this operation is not present"));
776 }
777 }
778 }
779 S::Store { pointer, value } => {
780 let mut current = pointer;
781 loop {
782 let _ = context
783 .resolve_pointer_type(current)
784 .map_err(|e| e.with_span())?;
785 match context.expressions[current] {
786 crate::Expression::Access { base, .. }
787 | crate::Expression::AccessIndex { base, .. } => current = base,
788 crate::Expression::LocalVariable(_)
789 | crate::Expression::GlobalVariable(_)
790 | crate::Expression::FunctionArgument(_) => break,
791 _ => {
792 return Err(FunctionError::InvalidStorePointer(current)
793 .with_span_handle(pointer, context.expressions))
794 }
795 }
796 }
797
798 let value_ty = context.resolve_type(value, &self.valid_expression_set)?;
799 match *value_ty {
800 Ti::Image { .. } | Ti::Sampler { .. } => {
801 return Err(FunctionError::InvalidStoreValue(value)
802 .with_span_handle(value, context.expressions));
803 }
804 _ => {}
805 }
806
807 let pointer_ty = context
808 .resolve_pointer_type(pointer)
809 .map_err(|e| e.with_span())?;
810
811 let good = match *pointer_ty {
812 Ti::Pointer { base, space: _ } => match context.types[base].inner {
813 Ti::Atomic(scalar) => *value_ty == Ti::Scalar(scalar),
814 ref other => value_ty == other,
815 },
816 Ti::ValuePointer {
817 size: Some(size),
818 scalar,
819 space: _,
820 } => *value_ty == Ti::Vector { size, scalar },
821 Ti::ValuePointer {
822 size: None,
823 scalar,
824 space: _,
825 } => *value_ty == Ti::Scalar(scalar),
826 _ => false,
827 };
828 if !good {
829 return Err(FunctionError::InvalidStoreTypes { pointer, value }
830 .with_span()
831 .with_handle(pointer, context.expressions)
832 .with_handle(value, context.expressions));
833 }
834
835 if let Some(space) = pointer_ty.pointer_space() {
836 if !space.access().contains(crate::StorageAccess::STORE) {
837 return Err(FunctionError::InvalidStorePointer(pointer)
838 .with_span_static(
839 context.expressions.get_span(pointer),
840 "writing to this location is not permitted",
841 ));
842 }
843 }
844 }
845 S::ImageStore {
846 image,
847 coordinate,
848 array_index,
849 value,
850 } => {
851 let var = match *context.get_expression(image) {
854 crate::Expression::GlobalVariable(var_handle) => {
855 &context.global_vars[var_handle]
856 }
857 crate::Expression::Access { base, .. }
859 | crate::Expression::AccessIndex { base, .. } => {
860 match *context.get_expression(base) {
861 crate::Expression::GlobalVariable(var_handle) => {
862 &context.global_vars[var_handle]
863 }
864 _ => {
865 return Err(FunctionError::InvalidImageStore(
866 ExpressionError::ExpectedGlobalVariable,
867 )
868 .with_span_handle(image, context.expressions))
869 }
870 }
871 }
872 _ => {
873 return Err(FunctionError::InvalidImageStore(
874 ExpressionError::ExpectedGlobalVariable,
875 )
876 .with_span_handle(image, context.expressions))
877 }
878 };
879
880 let global_ty = match context.types[var.ty].inner {
882 Ti::BindingArray { base, .. } => &context.types[base].inner,
883 ref inner => inner,
884 };
885
886 let value_ty = match *global_ty {
887 Ti::Image {
888 class,
889 arrayed,
890 dim,
891 } => {
892 match context
893 .resolve_type(coordinate, &self.valid_expression_set)?
894 .image_storage_coordinates()
895 {
896 Some(coord_dim) if coord_dim == dim => {}
897 _ => {
898 return Err(FunctionError::InvalidImageStore(
899 ExpressionError::InvalidImageCoordinateType(
900 dim, coordinate,
901 ),
902 )
903 .with_span_handle(coordinate, context.expressions));
904 }
905 };
906 if arrayed != array_index.is_some() {
907 return Err(FunctionError::InvalidImageStore(
908 ExpressionError::InvalidImageArrayIndex,
909 )
910 .with_span_handle(coordinate, context.expressions));
911 }
912 if let Some(expr) = array_index {
913 match *context.resolve_type(expr, &self.valid_expression_set)? {
914 Ti::Scalar(crate::Scalar {
915 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
916 width: _,
917 }) => {}
918 _ => {
919 return Err(FunctionError::InvalidImageStore(
920 ExpressionError::InvalidImageArrayIndexType(expr),
921 )
922 .with_span_handle(expr, context.expressions));
923 }
924 }
925 }
926 match class {
927 crate::ImageClass::Storage { format, .. } => {
928 crate::TypeInner::Vector {
929 size: crate::VectorSize::Quad,
930 scalar: crate::Scalar {
931 kind: format.into(),
932 width: 4,
933 },
934 }
935 }
936 _ => {
937 return Err(FunctionError::InvalidImageStore(
938 ExpressionError::InvalidImageClass(class),
939 )
940 .with_span_handle(image, context.expressions));
941 }
942 }
943 }
944 _ => {
945 return Err(FunctionError::InvalidImageStore(
946 ExpressionError::ExpectedImageType(var.ty),
947 )
948 .with_span()
949 .with_handle(var.ty, context.types)
950 .with_handle(image, context.expressions))
951 }
952 };
953
954 if *context.resolve_type(value, &self.valid_expression_set)? != value_ty {
955 return Err(FunctionError::InvalidStoreValue(value)
956 .with_span_handle(value, context.expressions));
957 }
958 }
959 S::Call {
960 function,
961 ref arguments,
962 result,
963 } => match self.validate_call(function, arguments, result, context) {
964 Ok(callee_stages) => stages &= callee_stages,
965 Err(error) => {
966 return Err(error.and_then(|error| {
967 FunctionError::InvalidCall { function, error }
968 .with_span_static(span, "invalid function call")
969 }))
970 }
971 },
972 S::Atomic {
973 pointer,
974 ref fun,
975 value,
976 result,
977 } => {
978 self.validate_atomic(pointer, fun, value, result, context)?;
979 }
980 S::WorkGroupUniformLoad { pointer, result } => {
981 stages &= super::ShaderStages::COMPUTE;
982 let pointer_inner =
983 context.resolve_type(pointer, &self.valid_expression_set)?;
984 match *pointer_inner {
985 Ti::Pointer {
986 space: AddressSpace::WorkGroup,
987 ..
988 } => {}
989 Ti::ValuePointer {
990 space: AddressSpace::WorkGroup,
991 ..
992 } => {}
993 _ => {
994 return Err(FunctionError::WorkgroupUniformLoadInvalidPointer(pointer)
995 .with_span_static(span, "WorkGroupUniformLoad"))
996 }
997 }
998 self.emit_expression(result, context)?;
999 let ty = match &context.expressions[result] {
1000 &crate::Expression::WorkGroupUniformLoadResult { ty } => ty,
1001 _ => {
1002 return Err(FunctionError::WorkgroupUniformLoadExpressionMismatch(
1003 result,
1004 )
1005 .with_span_static(span, "WorkGroupUniformLoad"));
1006 }
1007 };
1008 let expected_pointer_inner = Ti::Pointer {
1009 base: ty,
1010 space: AddressSpace::WorkGroup,
1011 };
1012 if !expected_pointer_inner.equivalent(pointer_inner, context.types) {
1013 return Err(FunctionError::WorkgroupUniformLoadInvalidPointer(pointer)
1014 .with_span_static(span, "WorkGroupUniformLoad"));
1015 }
1016 }
1017 S::RayQuery { query, ref fun } => {
1018 let query_var = match *context.get_expression(query) {
1019 crate::Expression::LocalVariable(var) => &context.local_vars[var],
1020 ref other => {
1021 log::error!("Unexpected ray query expression {other:?}");
1022 return Err(FunctionError::InvalidRayQueryExpression(query)
1023 .with_span_static(span, "invalid query expression"));
1024 }
1025 };
1026 match context.types[query_var.ty].inner {
1027 Ti::RayQuery => {}
1028 ref other => {
1029 log::error!("Unexpected ray query type {other:?}");
1030 return Err(FunctionError::InvalidRayQueryType(query_var.ty)
1031 .with_span_static(span, "invalid query type"));
1032 }
1033 }
1034 match *fun {
1035 crate::RayQueryFunction::Initialize {
1036 acceleration_structure,
1037 descriptor,
1038 } => {
1039 match *context
1040 .resolve_type(acceleration_structure, &self.valid_expression_set)?
1041 {
1042 Ti::AccelerationStructure => {}
1043 _ => {
1044 return Err(FunctionError::InvalidAccelerationStructure(
1045 acceleration_structure,
1046 )
1047 .with_span_static(span, "invalid acceleration structure"))
1048 }
1049 }
1050 let desc_ty_given =
1051 context.resolve_type(descriptor, &self.valid_expression_set)?;
1052 let desc_ty_expected = context
1053 .special_types
1054 .ray_desc
1055 .map(|handle| &context.types[handle].inner);
1056 if Some(desc_ty_given) != desc_ty_expected {
1057 return Err(FunctionError::InvalidRayDescriptor(descriptor)
1058 .with_span_static(span, "invalid ray descriptor"));
1059 }
1060 }
1061 crate::RayQueryFunction::Proceed { result } => {
1062 self.emit_expression(result, context)?;
1063 }
1064 crate::RayQueryFunction::Terminate => {}
1065 }
1066 }
1067 S::SubgroupBallot { result, predicate } => {
1068 stages &= self.subgroup_stages;
1069 if !self.capabilities.contains(super::Capabilities::SUBGROUP) {
1070 return Err(FunctionError::MissingCapability(
1071 super::Capabilities::SUBGROUP,
1072 )
1073 .with_span_static(span, "missing capability for this operation"));
1074 }
1075 if !self
1076 .subgroup_operations
1077 .contains(super::SubgroupOperationSet::BALLOT)
1078 {
1079 return Err(FunctionError::InvalidSubgroup(
1080 SubgroupError::UnsupportedOperation(
1081 super::SubgroupOperationSet::BALLOT,
1082 ),
1083 )
1084 .with_span_static(span, "support for this operation is not present"));
1085 }
1086 if let Some(predicate) = predicate {
1087 let predicate_inner =
1088 context.resolve_type(predicate, &self.valid_expression_set)?;
1089 if !matches!(
1090 *predicate_inner,
1091 crate::TypeInner::Scalar(crate::Scalar::BOOL,)
1092 ) {
1093 log::error!(
1094 "Subgroup ballot predicate type {:?} expected bool",
1095 predicate_inner
1096 );
1097 return Err(SubgroupError::InvalidOperand(predicate)
1098 .with_span_handle(predicate, context.expressions)
1099 .into_other());
1100 }
1101 }
1102 self.emit_expression(result, context)?;
1103 }
1104 S::SubgroupCollectiveOperation {
1105 ref op,
1106 ref collective_op,
1107 argument,
1108 result,
1109 } => {
1110 stages &= self.subgroup_stages;
1111 if !self.capabilities.contains(super::Capabilities::SUBGROUP) {
1112 return Err(FunctionError::MissingCapability(
1113 super::Capabilities::SUBGROUP,
1114 )
1115 .with_span_static(span, "missing capability for this operation"));
1116 }
1117 let operation = op.required_operations();
1118 if !self.subgroup_operations.contains(operation) {
1119 return Err(FunctionError::InvalidSubgroup(
1120 SubgroupError::UnsupportedOperation(operation),
1121 )
1122 .with_span_static(span, "support for this operation is not present"));
1123 }
1124 self.validate_subgroup_operation(op, collective_op, argument, result, context)?;
1125 }
1126 S::SubgroupGather {
1127 ref mode,
1128 argument,
1129 result,
1130 } => {
1131 stages &= self.subgroup_stages;
1132 if !self.capabilities.contains(super::Capabilities::SUBGROUP) {
1133 return Err(FunctionError::MissingCapability(
1134 super::Capabilities::SUBGROUP,
1135 )
1136 .with_span_static(span, "missing capability for this operation"));
1137 }
1138 let operation = mode.required_operations();
1139 if !self.subgroup_operations.contains(operation) {
1140 return Err(FunctionError::InvalidSubgroup(
1141 SubgroupError::UnsupportedOperation(operation),
1142 )
1143 .with_span_static(span, "support for this operation is not present"));
1144 }
1145 self.validate_subgroup_gather(mode, argument, result, context)?;
1146 }
1147 }
1148 }
1149 Ok(BlockInfo { stages, finished })
1150 }
1151
1152 fn validate_block(
1153 &mut self,
1154 statements: &crate::Block,
1155 context: &BlockContext,
1156 ) -> Result<BlockInfo, WithSpan<FunctionError>> {
1157 let base_expression_count = self.valid_expression_list.len();
1158 let info = self.validate_block_impl(statements, context)?;
1159 for handle in self.valid_expression_list.drain(base_expression_count..) {
1160 self.valid_expression_set.remove(handle.index());
1161 }
1162 Ok(info)
1163 }
1164
1165 fn validate_local_var(
1166 &self,
1167 var: &crate::LocalVariable,
1168 gctx: crate::proc::GlobalCtx,
1169 fun_info: &FunctionInfo,
1170 local_expr_kind: &crate::proc::ExpressionKindTracker,
1171 ) -> Result<(), LocalVariableError> {
1172 log::debug!("var {:?}", var);
1173 let type_info = self
1174 .types
1175 .get(var.ty.index())
1176 .ok_or(LocalVariableError::InvalidType(var.ty))?;
1177 if !type_info.flags.contains(super::TypeFlags::CONSTRUCTIBLE) {
1178 return Err(LocalVariableError::InvalidType(var.ty));
1179 }
1180
1181 if let Some(init) = var.init {
1182 let decl_ty = &gctx.types[var.ty].inner;
1183 let init_ty = fun_info[init].ty.inner_with(gctx.types);
1184 if !decl_ty.equivalent(init_ty, gctx.types) {
1185 return Err(LocalVariableError::InitializerType);
1186 }
1187
1188 if !local_expr_kind.is_const_or_override(init) {
1189 return Err(LocalVariableError::NonConstOrOverrideInitializer);
1190 }
1191 }
1192
1193 Ok(())
1194 }
1195
1196 pub(super) fn validate_function(
1197 &mut self,
1198 fun: &crate::Function,
1199 module: &crate::Module,
1200 mod_info: &ModuleInfo,
1201 entry_point: bool,
1202 global_expr_kind: &crate::proc::ExpressionKindTracker,
1203 ) -> Result<FunctionInfo, WithSpan<FunctionError>> {
1204 let mut info = mod_info.process_function(fun, module, self.flags, self.capabilities)?;
1205
1206 let local_expr_kind = crate::proc::ExpressionKindTracker::from_arena(&fun.expressions);
1207
1208 for (var_handle, var) in fun.local_variables.iter() {
1209 self.validate_local_var(var, module.to_ctx(), &info, &local_expr_kind)
1210 .map_err(|source| {
1211 FunctionError::LocalVariable {
1212 handle: var_handle,
1213 name: var.name.clone().unwrap_or_default(),
1214 source,
1215 }
1216 .with_span_handle(var.ty, &module.types)
1217 .with_handle(var_handle, &fun.local_variables)
1218 })?;
1219 }
1220
1221 for (index, argument) in fun.arguments.iter().enumerate() {
1222 match module.types[argument.ty].inner.pointer_space() {
1223 Some(crate::AddressSpace::Private | crate::AddressSpace::Function) | None => {}
1224 Some(other) => {
1225 return Err(FunctionError::InvalidArgumentPointerSpace {
1226 index,
1227 name: argument.name.clone().unwrap_or_default(),
1228 space: other,
1229 }
1230 .with_span_handle(argument.ty, &module.types))
1231 }
1232 }
1233 if !self.types[argument.ty.index()]
1235 .flags
1236 .contains(super::TypeFlags::ARGUMENT)
1237 {
1238 return Err(FunctionError::InvalidArgumentType {
1239 index,
1240 name: argument.name.clone().unwrap_or_default(),
1241 }
1242 .with_span_handle(argument.ty, &module.types));
1243 }
1244
1245 if !entry_point && argument.binding.is_some() {
1246 return Err(FunctionError::PipelineInputRegularFunction {
1247 name: argument.name.clone().unwrap_or_default(),
1248 }
1249 .with_span_handle(argument.ty, &module.types));
1250 }
1251 }
1252
1253 if let Some(ref result) = fun.result {
1254 if !self.types[result.ty.index()]
1255 .flags
1256 .contains(super::TypeFlags::CONSTRUCTIBLE)
1257 {
1258 return Err(FunctionError::NonConstructibleReturnType
1259 .with_span_handle(result.ty, &module.types));
1260 }
1261
1262 if !entry_point && result.binding.is_some() {
1263 return Err(FunctionError::PipelineOutputRegularFunction
1264 .with_span_handle(result.ty, &module.types));
1265 }
1266 }
1267
1268 self.valid_expression_set.clear();
1269 self.valid_expression_list.clear();
1270 for (handle, expr) in fun.expressions.iter() {
1271 if expr.needs_pre_emit() {
1272 self.valid_expression_set.insert(handle.index());
1273 }
1274 if self.flags.contains(super::ValidationFlags::EXPRESSIONS) {
1275 match self.validate_expression(
1276 handle,
1277 expr,
1278 fun,
1279 module,
1280 &info,
1281 mod_info,
1282 global_expr_kind,
1283 ) {
1284 Ok(stages) => info.available_stages &= stages,
1285 Err(source) => {
1286 return Err(FunctionError::Expression { handle, source }
1287 .with_span_handle(handle, &fun.expressions))
1288 }
1289 }
1290 }
1291 }
1292
1293 if self.flags.contains(super::ValidationFlags::BLOCKS) {
1294 let stages = self
1295 .validate_block(
1296 &fun.body,
1297 &BlockContext::new(fun, module, &info, &mod_info.functions),
1298 )?
1299 .stages;
1300 info.available_stages &= stages;
1301 }
1302 Ok(info)
1303 }
1304}