naga/valid/
function.rs

1use crate::arena::Handle;
2use crate::arena::{Arena, UniqueArena};
3
4use super::validate_atomic_compare_exchange_struct;
5
6use super::{
7    analyzer::{UniformityDisruptor, UniformityRequirements},
8    ExpressionError, FunctionInfo, ModuleInfo,
9};
10use crate::span::WithSpan;
11use crate::span::{AddSpan as _, MapErrWithSpan as _};
12
13use bit_set::BitSet;
14
15#[derive(Clone, Debug, thiserror::Error)]
16#[cfg_attr(test, derive(PartialEq))]
17pub enum CallError {
18    #[error("Argument {index} expression is invalid")]
19    Argument {
20        index: usize,
21        source: ExpressionError,
22    },
23    #[error("Result expression {0:?} has already been introduced earlier")]
24    ResultAlreadyInScope(Handle<crate::Expression>),
25    #[error("Result value is invalid")]
26    ResultValue(#[source] ExpressionError),
27    #[error("Requires {required} arguments, but {seen} are provided")]
28    ArgumentCount { required: usize, seen: usize },
29    #[error("Argument {index} value {seen_expression:?} doesn't match the type {required:?}")]
30    ArgumentType {
31        index: usize,
32        required: Handle<crate::Type>,
33        seen_expression: Handle<crate::Expression>,
34    },
35    #[error("The emitted expression doesn't match the call")]
36    ExpressionMismatch(Option<Handle<crate::Expression>>),
37}
38
39#[derive(Clone, Debug, thiserror::Error)]
40#[cfg_attr(test, derive(PartialEq))]
41pub enum AtomicError {
42    #[error("Pointer {0:?} to atomic is invalid.")]
43    InvalidPointer(Handle<crate::Expression>),
44    #[error("Operand {0:?} has invalid type.")]
45    InvalidOperand(Handle<crate::Expression>),
46    #[error("Result type for {0:?} doesn't match the statement")]
47    ResultTypeMismatch(Handle<crate::Expression>),
48}
49
50#[derive(Clone, Debug, thiserror::Error)]
51#[cfg_attr(test, derive(PartialEq))]
52pub enum SubgroupError {
53    #[error("Operand {0:?} has invalid type.")]
54    InvalidOperand(Handle<crate::Expression>),
55    #[error("Result type for {0:?} doesn't match the statement")]
56    ResultTypeMismatch(Handle<crate::Expression>),
57    #[error("Support for subgroup operation {0:?} is required")]
58    UnsupportedOperation(super::SubgroupOperationSet),
59    #[error("Unknown operation")]
60    UnknownOperation,
61}
62
63#[derive(Clone, Debug, thiserror::Error)]
64#[cfg_attr(test, derive(PartialEq))]
65pub enum LocalVariableError {
66    #[error("Local variable has a type {0:?} that can't be stored in a local variable.")]
67    InvalidType(Handle<crate::Type>),
68    #[error("Initializer doesn't match the variable type")]
69    InitializerType,
70    #[error("Initializer is not a const or override expression")]
71    NonConstOrOverrideInitializer,
72}
73
74#[derive(Clone, Debug, thiserror::Error)]
75#[cfg_attr(test, derive(PartialEq))]
76pub enum FunctionError {
77    #[error("Expression {handle:?} is invalid")]
78    Expression {
79        handle: Handle<crate::Expression>,
80        source: ExpressionError,
81    },
82    #[error("Expression {0:?} can't be introduced - it's already in scope")]
83    ExpressionAlreadyInScope(Handle<crate::Expression>),
84    #[error("Local variable {handle:?} '{name}' is invalid")]
85    LocalVariable {
86        handle: Handle<crate::LocalVariable>,
87        name: String,
88        source: LocalVariableError,
89    },
90    #[error("Argument '{name}' at index {index} has a type that can't be passed into functions.")]
91    InvalidArgumentType { index: usize, name: String },
92    #[error("The function's given return type cannot be returned from functions")]
93    NonConstructibleReturnType,
94    #[error("Argument '{name}' at index {index} is a pointer of space {space:?}, which can't be passed into functions.")]
95    InvalidArgumentPointerSpace {
96        index: usize,
97        name: String,
98        space: crate::AddressSpace,
99    },
100    #[error("There are instructions after `return`/`break`/`continue`")]
101    InstructionsAfterReturn,
102    #[error("The `break` is used outside of a `loop` or `switch` context")]
103    BreakOutsideOfLoopOrSwitch,
104    #[error("The `continue` is used outside of a `loop` context")]
105    ContinueOutsideOfLoop,
106    #[error("The `return` is called within a `continuing` block")]
107    InvalidReturnSpot,
108    #[error("The `return` value {0:?} does not match the function return value")]
109    InvalidReturnType(Option<Handle<crate::Expression>>),
110    #[error("The `if` condition {0:?} is not a boolean scalar")]
111    InvalidIfType(Handle<crate::Expression>),
112    #[error("The `switch` value {0:?} is not an integer scalar")]
113    InvalidSwitchType(Handle<crate::Expression>),
114    #[error("Multiple `switch` cases for {0:?} are present")]
115    ConflictingSwitchCase(crate::SwitchValue),
116    #[error("The `switch` contains cases with conflicting types")]
117    ConflictingCaseType,
118    #[error("The `switch` is missing a `default` case")]
119    MissingDefaultCase,
120    #[error("Multiple `default` cases are present")]
121    MultipleDefaultCases,
122    #[error("The last `switch` case contains a `fallthrough`")]
123    LastCaseFallTrough,
124    #[error("The pointer {0:?} doesn't relate to a valid destination for a store")]
125    InvalidStorePointer(Handle<crate::Expression>),
126    #[error("The value {0:?} can not be stored")]
127    InvalidStoreValue(Handle<crate::Expression>),
128    #[error("The type of {value:?} doesn't match the type stored in {pointer:?}")]
129    InvalidStoreTypes {
130        pointer: Handle<crate::Expression>,
131        value: Handle<crate::Expression>,
132    },
133    #[error("Image store parameters are invalid")]
134    InvalidImageStore(#[source] ExpressionError),
135    #[error("Call to {function:?} is invalid")]
136    InvalidCall {
137        function: Handle<crate::Function>,
138        #[source]
139        error: CallError,
140    },
141    #[error("Atomic operation is invalid")]
142    InvalidAtomic(#[from] AtomicError),
143    #[error("Ray Query {0:?} is not a local variable")]
144    InvalidRayQueryExpression(Handle<crate::Expression>),
145    #[error("Acceleration structure {0:?} is not a matching expression")]
146    InvalidAccelerationStructure(Handle<crate::Expression>),
147    #[error("Ray descriptor {0:?} is not a matching expression")]
148    InvalidRayDescriptor(Handle<crate::Expression>),
149    #[error("Ray Query {0:?} does not have a matching type")]
150    InvalidRayQueryType(Handle<crate::Type>),
151    #[error("Shader requires capability {0:?}")]
152    MissingCapability(super::Capabilities),
153    #[error(
154        "Required uniformity of control flow for {0:?} in {1:?} is not fulfilled because of {2:?}"
155    )]
156    NonUniformControlFlow(
157        UniformityRequirements,
158        Handle<crate::Expression>,
159        UniformityDisruptor,
160    ),
161    #[error("Functions that are not entry points cannot have `@location` or `@builtin` attributes on their arguments: \"{name}\" has attributes")]
162    PipelineInputRegularFunction { name: String },
163    #[error("Functions that are not entry points cannot have `@location` or `@builtin` attributes on their return value types")]
164    PipelineOutputRegularFunction,
165    #[error("Required uniformity for WorkGroupUniformLoad is not fulfilled because of {0:?}")]
166    // The actual load statement will be "pointed to" by the span
167    NonUniformWorkgroupUniformLoad(UniformityDisruptor),
168    // This is only possible with a misbehaving frontend
169    #[error("The expression {0:?} for a WorkGroupUniformLoad isn't a WorkgroupUniformLoadResult")]
170    WorkgroupUniformLoadExpressionMismatch(Handle<crate::Expression>),
171    #[error("The expression {0:?} is not valid as a WorkGroupUniformLoad argument. It should be a Pointer in Workgroup address space")]
172    WorkgroupUniformLoadInvalidPointer(Handle<crate::Expression>),
173    #[error("Subgroup operation is invalid")]
174    InvalidSubgroup(#[from] SubgroupError),
175}
176
177bitflags::bitflags! {
178    #[repr(transparent)]
179    #[derive(Clone, Copy)]
180    struct ControlFlowAbility: u8 {
181        /// The control can return out of this block.
182        const RETURN = 0x1;
183        /// The control can break.
184        const BREAK = 0x2;
185        /// The control can continue.
186        const CONTINUE = 0x4;
187    }
188}
189
190struct BlockInfo {
191    stages: super::ShaderStages,
192    finished: bool,
193}
194
195struct BlockContext<'a> {
196    abilities: ControlFlowAbility,
197    info: &'a FunctionInfo,
198    expressions: &'a Arena<crate::Expression>,
199    types: &'a UniqueArena<crate::Type>,
200    local_vars: &'a Arena<crate::LocalVariable>,
201    global_vars: &'a Arena<crate::GlobalVariable>,
202    functions: &'a Arena<crate::Function>,
203    special_types: &'a crate::SpecialTypes,
204    prev_infos: &'a [FunctionInfo],
205    return_type: Option<Handle<crate::Type>>,
206}
207
208impl<'a> BlockContext<'a> {
209    fn new(
210        fun: &'a crate::Function,
211        module: &'a crate::Module,
212        info: &'a FunctionInfo,
213        prev_infos: &'a [FunctionInfo],
214    ) -> Self {
215        Self {
216            abilities: ControlFlowAbility::RETURN,
217            info,
218            expressions: &fun.expressions,
219            types: &module.types,
220            local_vars: &fun.local_variables,
221            global_vars: &module.global_variables,
222            functions: &module.functions,
223            special_types: &module.special_types,
224            prev_infos,
225            return_type: fun.result.as_ref().map(|fr| fr.ty),
226        }
227    }
228
229    const fn with_abilities(&self, abilities: ControlFlowAbility) -> Self {
230        BlockContext { abilities, ..*self }
231    }
232
233    fn get_expression(&self, handle: Handle<crate::Expression>) -> &'a crate::Expression {
234        &self.expressions[handle]
235    }
236
237    fn resolve_type_impl(
238        &self,
239        handle: Handle<crate::Expression>,
240        valid_expressions: &BitSet,
241    ) -> Result<&crate::TypeInner, WithSpan<ExpressionError>> {
242        if handle.index() >= self.expressions.len() {
243            Err(ExpressionError::DoesntExist.with_span())
244        } else if !valid_expressions.contains(handle.index()) {
245            Err(ExpressionError::NotInScope.with_span_handle(handle, self.expressions))
246        } else {
247            Ok(self.info[handle].ty.inner_with(self.types))
248        }
249    }
250
251    fn resolve_type(
252        &self,
253        handle: Handle<crate::Expression>,
254        valid_expressions: &BitSet,
255    ) -> Result<&crate::TypeInner, WithSpan<FunctionError>> {
256        self.resolve_type_impl(handle, valid_expressions)
257            .map_err_inner(|source| FunctionError::Expression { handle, source }.with_span())
258    }
259
260    fn resolve_pointer_type(
261        &self,
262        handle: Handle<crate::Expression>,
263    ) -> Result<&crate::TypeInner, FunctionError> {
264        if handle.index() >= self.expressions.len() {
265            Err(FunctionError::Expression {
266                handle,
267                source: ExpressionError::DoesntExist,
268            })
269        } else {
270            Ok(self.info[handle].ty.inner_with(self.types))
271        }
272    }
273}
274
275impl super::Validator {
276    fn validate_call(
277        &mut self,
278        function: Handle<crate::Function>,
279        arguments: &[Handle<crate::Expression>],
280        result: Option<Handle<crate::Expression>>,
281        context: &BlockContext,
282    ) -> Result<super::ShaderStages, WithSpan<CallError>> {
283        let fun = &context.functions[function];
284        if fun.arguments.len() != arguments.len() {
285            return Err(CallError::ArgumentCount {
286                required: fun.arguments.len(),
287                seen: arguments.len(),
288            }
289            .with_span());
290        }
291        for (index, (arg, &expr)) in fun.arguments.iter().zip(arguments).enumerate() {
292            let ty = context
293                .resolve_type_impl(expr, &self.valid_expression_set)
294                .map_err_inner(|source| {
295                    CallError::Argument { index, source }
296                        .with_span_handle(expr, context.expressions)
297                })?;
298            let arg_inner = &context.types[arg.ty].inner;
299            if !ty.equivalent(arg_inner, context.types) {
300                return Err(CallError::ArgumentType {
301                    index,
302                    required: arg.ty,
303                    seen_expression: expr,
304                }
305                .with_span_handle(expr, context.expressions));
306            }
307        }
308
309        if let Some(expr) = result {
310            if self.valid_expression_set.insert(expr.index()) {
311                self.valid_expression_list.push(expr);
312            } else {
313                return Err(CallError::ResultAlreadyInScope(expr)
314                    .with_span_handle(expr, context.expressions));
315            }
316            match context.expressions[expr] {
317                crate::Expression::CallResult(callee)
318                    if fun.result.is_some() && callee == function => {}
319                _ => {
320                    return Err(CallError::ExpressionMismatch(result)
321                        .with_span_handle(expr, context.expressions))
322                }
323            }
324        } else if fun.result.is_some() {
325            return Err(CallError::ExpressionMismatch(result).with_span());
326        }
327
328        let callee_info = &context.prev_infos[function.index()];
329        Ok(callee_info.available_stages)
330    }
331
332    fn emit_expression(
333        &mut self,
334        handle: Handle<crate::Expression>,
335        context: &BlockContext,
336    ) -> Result<(), WithSpan<FunctionError>> {
337        if self.valid_expression_set.insert(handle.index()) {
338            self.valid_expression_list.push(handle);
339            Ok(())
340        } else {
341            Err(FunctionError::ExpressionAlreadyInScope(handle)
342                .with_span_handle(handle, context.expressions))
343        }
344    }
345
346    fn validate_atomic(
347        &mut self,
348        pointer: Handle<crate::Expression>,
349        fun: &crate::AtomicFunction,
350        value: Handle<crate::Expression>,
351        result: Handle<crate::Expression>,
352        context: &BlockContext,
353    ) -> Result<(), WithSpan<FunctionError>> {
354        let pointer_inner = context.resolve_type(pointer, &self.valid_expression_set)?;
355        let ptr_scalar = match *pointer_inner {
356            crate::TypeInner::Pointer { base, .. } => match context.types[base].inner {
357                crate::TypeInner::Atomic(scalar) => scalar,
358                ref other => {
359                    log::error!("Atomic pointer to type {:?}", other);
360                    return Err(AtomicError::InvalidPointer(pointer)
361                        .with_span_handle(pointer, context.expressions)
362                        .into_other());
363                }
364            },
365            ref other => {
366                log::error!("Atomic on type {:?}", other);
367                return Err(AtomicError::InvalidPointer(pointer)
368                    .with_span_handle(pointer, context.expressions)
369                    .into_other());
370            }
371        };
372
373        let value_inner = context.resolve_type(value, &self.valid_expression_set)?;
374        match *value_inner {
375            crate::TypeInner::Scalar(scalar) if scalar == ptr_scalar => {}
376            ref other => {
377                log::error!("Atomic operand type {:?}", other);
378                return Err(AtomicError::InvalidOperand(value)
379                    .with_span_handle(value, context.expressions)
380                    .into_other());
381            }
382        }
383
384        if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
385            if context.resolve_type(cmp, &self.valid_expression_set)? != value_inner {
386                log::error!("Atomic exchange comparison has a different type from the value");
387                return Err(AtomicError::InvalidOperand(cmp)
388                    .with_span_handle(cmp, context.expressions)
389                    .into_other());
390            }
391        }
392
393        self.emit_expression(result, context)?;
394        match context.expressions[result] {
395            crate::Expression::AtomicResult { ty, comparison }
396                if {
397                    let scalar_predicate =
398                        |ty: &crate::TypeInner| *ty == crate::TypeInner::Scalar(ptr_scalar);
399                    match &context.types[ty].inner {
400                        ty if !comparison => scalar_predicate(ty),
401                        &crate::TypeInner::Struct { ref members, .. } if comparison => {
402                            validate_atomic_compare_exchange_struct(
403                                context.types,
404                                members,
405                                scalar_predicate,
406                            )
407                        }
408                        _ => false,
409                    }
410                } => {}
411            _ => {
412                return Err(AtomicError::ResultTypeMismatch(result)
413                    .with_span_handle(result, context.expressions)
414                    .into_other())
415            }
416        }
417        Ok(())
418    }
419    fn validate_subgroup_operation(
420        &mut self,
421        op: &crate::SubgroupOperation,
422        collective_op: &crate::CollectiveOperation,
423        argument: Handle<crate::Expression>,
424        result: Handle<crate::Expression>,
425        context: &BlockContext,
426    ) -> Result<(), WithSpan<FunctionError>> {
427        let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?;
428
429        let (is_scalar, scalar) = match *argument_inner {
430            crate::TypeInner::Scalar(scalar) => (true, scalar),
431            crate::TypeInner::Vector { scalar, .. } => (false, scalar),
432            _ => {
433                log::error!("Subgroup operand type {:?}", argument_inner);
434                return Err(SubgroupError::InvalidOperand(argument)
435                    .with_span_handle(argument, context.expressions)
436                    .into_other());
437            }
438        };
439
440        use crate::ScalarKind as sk;
441        use crate::SubgroupOperation as sg;
442        match (scalar.kind, *op) {
443            (sk::Bool, sg::All | sg::Any) if is_scalar => {}
444            (sk::Sint | sk::Uint | sk::Float, sg::Add | sg::Mul | sg::Min | sg::Max) => {}
445            (sk::Sint | sk::Uint, sg::And | sg::Or | sg::Xor) => {}
446
447            (_, _) => {
448                log::error!("Subgroup operand type {:?}", argument_inner);
449                return Err(SubgroupError::InvalidOperand(argument)
450                    .with_span_handle(argument, context.expressions)
451                    .into_other());
452            }
453        };
454
455        use crate::CollectiveOperation as co;
456        match (*collective_op, *op) {
457            (
458                co::Reduce,
459                sg::All
460                | sg::Any
461                | sg::Add
462                | sg::Mul
463                | sg::Min
464                | sg::Max
465                | sg::And
466                | sg::Or
467                | sg::Xor,
468            ) => {}
469            (co::InclusiveScan | co::ExclusiveScan, sg::Add | sg::Mul) => {}
470
471            (_, _) => {
472                return Err(SubgroupError::UnknownOperation.with_span().into_other());
473            }
474        };
475
476        self.emit_expression(result, context)?;
477        match context.expressions[result] {
478            crate::Expression::SubgroupOperationResult { ty }
479                if { &context.types[ty].inner == argument_inner } => {}
480            _ => {
481                return Err(SubgroupError::ResultTypeMismatch(result)
482                    .with_span_handle(result, context.expressions)
483                    .into_other())
484            }
485        }
486        Ok(())
487    }
488    fn validate_subgroup_gather(
489        &mut self,
490        mode: &crate::GatherMode,
491        argument: Handle<crate::Expression>,
492        result: Handle<crate::Expression>,
493        context: &BlockContext,
494    ) -> Result<(), WithSpan<FunctionError>> {
495        match *mode {
496            crate::GatherMode::BroadcastFirst => {}
497            crate::GatherMode::Broadcast(index)
498            | crate::GatherMode::Shuffle(index)
499            | crate::GatherMode::ShuffleDown(index)
500            | crate::GatherMode::ShuffleUp(index)
501            | crate::GatherMode::ShuffleXor(index) => {
502                let index_ty = context.resolve_type(index, &self.valid_expression_set)?;
503                match *index_ty {
504                    crate::TypeInner::Scalar(crate::Scalar::U32) => {}
505                    _ => {
506                        log::error!(
507                            "Subgroup gather index type {:?}, expected unsigned int",
508                            index_ty
509                        );
510                        return Err(SubgroupError::InvalidOperand(argument)
511                            .with_span_handle(index, context.expressions)
512                            .into_other());
513                    }
514                }
515            }
516        }
517        let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?;
518        if !matches!(*argument_inner,
519            crate::TypeInner::Scalar ( scalar, .. ) | crate::TypeInner::Vector { scalar, .. }
520            if matches!(scalar.kind, crate::ScalarKind::Uint | crate::ScalarKind::Sint | crate::ScalarKind::Float)
521        ) {
522            log::error!("Subgroup gather operand type {:?}", argument_inner);
523            return Err(SubgroupError::InvalidOperand(argument)
524                .with_span_handle(argument, context.expressions)
525                .into_other());
526        }
527
528        self.emit_expression(result, context)?;
529        match context.expressions[result] {
530            crate::Expression::SubgroupOperationResult { ty }
531                if { &context.types[ty].inner == argument_inner } => {}
532            _ => {
533                return Err(SubgroupError::ResultTypeMismatch(result)
534                    .with_span_handle(result, context.expressions)
535                    .into_other())
536            }
537        }
538        Ok(())
539    }
540
541    fn validate_block_impl(
542        &mut self,
543        statements: &crate::Block,
544        context: &BlockContext,
545    ) -> Result<BlockInfo, WithSpan<FunctionError>> {
546        use crate::{AddressSpace, Statement as S, TypeInner as Ti};
547        let mut finished = false;
548        let mut stages = super::ShaderStages::all();
549        for (statement, &span) in statements.span_iter() {
550            if finished {
551                return Err(FunctionError::InstructionsAfterReturn
552                    .with_span_static(span, "instructions after return"));
553            }
554            match *statement {
555                S::Emit(ref range) => {
556                    for handle in range.clone() {
557                        self.emit_expression(handle, context)?;
558                    }
559                }
560                S::Block(ref block) => {
561                    let info = self.validate_block(block, context)?;
562                    stages &= info.stages;
563                    finished = info.finished;
564                }
565                S::If {
566                    condition,
567                    ref accept,
568                    ref reject,
569                } => {
570                    match *context.resolve_type(condition, &self.valid_expression_set)? {
571                        Ti::Scalar(crate::Scalar {
572                            kind: crate::ScalarKind::Bool,
573                            width: _,
574                        }) => {}
575                        _ => {
576                            return Err(FunctionError::InvalidIfType(condition)
577                                .with_span_handle(condition, context.expressions))
578                        }
579                    }
580                    stages &= self.validate_block(accept, context)?.stages;
581                    stages &= self.validate_block(reject, context)?.stages;
582                }
583                S::Switch {
584                    selector,
585                    ref cases,
586                } => {
587                    let uint = match context
588                        .resolve_type(selector, &self.valid_expression_set)?
589                        .scalar_kind()
590                    {
591                        Some(crate::ScalarKind::Uint) => true,
592                        Some(crate::ScalarKind::Sint) => false,
593                        _ => {
594                            return Err(FunctionError::InvalidSwitchType(selector)
595                                .with_span_handle(selector, context.expressions))
596                        }
597                    };
598                    self.switch_values.clear();
599                    for case in cases {
600                        match case.value {
601                            crate::SwitchValue::I32(_) if !uint => {}
602                            crate::SwitchValue::U32(_) if uint => {}
603                            crate::SwitchValue::Default => {}
604                            _ => {
605                                return Err(FunctionError::ConflictingCaseType.with_span_static(
606                                    case.body
607                                        .span_iter()
608                                        .next()
609                                        .map_or(Default::default(), |(_, s)| *s),
610                                    "conflicting switch arm here",
611                                ));
612                            }
613                        };
614                        if !self.switch_values.insert(case.value) {
615                            return Err(match case.value {
616                                crate::SwitchValue::Default => FunctionError::MultipleDefaultCases
617                                    .with_span_static(
618                                        case.body
619                                            .span_iter()
620                                            .next()
621                                            .map_or(Default::default(), |(_, s)| *s),
622                                        "duplicated switch arm here",
623                                    ),
624                                _ => FunctionError::ConflictingSwitchCase(case.value)
625                                    .with_span_static(
626                                        case.body
627                                            .span_iter()
628                                            .next()
629                                            .map_or(Default::default(), |(_, s)| *s),
630                                        "conflicting switch arm here",
631                                    ),
632                            });
633                        }
634                    }
635                    if !self.switch_values.contains(&crate::SwitchValue::Default) {
636                        return Err(FunctionError::MissingDefaultCase
637                            .with_span_static(span, "missing default case"));
638                    }
639                    if let Some(case) = cases.last() {
640                        if case.fall_through {
641                            return Err(FunctionError::LastCaseFallTrough.with_span_static(
642                                case.body
643                                    .span_iter()
644                                    .next()
645                                    .map_or(Default::default(), |(_, s)| *s),
646                                "bad switch arm here",
647                            ));
648                        }
649                    }
650                    let pass_through_abilities = context.abilities
651                        & (ControlFlowAbility::RETURN | ControlFlowAbility::CONTINUE);
652                    let sub_context =
653                        context.with_abilities(pass_through_abilities | ControlFlowAbility::BREAK);
654                    for case in cases {
655                        stages &= self.validate_block(&case.body, &sub_context)?.stages;
656                    }
657                }
658                S::Loop {
659                    ref body,
660                    ref continuing,
661                    break_if,
662                } => {
663                    // special handling for block scoping is needed here,
664                    // because the continuing{} block inherits the scope
665                    let base_expression_count = self.valid_expression_list.len();
666                    let pass_through_abilities = context.abilities & ControlFlowAbility::RETURN;
667                    stages &= self
668                        .validate_block_impl(
669                            body,
670                            &context.with_abilities(
671                                pass_through_abilities
672                                    | ControlFlowAbility::BREAK
673                                    | ControlFlowAbility::CONTINUE,
674                            ),
675                        )?
676                        .stages;
677                    stages &= self
678                        .validate_block_impl(
679                            continuing,
680                            &context.with_abilities(ControlFlowAbility::empty()),
681                        )?
682                        .stages;
683
684                    if let Some(condition) = break_if {
685                        match *context.resolve_type(condition, &self.valid_expression_set)? {
686                            Ti::Scalar(crate::Scalar {
687                                kind: crate::ScalarKind::Bool,
688                                width: _,
689                            }) => {}
690                            _ => {
691                                return Err(FunctionError::InvalidIfType(condition)
692                                    .with_span_handle(condition, context.expressions))
693                            }
694                        }
695                    }
696
697                    for handle in self.valid_expression_list.drain(base_expression_count..) {
698                        self.valid_expression_set.remove(handle.index());
699                    }
700                }
701                S::Break => {
702                    if !context.abilities.contains(ControlFlowAbility::BREAK) {
703                        return Err(FunctionError::BreakOutsideOfLoopOrSwitch
704                            .with_span_static(span, "invalid break"));
705                    }
706                    finished = true;
707                }
708                S::Continue => {
709                    if !context.abilities.contains(ControlFlowAbility::CONTINUE) {
710                        return Err(FunctionError::ContinueOutsideOfLoop
711                            .with_span_static(span, "invalid continue"));
712                    }
713                    finished = true;
714                }
715                S::Return { value } => {
716                    if !context.abilities.contains(ControlFlowAbility::RETURN) {
717                        return Err(FunctionError::InvalidReturnSpot
718                            .with_span_static(span, "invalid return"));
719                    }
720                    let value_ty = value
721                        .map(|expr| context.resolve_type(expr, &self.valid_expression_set))
722                        .transpose()?;
723                    let expected_ty = context.return_type.map(|ty| &context.types[ty].inner);
724                    // We can't return pointers, but it seems best not to embed that
725                    // assumption here, so use `TypeInner::equivalent` for comparison.
726                    let okay = match (value_ty, expected_ty) {
727                        (None, None) => true,
728                        (Some(value_inner), Some(expected_inner)) => {
729                            value_inner.equivalent(expected_inner, context.types)
730                        }
731                        (_, _) => false,
732                    };
733
734                    if !okay {
735                        log::error!(
736                            "Returning {:?} where {:?} is expected",
737                            value_ty,
738                            expected_ty
739                        );
740                        if let Some(handle) = value {
741                            return Err(FunctionError::InvalidReturnType(value)
742                                .with_span_handle(handle, context.expressions));
743                        } else {
744                            return Err(FunctionError::InvalidReturnType(value)
745                                .with_span_static(span, "invalid return"));
746                        }
747                    }
748                    finished = true;
749                }
750                S::Kill => {
751                    stages &= super::ShaderStages::FRAGMENT;
752                    finished = true;
753                }
754                S::Barrier(barrier) => {
755                    stages &= super::ShaderStages::COMPUTE;
756                    if barrier.contains(crate::Barrier::SUB_GROUP) {
757                        if !self.capabilities.contains(
758                            super::Capabilities::SUBGROUP | super::Capabilities::SUBGROUP_BARRIER,
759                        ) {
760                            return Err(FunctionError::MissingCapability(
761                                super::Capabilities::SUBGROUP
762                                    | super::Capabilities::SUBGROUP_BARRIER,
763                            )
764                            .with_span_static(span, "missing capability for this operation"));
765                        }
766                        if !self
767                            .subgroup_operations
768                            .contains(super::SubgroupOperationSet::BASIC)
769                        {
770                            return Err(FunctionError::InvalidSubgroup(
771                                SubgroupError::UnsupportedOperation(
772                                    super::SubgroupOperationSet::BASIC,
773                                ),
774                            )
775                            .with_span_static(span, "support for this operation is not present"));
776                        }
777                    }
778                }
779                S::Store { pointer, value } => {
780                    let mut current = pointer;
781                    loop {
782                        let _ = context
783                            .resolve_pointer_type(current)
784                            .map_err(|e| e.with_span())?;
785                        match context.expressions[current] {
786                            crate::Expression::Access { base, .. }
787                            | crate::Expression::AccessIndex { base, .. } => current = base,
788                            crate::Expression::LocalVariable(_)
789                            | crate::Expression::GlobalVariable(_)
790                            | crate::Expression::FunctionArgument(_) => break,
791                            _ => {
792                                return Err(FunctionError::InvalidStorePointer(current)
793                                    .with_span_handle(pointer, context.expressions))
794                            }
795                        }
796                    }
797
798                    let value_ty = context.resolve_type(value, &self.valid_expression_set)?;
799                    match *value_ty {
800                        Ti::Image { .. } | Ti::Sampler { .. } => {
801                            return Err(FunctionError::InvalidStoreValue(value)
802                                .with_span_handle(value, context.expressions));
803                        }
804                        _ => {}
805                    }
806
807                    let pointer_ty = context
808                        .resolve_pointer_type(pointer)
809                        .map_err(|e| e.with_span())?;
810
811                    let good = match *pointer_ty {
812                        Ti::Pointer { base, space: _ } => match context.types[base].inner {
813                            Ti::Atomic(scalar) => *value_ty == Ti::Scalar(scalar),
814                            ref other => value_ty == other,
815                        },
816                        Ti::ValuePointer {
817                            size: Some(size),
818                            scalar,
819                            space: _,
820                        } => *value_ty == Ti::Vector { size, scalar },
821                        Ti::ValuePointer {
822                            size: None,
823                            scalar,
824                            space: _,
825                        } => *value_ty == Ti::Scalar(scalar),
826                        _ => false,
827                    };
828                    if !good {
829                        return Err(FunctionError::InvalidStoreTypes { pointer, value }
830                            .with_span()
831                            .with_handle(pointer, context.expressions)
832                            .with_handle(value, context.expressions));
833                    }
834
835                    if let Some(space) = pointer_ty.pointer_space() {
836                        if !space.access().contains(crate::StorageAccess::STORE) {
837                            return Err(FunctionError::InvalidStorePointer(pointer)
838                                .with_span_static(
839                                    context.expressions.get_span(pointer),
840                                    "writing to this location is not permitted",
841                                ));
842                        }
843                    }
844                }
845                S::ImageStore {
846                    image,
847                    coordinate,
848                    array_index,
849                    value,
850                } => {
851                    //Note: this code uses a lot of `FunctionError::InvalidImageStore`,
852                    // and could probably be refactored.
853                    let var = match *context.get_expression(image) {
854                        crate::Expression::GlobalVariable(var_handle) => {
855                            &context.global_vars[var_handle]
856                        }
857                        // We're looking at a binding index situation, so punch through the index and look at the global behind it.
858                        crate::Expression::Access { base, .. }
859                        | crate::Expression::AccessIndex { base, .. } => {
860                            match *context.get_expression(base) {
861                                crate::Expression::GlobalVariable(var_handle) => {
862                                    &context.global_vars[var_handle]
863                                }
864                                _ => {
865                                    return Err(FunctionError::InvalidImageStore(
866                                        ExpressionError::ExpectedGlobalVariable,
867                                    )
868                                    .with_span_handle(image, context.expressions))
869                                }
870                            }
871                        }
872                        _ => {
873                            return Err(FunctionError::InvalidImageStore(
874                                ExpressionError::ExpectedGlobalVariable,
875                            )
876                            .with_span_handle(image, context.expressions))
877                        }
878                    };
879
880                    // Punch through a binding array to get the underlying type
881                    let global_ty = match context.types[var.ty].inner {
882                        Ti::BindingArray { base, .. } => &context.types[base].inner,
883                        ref inner => inner,
884                    };
885
886                    let value_ty = match *global_ty {
887                        Ti::Image {
888                            class,
889                            arrayed,
890                            dim,
891                        } => {
892                            match context
893                                .resolve_type(coordinate, &self.valid_expression_set)?
894                                .image_storage_coordinates()
895                            {
896                                Some(coord_dim) if coord_dim == dim => {}
897                                _ => {
898                                    return Err(FunctionError::InvalidImageStore(
899                                        ExpressionError::InvalidImageCoordinateType(
900                                            dim, coordinate,
901                                        ),
902                                    )
903                                    .with_span_handle(coordinate, context.expressions));
904                                }
905                            };
906                            if arrayed != array_index.is_some() {
907                                return Err(FunctionError::InvalidImageStore(
908                                    ExpressionError::InvalidImageArrayIndex,
909                                )
910                                .with_span_handle(coordinate, context.expressions));
911                            }
912                            if let Some(expr) = array_index {
913                                match *context.resolve_type(expr, &self.valid_expression_set)? {
914                                    Ti::Scalar(crate::Scalar {
915                                        kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
916                                        width: _,
917                                    }) => {}
918                                    _ => {
919                                        return Err(FunctionError::InvalidImageStore(
920                                            ExpressionError::InvalidImageArrayIndexType(expr),
921                                        )
922                                        .with_span_handle(expr, context.expressions));
923                                    }
924                                }
925                            }
926                            match class {
927                                crate::ImageClass::Storage { format, .. } => {
928                                    crate::TypeInner::Vector {
929                                        size: crate::VectorSize::Quad,
930                                        scalar: crate::Scalar {
931                                            kind: format.into(),
932                                            width: 4,
933                                        },
934                                    }
935                                }
936                                _ => {
937                                    return Err(FunctionError::InvalidImageStore(
938                                        ExpressionError::InvalidImageClass(class),
939                                    )
940                                    .with_span_handle(image, context.expressions));
941                                }
942                            }
943                        }
944                        _ => {
945                            return Err(FunctionError::InvalidImageStore(
946                                ExpressionError::ExpectedImageType(var.ty),
947                            )
948                            .with_span()
949                            .with_handle(var.ty, context.types)
950                            .with_handle(image, context.expressions))
951                        }
952                    };
953
954                    if *context.resolve_type(value, &self.valid_expression_set)? != value_ty {
955                        return Err(FunctionError::InvalidStoreValue(value)
956                            .with_span_handle(value, context.expressions));
957                    }
958                }
959                S::Call {
960                    function,
961                    ref arguments,
962                    result,
963                } => match self.validate_call(function, arguments, result, context) {
964                    Ok(callee_stages) => stages &= callee_stages,
965                    Err(error) => {
966                        return Err(error.and_then(|error| {
967                            FunctionError::InvalidCall { function, error }
968                                .with_span_static(span, "invalid function call")
969                        }))
970                    }
971                },
972                S::Atomic {
973                    pointer,
974                    ref fun,
975                    value,
976                    result,
977                } => {
978                    self.validate_atomic(pointer, fun, value, result, context)?;
979                }
980                S::WorkGroupUniformLoad { pointer, result } => {
981                    stages &= super::ShaderStages::COMPUTE;
982                    let pointer_inner =
983                        context.resolve_type(pointer, &self.valid_expression_set)?;
984                    match *pointer_inner {
985                        Ti::Pointer {
986                            space: AddressSpace::WorkGroup,
987                            ..
988                        } => {}
989                        Ti::ValuePointer {
990                            space: AddressSpace::WorkGroup,
991                            ..
992                        } => {}
993                        _ => {
994                            return Err(FunctionError::WorkgroupUniformLoadInvalidPointer(pointer)
995                                .with_span_static(span, "WorkGroupUniformLoad"))
996                        }
997                    }
998                    self.emit_expression(result, context)?;
999                    let ty = match &context.expressions[result] {
1000                        &crate::Expression::WorkGroupUniformLoadResult { ty } => ty,
1001                        _ => {
1002                            return Err(FunctionError::WorkgroupUniformLoadExpressionMismatch(
1003                                result,
1004                            )
1005                            .with_span_static(span, "WorkGroupUniformLoad"));
1006                        }
1007                    };
1008                    let expected_pointer_inner = Ti::Pointer {
1009                        base: ty,
1010                        space: AddressSpace::WorkGroup,
1011                    };
1012                    if !expected_pointer_inner.equivalent(pointer_inner, context.types) {
1013                        return Err(FunctionError::WorkgroupUniformLoadInvalidPointer(pointer)
1014                            .with_span_static(span, "WorkGroupUniformLoad"));
1015                    }
1016                }
1017                S::RayQuery { query, ref fun } => {
1018                    let query_var = match *context.get_expression(query) {
1019                        crate::Expression::LocalVariable(var) => &context.local_vars[var],
1020                        ref other => {
1021                            log::error!("Unexpected ray query expression {other:?}");
1022                            return Err(FunctionError::InvalidRayQueryExpression(query)
1023                                .with_span_static(span, "invalid query expression"));
1024                        }
1025                    };
1026                    match context.types[query_var.ty].inner {
1027                        Ti::RayQuery => {}
1028                        ref other => {
1029                            log::error!("Unexpected ray query type {other:?}");
1030                            return Err(FunctionError::InvalidRayQueryType(query_var.ty)
1031                                .with_span_static(span, "invalid query type"));
1032                        }
1033                    }
1034                    match *fun {
1035                        crate::RayQueryFunction::Initialize {
1036                            acceleration_structure,
1037                            descriptor,
1038                        } => {
1039                            match *context
1040                                .resolve_type(acceleration_structure, &self.valid_expression_set)?
1041                            {
1042                                Ti::AccelerationStructure => {}
1043                                _ => {
1044                                    return Err(FunctionError::InvalidAccelerationStructure(
1045                                        acceleration_structure,
1046                                    )
1047                                    .with_span_static(span, "invalid acceleration structure"))
1048                                }
1049                            }
1050                            let desc_ty_given =
1051                                context.resolve_type(descriptor, &self.valid_expression_set)?;
1052                            let desc_ty_expected = context
1053                                .special_types
1054                                .ray_desc
1055                                .map(|handle| &context.types[handle].inner);
1056                            if Some(desc_ty_given) != desc_ty_expected {
1057                                return Err(FunctionError::InvalidRayDescriptor(descriptor)
1058                                    .with_span_static(span, "invalid ray descriptor"));
1059                            }
1060                        }
1061                        crate::RayQueryFunction::Proceed { result } => {
1062                            self.emit_expression(result, context)?;
1063                        }
1064                        crate::RayQueryFunction::Terminate => {}
1065                    }
1066                }
1067                S::SubgroupBallot { result, predicate } => {
1068                    stages &= self.subgroup_stages;
1069                    if !self.capabilities.contains(super::Capabilities::SUBGROUP) {
1070                        return Err(FunctionError::MissingCapability(
1071                            super::Capabilities::SUBGROUP,
1072                        )
1073                        .with_span_static(span, "missing capability for this operation"));
1074                    }
1075                    if !self
1076                        .subgroup_operations
1077                        .contains(super::SubgroupOperationSet::BALLOT)
1078                    {
1079                        return Err(FunctionError::InvalidSubgroup(
1080                            SubgroupError::UnsupportedOperation(
1081                                super::SubgroupOperationSet::BALLOT,
1082                            ),
1083                        )
1084                        .with_span_static(span, "support for this operation is not present"));
1085                    }
1086                    if let Some(predicate) = predicate {
1087                        let predicate_inner =
1088                            context.resolve_type(predicate, &self.valid_expression_set)?;
1089                        if !matches!(
1090                            *predicate_inner,
1091                            crate::TypeInner::Scalar(crate::Scalar::BOOL,)
1092                        ) {
1093                            log::error!(
1094                                "Subgroup ballot predicate type {:?} expected bool",
1095                                predicate_inner
1096                            );
1097                            return Err(SubgroupError::InvalidOperand(predicate)
1098                                .with_span_handle(predicate, context.expressions)
1099                                .into_other());
1100                        }
1101                    }
1102                    self.emit_expression(result, context)?;
1103                }
1104                S::SubgroupCollectiveOperation {
1105                    ref op,
1106                    ref collective_op,
1107                    argument,
1108                    result,
1109                } => {
1110                    stages &= self.subgroup_stages;
1111                    if !self.capabilities.contains(super::Capabilities::SUBGROUP) {
1112                        return Err(FunctionError::MissingCapability(
1113                            super::Capabilities::SUBGROUP,
1114                        )
1115                        .with_span_static(span, "missing capability for this operation"));
1116                    }
1117                    let operation = op.required_operations();
1118                    if !self.subgroup_operations.contains(operation) {
1119                        return Err(FunctionError::InvalidSubgroup(
1120                            SubgroupError::UnsupportedOperation(operation),
1121                        )
1122                        .with_span_static(span, "support for this operation is not present"));
1123                    }
1124                    self.validate_subgroup_operation(op, collective_op, argument, result, context)?;
1125                }
1126                S::SubgroupGather {
1127                    ref mode,
1128                    argument,
1129                    result,
1130                } => {
1131                    stages &= self.subgroup_stages;
1132                    if !self.capabilities.contains(super::Capabilities::SUBGROUP) {
1133                        return Err(FunctionError::MissingCapability(
1134                            super::Capabilities::SUBGROUP,
1135                        )
1136                        .with_span_static(span, "missing capability for this operation"));
1137                    }
1138                    let operation = mode.required_operations();
1139                    if !self.subgroup_operations.contains(operation) {
1140                        return Err(FunctionError::InvalidSubgroup(
1141                            SubgroupError::UnsupportedOperation(operation),
1142                        )
1143                        .with_span_static(span, "support for this operation is not present"));
1144                    }
1145                    self.validate_subgroup_gather(mode, argument, result, context)?;
1146                }
1147            }
1148        }
1149        Ok(BlockInfo { stages, finished })
1150    }
1151
1152    fn validate_block(
1153        &mut self,
1154        statements: &crate::Block,
1155        context: &BlockContext,
1156    ) -> Result<BlockInfo, WithSpan<FunctionError>> {
1157        let base_expression_count = self.valid_expression_list.len();
1158        let info = self.validate_block_impl(statements, context)?;
1159        for handle in self.valid_expression_list.drain(base_expression_count..) {
1160            self.valid_expression_set.remove(handle.index());
1161        }
1162        Ok(info)
1163    }
1164
1165    fn validate_local_var(
1166        &self,
1167        var: &crate::LocalVariable,
1168        gctx: crate::proc::GlobalCtx,
1169        fun_info: &FunctionInfo,
1170        local_expr_kind: &crate::proc::ExpressionKindTracker,
1171    ) -> Result<(), LocalVariableError> {
1172        log::debug!("var {:?}", var);
1173        let type_info = self
1174            .types
1175            .get(var.ty.index())
1176            .ok_or(LocalVariableError::InvalidType(var.ty))?;
1177        if !type_info.flags.contains(super::TypeFlags::CONSTRUCTIBLE) {
1178            return Err(LocalVariableError::InvalidType(var.ty));
1179        }
1180
1181        if let Some(init) = var.init {
1182            let decl_ty = &gctx.types[var.ty].inner;
1183            let init_ty = fun_info[init].ty.inner_with(gctx.types);
1184            if !decl_ty.equivalent(init_ty, gctx.types) {
1185                return Err(LocalVariableError::InitializerType);
1186            }
1187
1188            if !local_expr_kind.is_const_or_override(init) {
1189                return Err(LocalVariableError::NonConstOrOverrideInitializer);
1190            }
1191        }
1192
1193        Ok(())
1194    }
1195
1196    pub(super) fn validate_function(
1197        &mut self,
1198        fun: &crate::Function,
1199        module: &crate::Module,
1200        mod_info: &ModuleInfo,
1201        entry_point: bool,
1202        global_expr_kind: &crate::proc::ExpressionKindTracker,
1203    ) -> Result<FunctionInfo, WithSpan<FunctionError>> {
1204        let mut info = mod_info.process_function(fun, module, self.flags, self.capabilities)?;
1205
1206        let local_expr_kind = crate::proc::ExpressionKindTracker::from_arena(&fun.expressions);
1207
1208        for (var_handle, var) in fun.local_variables.iter() {
1209            self.validate_local_var(var, module.to_ctx(), &info, &local_expr_kind)
1210                .map_err(|source| {
1211                    FunctionError::LocalVariable {
1212                        handle: var_handle,
1213                        name: var.name.clone().unwrap_or_default(),
1214                        source,
1215                    }
1216                    .with_span_handle(var.ty, &module.types)
1217                    .with_handle(var_handle, &fun.local_variables)
1218                })?;
1219        }
1220
1221        for (index, argument) in fun.arguments.iter().enumerate() {
1222            match module.types[argument.ty].inner.pointer_space() {
1223                Some(crate::AddressSpace::Private | crate::AddressSpace::Function) | None => {}
1224                Some(other) => {
1225                    return Err(FunctionError::InvalidArgumentPointerSpace {
1226                        index,
1227                        name: argument.name.clone().unwrap_or_default(),
1228                        space: other,
1229                    }
1230                    .with_span_handle(argument.ty, &module.types))
1231                }
1232            }
1233            // Check for the least informative error last.
1234            if !self.types[argument.ty.index()]
1235                .flags
1236                .contains(super::TypeFlags::ARGUMENT)
1237            {
1238                return Err(FunctionError::InvalidArgumentType {
1239                    index,
1240                    name: argument.name.clone().unwrap_or_default(),
1241                }
1242                .with_span_handle(argument.ty, &module.types));
1243            }
1244
1245            if !entry_point && argument.binding.is_some() {
1246                return Err(FunctionError::PipelineInputRegularFunction {
1247                    name: argument.name.clone().unwrap_or_default(),
1248                }
1249                .with_span_handle(argument.ty, &module.types));
1250            }
1251        }
1252
1253        if let Some(ref result) = fun.result {
1254            if !self.types[result.ty.index()]
1255                .flags
1256                .contains(super::TypeFlags::CONSTRUCTIBLE)
1257            {
1258                return Err(FunctionError::NonConstructibleReturnType
1259                    .with_span_handle(result.ty, &module.types));
1260            }
1261
1262            if !entry_point && result.binding.is_some() {
1263                return Err(FunctionError::PipelineOutputRegularFunction
1264                    .with_span_handle(result.ty, &module.types));
1265            }
1266        }
1267
1268        self.valid_expression_set.clear();
1269        self.valid_expression_list.clear();
1270        for (handle, expr) in fun.expressions.iter() {
1271            if expr.needs_pre_emit() {
1272                self.valid_expression_set.insert(handle.index());
1273            }
1274            if self.flags.contains(super::ValidationFlags::EXPRESSIONS) {
1275                match self.validate_expression(
1276                    handle,
1277                    expr,
1278                    fun,
1279                    module,
1280                    &info,
1281                    mod_info,
1282                    global_expr_kind,
1283                ) {
1284                    Ok(stages) => info.available_stages &= stages,
1285                    Err(source) => {
1286                        return Err(FunctionError::Expression { handle, source }
1287                            .with_span_handle(handle, &fun.expressions))
1288                    }
1289                }
1290            }
1291        }
1292
1293        if self.flags.contains(super::ValidationFlags::BLOCKS) {
1294            let stages = self
1295                .validate_block(
1296                    &fun.body,
1297                    &BlockContext::new(fun, module, &info, &mod_info.functions),
1298                )?
1299                .stages;
1300            info.available_stages &= stages;
1301        }
1302        Ok(info)
1303    }
1304}