naga/valid/
expression.rs

1use super::{
2    compose::validate_compose, validate_atomic_compare_exchange_struct, FunctionInfo, ModuleInfo,
3    ShaderStages, TypeFlags,
4};
5use crate::arena::UniqueArena;
6
7use crate::{
8    arena::Handle,
9    proc::{IndexableLengthError, ResolveError},
10};
11
12#[derive(Clone, Debug, thiserror::Error)]
13#[cfg_attr(test, derive(PartialEq))]
14pub enum ExpressionError {
15    #[error("Doesn't exist")]
16    DoesntExist,
17    #[error("Used by a statement before it was introduced into the scope by any of the dominating blocks")]
18    NotInScope,
19    #[error("Base type {0:?} is not compatible with this expression")]
20    InvalidBaseType(Handle<crate::Expression>),
21    #[error("Accessing with index {0:?} can't be done")]
22    InvalidIndexType(Handle<crate::Expression>),
23    #[error("Accessing {0:?} via a negative index is invalid")]
24    NegativeIndex(Handle<crate::Expression>),
25    #[error("Accessing index {1} is out of {0:?} bounds")]
26    IndexOutOfBounds(Handle<crate::Expression>, u32),
27    #[error("The expression {0:?} may only be indexed by a constant")]
28    IndexMustBeConstant(Handle<crate::Expression>),
29    #[error("Function argument {0:?} doesn't exist")]
30    FunctionArgumentDoesntExist(u32),
31    #[error("Loading of {0:?} can't be done")]
32    InvalidPointerType(Handle<crate::Expression>),
33    #[error("Array length of {0:?} can't be done")]
34    InvalidArrayType(Handle<crate::Expression>),
35    #[error("Get intersection of {0:?} can't be done")]
36    InvalidRayQueryType(Handle<crate::Expression>),
37    #[error("Splatting {0:?} can't be done")]
38    InvalidSplatType(Handle<crate::Expression>),
39    #[error("Swizzling {0:?} can't be done")]
40    InvalidVectorType(Handle<crate::Expression>),
41    #[error("Swizzle component {0:?} is outside of vector size {1:?}")]
42    InvalidSwizzleComponent(crate::SwizzleComponent, crate::VectorSize),
43    #[error(transparent)]
44    Compose(#[from] super::ComposeError),
45    #[error(transparent)]
46    IndexableLength(#[from] IndexableLengthError),
47    #[error("Operation {0:?} can't work with {1:?}")]
48    InvalidUnaryOperandType(crate::UnaryOperator, Handle<crate::Expression>),
49    #[error("Operation {0:?} can't work with {1:?} and {2:?}")]
50    InvalidBinaryOperandTypes(
51        crate::BinaryOperator,
52        Handle<crate::Expression>,
53        Handle<crate::Expression>,
54    ),
55    #[error("Selecting is not possible")]
56    InvalidSelectTypes,
57    #[error("Relational argument {0:?} is not a boolean vector")]
58    InvalidBooleanVector(Handle<crate::Expression>),
59    #[error("Relational argument {0:?} is not a float")]
60    InvalidFloatArgument(Handle<crate::Expression>),
61    #[error("Type resolution failed")]
62    Type(#[from] ResolveError),
63    #[error("Not a global variable")]
64    ExpectedGlobalVariable,
65    #[error("Not a global variable or a function argument")]
66    ExpectedGlobalOrArgument,
67    #[error("Needs to be an binding array instead of {0:?}")]
68    ExpectedBindingArrayType(Handle<crate::Type>),
69    #[error("Needs to be an image instead of {0:?}")]
70    ExpectedImageType(Handle<crate::Type>),
71    #[error("Needs to be an image instead of {0:?}")]
72    ExpectedSamplerType(Handle<crate::Type>),
73    #[error("Unable to operate on image class {0:?}")]
74    InvalidImageClass(crate::ImageClass),
75    #[error("Derivatives can only be taken from scalar and vector floats")]
76    InvalidDerivative,
77    #[error("Image array index parameter is misplaced")]
78    InvalidImageArrayIndex,
79    #[error("Inappropriate sample or level-of-detail index for texel access")]
80    InvalidImageOtherIndex,
81    #[error("Image array index type of {0:?} is not an integer scalar")]
82    InvalidImageArrayIndexType(Handle<crate::Expression>),
83    #[error("Image sample or level-of-detail index's type of {0:?} is not an integer scalar")]
84    InvalidImageOtherIndexType(Handle<crate::Expression>),
85    #[error("Image coordinate type of {1:?} does not match dimension {0:?}")]
86    InvalidImageCoordinateType(crate::ImageDimension, Handle<crate::Expression>),
87    #[error("Comparison sampling mismatch: image has class {image:?}, but the sampler is comparison={sampler}, and the reference was provided={has_ref}")]
88    ComparisonSamplingMismatch {
89        image: crate::ImageClass,
90        sampler: bool,
91        has_ref: bool,
92    },
93    #[error("Sample offset must be a const-expression")]
94    InvalidSampleOffsetExprType,
95    #[error("Sample offset constant {1:?} doesn't match the image dimension {0:?}")]
96    InvalidSampleOffset(crate::ImageDimension, Handle<crate::Expression>),
97    #[error("Depth reference {0:?} is not a scalar float")]
98    InvalidDepthReference(Handle<crate::Expression>),
99    #[error("Depth sample level can only be Auto or Zero")]
100    InvalidDepthSampleLevel,
101    #[error("Gather level can only be Zero")]
102    InvalidGatherLevel,
103    #[error("Gather component {0:?} doesn't exist in the image")]
104    InvalidGatherComponent(crate::SwizzleComponent),
105    #[error("Gather can't be done for image dimension {0:?}")]
106    InvalidGatherDimension(crate::ImageDimension),
107    #[error("Sample level (exact) type {0:?} is not a scalar float")]
108    InvalidSampleLevelExactType(Handle<crate::Expression>),
109    #[error("Sample level (bias) type {0:?} is not a scalar float")]
110    InvalidSampleLevelBiasType(Handle<crate::Expression>),
111    #[error("Sample level (gradient) of {1:?} doesn't match the image dimension {0:?}")]
112    InvalidSampleLevelGradientType(crate::ImageDimension, Handle<crate::Expression>),
113    #[error("Unable to cast")]
114    InvalidCastArgument,
115    #[error("Invalid argument count for {0:?}")]
116    WrongArgumentCount(crate::MathFunction),
117    #[error("Argument [{1}] to {0:?} as expression {2:?} has an invalid type.")]
118    InvalidArgumentType(crate::MathFunction, u32, Handle<crate::Expression>),
119    #[error("Atomic result type can't be {0:?}")]
120    InvalidAtomicResultType(Handle<crate::Type>),
121    #[error(
122        "workgroupUniformLoad result type can't be {0:?}. It can only be a constructible type."
123    )]
124    InvalidWorkGroupUniformLoadResultType(Handle<crate::Type>),
125    #[error("Shader requires capability {0:?}")]
126    MissingCapabilities(super::Capabilities),
127    #[error(transparent)]
128    Literal(#[from] LiteralError),
129    #[error("{0:?} is not supported for Width {2} {1:?} arguments yet, see https://github.com/gfx-rs/wgpu/issues/5276")]
130    UnsupportedWidth(crate::MathFunction, crate::ScalarKind, crate::Bytes),
131}
132
133#[derive(Clone, Debug, thiserror::Error)]
134#[cfg_attr(test, derive(PartialEq))]
135pub enum ConstExpressionError {
136    #[error("The expression is not a constant or override expression")]
137    NonConstOrOverride,
138    #[error("The expression is not a fully evaluated constant expression")]
139    NonFullyEvaluatedConst,
140    #[error(transparent)]
141    Compose(#[from] super::ComposeError),
142    #[error("Splatting {0:?} can't be done")]
143    InvalidSplatType(Handle<crate::Expression>),
144    #[error("Type resolution failed")]
145    Type(#[from] ResolveError),
146    #[error(transparent)]
147    Literal(#[from] LiteralError),
148    #[error(transparent)]
149    Width(#[from] super::r#type::WidthError),
150}
151
152#[derive(Clone, Debug, thiserror::Error)]
153#[cfg_attr(test, derive(PartialEq))]
154pub enum LiteralError {
155    #[error("Float literal is NaN")]
156    NaN,
157    #[error("Float literal is infinite")]
158    Infinity,
159    #[error(transparent)]
160    Width(#[from] super::r#type::WidthError),
161}
162
163struct ExpressionTypeResolver<'a> {
164    root: Handle<crate::Expression>,
165    types: &'a UniqueArena<crate::Type>,
166    info: &'a FunctionInfo,
167}
168
169impl<'a> std::ops::Index<Handle<crate::Expression>> for ExpressionTypeResolver<'a> {
170    type Output = crate::TypeInner;
171
172    #[allow(clippy::panic)]
173    fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
174        if handle < self.root {
175            self.info[handle].ty.inner_with(self.types)
176        } else {
177            // `Validator::validate_module_handles` should have caught this.
178            panic!(
179                "Depends on {:?}, which has not been processed yet",
180                self.root
181            )
182        }
183    }
184}
185
186impl super::Validator {
187    pub(super) fn validate_const_expression(
188        &self,
189        handle: Handle<crate::Expression>,
190        gctx: crate::proc::GlobalCtx,
191        mod_info: &ModuleInfo,
192        global_expr_kind: &crate::proc::ExpressionKindTracker,
193    ) -> Result<(), ConstExpressionError> {
194        use crate::Expression as E;
195
196        if !global_expr_kind.is_const_or_override(handle) {
197            return Err(ConstExpressionError::NonConstOrOverride);
198        }
199
200        match gctx.global_expressions[handle] {
201            E::Literal(literal) => {
202                self.validate_literal(literal)?;
203            }
204            E::Constant(_) | E::ZeroValue(_) => {}
205            E::Compose { ref components, ty } => {
206                validate_compose(
207                    ty,
208                    gctx,
209                    components.iter().map(|&handle| mod_info[handle].clone()),
210                )?;
211            }
212            E::Splat { value, .. } => match *mod_info[value].inner_with(gctx.types) {
213                crate::TypeInner::Scalar { .. } => {}
214                _ => return Err(ConstExpressionError::InvalidSplatType(value)),
215            },
216            _ if global_expr_kind.is_const(handle) || !self.allow_overrides => {
217                return Err(ConstExpressionError::NonFullyEvaluatedConst)
218            }
219            // the constant evaluator will report errors about override-expressions
220            _ => {}
221        }
222
223        Ok(())
224    }
225
226    #[allow(clippy::too_many_arguments)]
227    pub(super) fn validate_expression(
228        &self,
229        root: Handle<crate::Expression>,
230        expression: &crate::Expression,
231        function: &crate::Function,
232        module: &crate::Module,
233        info: &FunctionInfo,
234        mod_info: &ModuleInfo,
235        global_expr_kind: &crate::proc::ExpressionKindTracker,
236    ) -> Result<ShaderStages, ExpressionError> {
237        use crate::{Expression as E, Scalar as Sc, ScalarKind as Sk, TypeInner as Ti};
238
239        let resolver = ExpressionTypeResolver {
240            root,
241            types: &module.types,
242            info,
243        };
244
245        let stages = match *expression {
246            E::Access { base, index } => {
247                let base_type = &resolver[base];
248                // See the documentation for `Expression::Access`.
249                let dynamic_indexing_restricted = match *base_type {
250                    Ti::Vector { .. } => false,
251                    Ti::Matrix { .. } | Ti::Array { .. } => true,
252                    Ti::Pointer { .. }
253                    | Ti::ValuePointer { size: Some(_), .. }
254                    | Ti::BindingArray { .. } => false,
255                    ref other => {
256                        log::error!("Indexing of {:?}", other);
257                        return Err(ExpressionError::InvalidBaseType(base));
258                    }
259                };
260                match resolver[index] {
261                    //TODO: only allow one of these
262                    Ti::Scalar(Sc {
263                        kind: Sk::Sint | Sk::Uint,
264                        ..
265                    }) => {}
266                    ref other => {
267                        log::error!("Indexing by {:?}", other);
268                        return Err(ExpressionError::InvalidIndexType(index));
269                    }
270                }
271                if dynamic_indexing_restricted && function.expressions[index].is_dynamic_index() {
272                    return Err(ExpressionError::IndexMustBeConstant(base));
273                }
274
275                // If we know both the length and the index, we can do the
276                // bounds check now.
277                if let crate::proc::IndexableLength::Known(known_length) =
278                    base_type.indexable_length(module)?
279                {
280                    match module
281                        .to_ctx()
282                        .eval_expr_to_u32_from(index, &function.expressions)
283                    {
284                        Ok(value) => {
285                            if value >= known_length {
286                                return Err(ExpressionError::IndexOutOfBounds(base, value));
287                            }
288                        }
289                        Err(crate::proc::U32EvalError::Negative) => {
290                            return Err(ExpressionError::NegativeIndex(base))
291                        }
292                        Err(crate::proc::U32EvalError::NonConst) => {}
293                    }
294                }
295
296                ShaderStages::all()
297            }
298            E::AccessIndex { base, index } => {
299                fn resolve_index_limit(
300                    module: &crate::Module,
301                    top: Handle<crate::Expression>,
302                    ty: &crate::TypeInner,
303                    top_level: bool,
304                ) -> Result<u32, ExpressionError> {
305                    let limit = match *ty {
306                        Ti::Vector { size, .. }
307                        | Ti::ValuePointer {
308                            size: Some(size), ..
309                        } => size as u32,
310                        Ti::Matrix { columns, .. } => columns as u32,
311                        Ti::Array {
312                            size: crate::ArraySize::Constant(len),
313                            ..
314                        } => len.get(),
315                        Ti::Array { .. } | Ti::BindingArray { .. } => u32::MAX, // can't statically know, but need run-time checks
316                        Ti::Pointer { base, .. } if top_level => {
317                            resolve_index_limit(module, top, &module.types[base].inner, false)?
318                        }
319                        Ti::Struct { ref members, .. } => members.len() as u32,
320                        ref other => {
321                            log::error!("Indexing of {:?}", other);
322                            return Err(ExpressionError::InvalidBaseType(top));
323                        }
324                    };
325                    Ok(limit)
326                }
327
328                let limit = resolve_index_limit(module, base, &resolver[base], true)?;
329                if index >= limit {
330                    return Err(ExpressionError::IndexOutOfBounds(base, limit));
331                }
332                ShaderStages::all()
333            }
334            E::Splat { size: _, value } => match resolver[value] {
335                Ti::Scalar { .. } => ShaderStages::all(),
336                ref other => {
337                    log::error!("Splat scalar type {:?}", other);
338                    return Err(ExpressionError::InvalidSplatType(value));
339                }
340            },
341            E::Swizzle {
342                size,
343                vector,
344                pattern,
345            } => {
346                let vec_size = match resolver[vector] {
347                    Ti::Vector { size: vec_size, .. } => vec_size,
348                    ref other => {
349                        log::error!("Swizzle vector type {:?}", other);
350                        return Err(ExpressionError::InvalidVectorType(vector));
351                    }
352                };
353                for &sc in pattern[..size as usize].iter() {
354                    if sc as u8 >= vec_size as u8 {
355                        return Err(ExpressionError::InvalidSwizzleComponent(sc, vec_size));
356                    }
357                }
358                ShaderStages::all()
359            }
360            E::Literal(literal) => {
361                self.validate_literal(literal)?;
362                ShaderStages::all()
363            }
364            E::Constant(_) | E::Override(_) | E::ZeroValue(_) => ShaderStages::all(),
365            E::Compose { ref components, ty } => {
366                validate_compose(
367                    ty,
368                    module.to_ctx(),
369                    components.iter().map(|&handle| info[handle].ty.clone()),
370                )?;
371                ShaderStages::all()
372            }
373            E::FunctionArgument(index) => {
374                if index >= function.arguments.len() as u32 {
375                    return Err(ExpressionError::FunctionArgumentDoesntExist(index));
376                }
377                ShaderStages::all()
378            }
379            E::GlobalVariable(_handle) => ShaderStages::all(),
380            E::LocalVariable(_handle) => ShaderStages::all(),
381            E::Load { pointer } => {
382                match resolver[pointer] {
383                    Ti::Pointer { base, .. }
384                        if self.types[base.index()]
385                            .flags
386                            .contains(TypeFlags::SIZED | TypeFlags::DATA) => {}
387                    Ti::ValuePointer { .. } => {}
388                    ref other => {
389                        log::error!("Loading {:?}", other);
390                        return Err(ExpressionError::InvalidPointerType(pointer));
391                    }
392                }
393                ShaderStages::all()
394            }
395            E::ImageSample {
396                image,
397                sampler,
398                gather,
399                coordinate,
400                array_index,
401                offset,
402                level,
403                depth_ref,
404            } => {
405                // check the validity of expressions
406                let image_ty = Self::global_var_ty(module, function, image)?;
407                let sampler_ty = Self::global_var_ty(module, function, sampler)?;
408
409                let comparison = match module.types[sampler_ty].inner {
410                    Ti::Sampler { comparison } => comparison,
411                    _ => return Err(ExpressionError::ExpectedSamplerType(sampler_ty)),
412                };
413
414                let (class, dim) = match module.types[image_ty].inner {
415                    Ti::Image {
416                        class,
417                        arrayed,
418                        dim,
419                    } => {
420                        // check the array property
421                        if arrayed != array_index.is_some() {
422                            return Err(ExpressionError::InvalidImageArrayIndex);
423                        }
424                        if let Some(expr) = array_index {
425                            match resolver[expr] {
426                                Ti::Scalar(Sc {
427                                    kind: Sk::Sint | Sk::Uint,
428                                    ..
429                                }) => {}
430                                _ => return Err(ExpressionError::InvalidImageArrayIndexType(expr)),
431                            }
432                        }
433                        (class, dim)
434                    }
435                    _ => return Err(ExpressionError::ExpectedImageType(image_ty)),
436                };
437
438                // check sampling and comparison properties
439                let image_depth = match class {
440                    crate::ImageClass::Sampled {
441                        kind: crate::ScalarKind::Float,
442                        multi: false,
443                    } => false,
444                    crate::ImageClass::Sampled {
445                        kind: crate::ScalarKind::Uint | crate::ScalarKind::Sint,
446                        multi: false,
447                    } if gather.is_some() => false,
448                    crate::ImageClass::Depth { multi: false } => true,
449                    _ => return Err(ExpressionError::InvalidImageClass(class)),
450                };
451                if comparison != depth_ref.is_some() || (comparison && !image_depth) {
452                    return Err(ExpressionError::ComparisonSamplingMismatch {
453                        image: class,
454                        sampler: comparison,
455                        has_ref: depth_ref.is_some(),
456                    });
457                }
458
459                // check texture coordinates type
460                let num_components = match dim {
461                    crate::ImageDimension::D1 => 1,
462                    crate::ImageDimension::D2 => 2,
463                    crate::ImageDimension::D3 | crate::ImageDimension::Cube => 3,
464                };
465                match resolver[coordinate] {
466                    Ti::Scalar(Sc {
467                        kind: Sk::Float, ..
468                    }) if num_components == 1 => {}
469                    Ti::Vector {
470                        size,
471                        scalar:
472                            Sc {
473                                kind: Sk::Float, ..
474                            },
475                    } if size as u32 == num_components => {}
476                    _ => return Err(ExpressionError::InvalidImageCoordinateType(dim, coordinate)),
477                }
478
479                // check constant offset
480                if let Some(const_expr) = offset {
481                    if !global_expr_kind.is_const(const_expr) {
482                        return Err(ExpressionError::InvalidSampleOffsetExprType);
483                    }
484
485                    match *mod_info[const_expr].inner_with(&module.types) {
486                        Ti::Scalar(Sc { kind: Sk::Sint, .. }) if num_components == 1 => {}
487                        Ti::Vector {
488                            size,
489                            scalar: Sc { kind: Sk::Sint, .. },
490                        } if size as u32 == num_components => {}
491                        _ => {
492                            return Err(ExpressionError::InvalidSampleOffset(dim, const_expr));
493                        }
494                    }
495                }
496
497                // check depth reference type
498                if let Some(expr) = depth_ref {
499                    match resolver[expr] {
500                        Ti::Scalar(Sc {
501                            kind: Sk::Float, ..
502                        }) => {}
503                        _ => return Err(ExpressionError::InvalidDepthReference(expr)),
504                    }
505                    match level {
506                        crate::SampleLevel::Auto | crate::SampleLevel::Zero => {}
507                        _ => return Err(ExpressionError::InvalidDepthSampleLevel),
508                    }
509                }
510
511                if let Some(component) = gather {
512                    match dim {
513                        crate::ImageDimension::D2 | crate::ImageDimension::Cube => {}
514                        crate::ImageDimension::D1 | crate::ImageDimension::D3 => {
515                            return Err(ExpressionError::InvalidGatherDimension(dim))
516                        }
517                    };
518                    let max_component = match class {
519                        crate::ImageClass::Depth { .. } => crate::SwizzleComponent::X,
520                        _ => crate::SwizzleComponent::W,
521                    };
522                    if component > max_component {
523                        return Err(ExpressionError::InvalidGatherComponent(component));
524                    }
525                    match level {
526                        crate::SampleLevel::Zero => {}
527                        _ => return Err(ExpressionError::InvalidGatherLevel),
528                    }
529                }
530
531                // check level properties
532                match level {
533                    crate::SampleLevel::Auto => ShaderStages::FRAGMENT,
534                    crate::SampleLevel::Zero => ShaderStages::all(),
535                    crate::SampleLevel::Exact(expr) => {
536                        match resolver[expr] {
537                            Ti::Scalar(Sc {
538                                kind: Sk::Float, ..
539                            }) => {}
540                            _ => return Err(ExpressionError::InvalidSampleLevelExactType(expr)),
541                        }
542                        ShaderStages::all()
543                    }
544                    crate::SampleLevel::Bias(expr) => {
545                        match resolver[expr] {
546                            Ti::Scalar(Sc {
547                                kind: Sk::Float, ..
548                            }) => {}
549                            _ => return Err(ExpressionError::InvalidSampleLevelBiasType(expr)),
550                        }
551                        ShaderStages::FRAGMENT
552                    }
553                    crate::SampleLevel::Gradient { x, y } => {
554                        match resolver[x] {
555                            Ti::Scalar(Sc {
556                                kind: Sk::Float, ..
557                            }) if num_components == 1 => {}
558                            Ti::Vector {
559                                size,
560                                scalar:
561                                    Sc {
562                                        kind: Sk::Float, ..
563                                    },
564                            } if size as u32 == num_components => {}
565                            _ => {
566                                return Err(ExpressionError::InvalidSampleLevelGradientType(dim, x))
567                            }
568                        }
569                        match resolver[y] {
570                            Ti::Scalar(Sc {
571                                kind: Sk::Float, ..
572                            }) if num_components == 1 => {}
573                            Ti::Vector {
574                                size,
575                                scalar:
576                                    Sc {
577                                        kind: Sk::Float, ..
578                                    },
579                            } if size as u32 == num_components => {}
580                            _ => {
581                                return Err(ExpressionError::InvalidSampleLevelGradientType(dim, y))
582                            }
583                        }
584                        ShaderStages::all()
585                    }
586                }
587            }
588            E::ImageLoad {
589                image,
590                coordinate,
591                array_index,
592                sample,
593                level,
594            } => {
595                let ty = Self::global_var_ty(module, function, image)?;
596                match module.types[ty].inner {
597                    Ti::Image {
598                        class,
599                        arrayed,
600                        dim,
601                    } => {
602                        match resolver[coordinate].image_storage_coordinates() {
603                            Some(coord_dim) if coord_dim == dim => {}
604                            _ => {
605                                return Err(ExpressionError::InvalidImageCoordinateType(
606                                    dim, coordinate,
607                                ))
608                            }
609                        };
610                        if arrayed != array_index.is_some() {
611                            return Err(ExpressionError::InvalidImageArrayIndex);
612                        }
613                        if let Some(expr) = array_index {
614                            match resolver[expr] {
615                                Ti::Scalar(Sc {
616                                    kind: Sk::Sint | Sk::Uint,
617                                    width: _,
618                                }) => {}
619                                _ => return Err(ExpressionError::InvalidImageArrayIndexType(expr)),
620                            }
621                        }
622
623                        match (sample, class.is_multisampled()) {
624                            (None, false) => {}
625                            (Some(sample), true) => {
626                                if resolver[sample].scalar_kind() != Some(Sk::Sint) {
627                                    return Err(ExpressionError::InvalidImageOtherIndexType(
628                                        sample,
629                                    ));
630                                }
631                            }
632                            _ => {
633                                return Err(ExpressionError::InvalidImageOtherIndex);
634                            }
635                        }
636
637                        match (level, class.is_mipmapped()) {
638                            (None, false) => {}
639                            (Some(level), true) => {
640                                if resolver[level].scalar_kind() != Some(Sk::Sint) {
641                                    return Err(ExpressionError::InvalidImageOtherIndexType(level));
642                                }
643                            }
644                            _ => {
645                                return Err(ExpressionError::InvalidImageOtherIndex);
646                            }
647                        }
648                    }
649                    _ => return Err(ExpressionError::ExpectedImageType(ty)),
650                }
651                ShaderStages::all()
652            }
653            E::ImageQuery { image, query } => {
654                let ty = Self::global_var_ty(module, function, image)?;
655                match module.types[ty].inner {
656                    Ti::Image { class, arrayed, .. } => {
657                        let good = match query {
658                            crate::ImageQuery::NumLayers => arrayed,
659                            crate::ImageQuery::Size { level: None } => true,
660                            crate::ImageQuery::Size { level: Some(_) }
661                            | crate::ImageQuery::NumLevels => class.is_mipmapped(),
662                            crate::ImageQuery::NumSamples => class.is_multisampled(),
663                        };
664                        if !good {
665                            return Err(ExpressionError::InvalidImageClass(class));
666                        }
667                    }
668                    _ => return Err(ExpressionError::ExpectedImageType(ty)),
669                }
670                ShaderStages::all()
671            }
672            E::Unary { op, expr } => {
673                use crate::UnaryOperator as Uo;
674                let inner = &resolver[expr];
675                match (op, inner.scalar_kind()) {
676                    (Uo::Negate, Some(Sk::Float | Sk::Sint))
677                    | (Uo::LogicalNot, Some(Sk::Bool))
678                    | (Uo::BitwiseNot, Some(Sk::Sint | Sk::Uint)) => {}
679                    other => {
680                        log::error!("Op {:?} kind {:?}", op, other);
681                        return Err(ExpressionError::InvalidUnaryOperandType(op, expr));
682                    }
683                }
684                ShaderStages::all()
685            }
686            E::Binary { op, left, right } => {
687                use crate::BinaryOperator as Bo;
688                let left_inner = &resolver[left];
689                let right_inner = &resolver[right];
690                let good = match op {
691                    Bo::Add | Bo::Subtract => match *left_inner {
692                        Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
693                            Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
694                            Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
695                        },
696                        Ti::Matrix { .. } => left_inner == right_inner,
697                        _ => false,
698                    },
699                    Bo::Divide | Bo::Modulo => match *left_inner {
700                        Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
701                            Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
702                            Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
703                        },
704                        _ => false,
705                    },
706                    Bo::Multiply => {
707                        let kind_allowed = match left_inner.scalar_kind() {
708                            Some(Sk::Uint | Sk::Sint | Sk::Float) => true,
709                            Some(Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat) | None => false,
710                        };
711                        let types_match = match (left_inner, right_inner) {
712                            // Straight scalar and mixed scalar/vector.
713                            (&Ti::Scalar(scalar1), &Ti::Scalar(scalar2))
714                            | (
715                                &Ti::Vector {
716                                    scalar: scalar1, ..
717                                },
718                                &Ti::Scalar(scalar2),
719                            )
720                            | (
721                                &Ti::Scalar(scalar1),
722                                &Ti::Vector {
723                                    scalar: scalar2, ..
724                                },
725                            ) => scalar1 == scalar2,
726                            // Scalar/matrix.
727                            (
728                                &Ti::Scalar(Sc {
729                                    kind: Sk::Float, ..
730                                }),
731                                &Ti::Matrix { .. },
732                            )
733                            | (
734                                &Ti::Matrix { .. },
735                                &Ti::Scalar(Sc {
736                                    kind: Sk::Float, ..
737                                }),
738                            ) => true,
739                            // Vector/vector.
740                            (
741                                &Ti::Vector {
742                                    size: size1,
743                                    scalar: scalar1,
744                                },
745                                &Ti::Vector {
746                                    size: size2,
747                                    scalar: scalar2,
748                                },
749                            ) => scalar1 == scalar2 && size1 == size2,
750                            // Matrix * vector.
751                            (
752                                &Ti::Matrix { columns, .. },
753                                &Ti::Vector {
754                                    size,
755                                    scalar:
756                                        Sc {
757                                            kind: Sk::Float, ..
758                                        },
759                                },
760                            ) => columns == size,
761                            // Vector * matrix.
762                            (
763                                &Ti::Vector {
764                                    size,
765                                    scalar:
766                                        Sc {
767                                            kind: Sk::Float, ..
768                                        },
769                                },
770                                &Ti::Matrix { rows, .. },
771                            ) => size == rows,
772                            (&Ti::Matrix { columns, .. }, &Ti::Matrix { rows, .. }) => {
773                                columns == rows
774                            }
775                            _ => false,
776                        };
777                        let left_width = left_inner.scalar_width().unwrap_or(0);
778                        let right_width = right_inner.scalar_width().unwrap_or(0);
779                        kind_allowed && types_match && left_width == right_width
780                    }
781                    Bo::Equal | Bo::NotEqual => left_inner.is_sized() && left_inner == right_inner,
782                    Bo::Less | Bo::LessEqual | Bo::Greater | Bo::GreaterEqual => {
783                        match *left_inner {
784                            Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
785                                Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
786                                Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
787                            },
788                            ref other => {
789                                log::error!("Op {:?} left type {:?}", op, other);
790                                false
791                            }
792                        }
793                    }
794                    Bo::LogicalAnd | Bo::LogicalOr => match *left_inner {
795                        Ti::Scalar(Sc { kind: Sk::Bool, .. })
796                        | Ti::Vector {
797                            scalar: Sc { kind: Sk::Bool, .. },
798                            ..
799                        } => left_inner == right_inner,
800                        ref other => {
801                            log::error!("Op {:?} left type {:?}", op, other);
802                            false
803                        }
804                    },
805                    Bo::And | Bo::InclusiveOr => match *left_inner {
806                        Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
807                            Sk::Bool | Sk::Sint | Sk::Uint => left_inner == right_inner,
808                            Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false,
809                        },
810                        ref other => {
811                            log::error!("Op {:?} left type {:?}", op, other);
812                            false
813                        }
814                    },
815                    Bo::ExclusiveOr => match *left_inner {
816                        Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
817                            Sk::Sint | Sk::Uint => left_inner == right_inner,
818                            Sk::Bool | Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false,
819                        },
820                        ref other => {
821                            log::error!("Op {:?} left type {:?}", op, other);
822                            false
823                        }
824                    },
825                    Bo::ShiftLeft | Bo::ShiftRight => {
826                        let (base_size, base_scalar) = match *left_inner {
827                            Ti::Scalar(scalar) => (Ok(None), scalar),
828                            Ti::Vector { size, scalar } => (Ok(Some(size)), scalar),
829                            ref other => {
830                                log::error!("Op {:?} base type {:?}", op, other);
831                                (Err(()), Sc::BOOL)
832                            }
833                        };
834                        let shift_size = match *right_inner {
835                            Ti::Scalar(Sc { kind: Sk::Uint, .. }) => Ok(None),
836                            Ti::Vector {
837                                size,
838                                scalar: Sc { kind: Sk::Uint, .. },
839                            } => Ok(Some(size)),
840                            ref other => {
841                                log::error!("Op {:?} shift type {:?}", op, other);
842                                Err(())
843                            }
844                        };
845                        match base_scalar.kind {
846                            Sk::Sint | Sk::Uint => base_size.is_ok() && base_size == shift_size,
847                            Sk::Float | Sk::AbstractInt | Sk::AbstractFloat | Sk::Bool => false,
848                        }
849                    }
850                };
851                if !good {
852                    log::error!(
853                        "Left: {:?} of type {:?}",
854                        function.expressions[left],
855                        left_inner
856                    );
857                    log::error!(
858                        "Right: {:?} of type {:?}",
859                        function.expressions[right],
860                        right_inner
861                    );
862                    return Err(ExpressionError::InvalidBinaryOperandTypes(op, left, right));
863                }
864                ShaderStages::all()
865            }
866            E::Select {
867                condition,
868                accept,
869                reject,
870            } => {
871                let accept_inner = &resolver[accept];
872                let reject_inner = &resolver[reject];
873                let condition_good = match resolver[condition] {
874                    Ti::Scalar(Sc {
875                        kind: Sk::Bool,
876                        width: _,
877                    }) => {
878                        // When `condition` is a single boolean, `accept` and
879                        // `reject` can be vectors or scalars.
880                        match *accept_inner {
881                            Ti::Scalar { .. } | Ti::Vector { .. } => true,
882                            _ => false,
883                        }
884                    }
885                    Ti::Vector {
886                        size,
887                        scalar:
888                            Sc {
889                                kind: Sk::Bool,
890                                width: _,
891                            },
892                    } => match *accept_inner {
893                        Ti::Vector {
894                            size: other_size, ..
895                        } => size == other_size,
896                        _ => false,
897                    },
898                    _ => false,
899                };
900                if !condition_good || accept_inner != reject_inner {
901                    return Err(ExpressionError::InvalidSelectTypes);
902                }
903                ShaderStages::all()
904            }
905            E::Derivative { expr, .. } => {
906                match resolver[expr] {
907                    Ti::Scalar(Sc {
908                        kind: Sk::Float, ..
909                    })
910                    | Ti::Vector {
911                        scalar:
912                            Sc {
913                                kind: Sk::Float, ..
914                            },
915                        ..
916                    } => {}
917                    _ => return Err(ExpressionError::InvalidDerivative),
918                }
919                ShaderStages::FRAGMENT
920            }
921            E::Relational { fun, argument } => {
922                use crate::RelationalFunction as Rf;
923                let argument_inner = &resolver[argument];
924                match fun {
925                    Rf::All | Rf::Any => match *argument_inner {
926                        Ti::Vector {
927                            scalar: Sc { kind: Sk::Bool, .. },
928                            ..
929                        } => {}
930                        ref other => {
931                            log::error!("All/Any of type {:?}", other);
932                            return Err(ExpressionError::InvalidBooleanVector(argument));
933                        }
934                    },
935                    Rf::IsNan | Rf::IsInf => match *argument_inner {
936                        Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
937                            if scalar.kind == Sk::Float => {}
938                        ref other => {
939                            log::error!("Float test of type {:?}", other);
940                            return Err(ExpressionError::InvalidFloatArgument(argument));
941                        }
942                    },
943                }
944                ShaderStages::all()
945            }
946            E::Math {
947                fun,
948                arg,
949                arg1,
950                arg2,
951                arg3,
952            } => {
953                use crate::MathFunction as Mf;
954
955                let resolve = |arg| &resolver[arg];
956                let arg_ty = resolve(arg);
957                let arg1_ty = arg1.map(resolve);
958                let arg2_ty = arg2.map(resolve);
959                let arg3_ty = arg3.map(resolve);
960                match fun {
961                    Mf::Abs => {
962                        if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
963                            return Err(ExpressionError::WrongArgumentCount(fun));
964                        }
965                        let good = match *arg_ty {
966                            Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
967                                scalar.kind != Sk::Bool
968                            }
969                            _ => false,
970                        };
971                        if !good {
972                            return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
973                        }
974                    }
975                    Mf::Min | Mf::Max => {
976                        let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
977                            (Some(ty1), None, None) => ty1,
978                            _ => return Err(ExpressionError::WrongArgumentCount(fun)),
979                        };
980                        let good = match *arg_ty {
981                            Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
982                                scalar.kind != Sk::Bool
983                            }
984                            _ => false,
985                        };
986                        if !good {
987                            return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
988                        }
989                        if arg1_ty != arg_ty {
990                            return Err(ExpressionError::InvalidArgumentType(
991                                fun,
992                                1,
993                                arg1.unwrap(),
994                            ));
995                        }
996                    }
997                    Mf::Clamp => {
998                        let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
999                            (Some(ty1), Some(ty2), None) => (ty1, ty2),
1000                            _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1001                        };
1002                        let good = match *arg_ty {
1003                            Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
1004                                scalar.kind != Sk::Bool
1005                            }
1006                            _ => false,
1007                        };
1008                        if !good {
1009                            return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
1010                        }
1011                        if arg1_ty != arg_ty {
1012                            return Err(ExpressionError::InvalidArgumentType(
1013                                fun,
1014                                1,
1015                                arg1.unwrap(),
1016                            ));
1017                        }
1018                        if arg2_ty != arg_ty {
1019                            return Err(ExpressionError::InvalidArgumentType(
1020                                fun,
1021                                2,
1022                                arg2.unwrap(),
1023                            ));
1024                        }
1025                    }
1026                    Mf::Saturate
1027                    | Mf::Cos
1028                    | Mf::Cosh
1029                    | Mf::Sin
1030                    | Mf::Sinh
1031                    | Mf::Tan
1032                    | Mf::Tanh
1033                    | Mf::Acos
1034                    | Mf::Asin
1035                    | Mf::Atan
1036                    | Mf::Asinh
1037                    | Mf::Acosh
1038                    | Mf::Atanh
1039                    | Mf::Radians
1040                    | Mf::Degrees
1041                    | Mf::Ceil
1042                    | Mf::Floor
1043                    | Mf::Round
1044                    | Mf::Fract
1045                    | Mf::Trunc
1046                    | Mf::Exp
1047                    | Mf::Exp2
1048                    | Mf::Log
1049                    | Mf::Log2
1050                    | Mf::Length
1051                    | Mf::Sqrt
1052                    | Mf::InverseSqrt => {
1053                        if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1054                            return Err(ExpressionError::WrongArgumentCount(fun));
1055                        }
1056                        match *arg_ty {
1057                            Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
1058                                if scalar.kind == Sk::Float => {}
1059                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1060                        }
1061                    }
1062                    Mf::Sign => {
1063                        if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1064                            return Err(ExpressionError::WrongArgumentCount(fun));
1065                        }
1066                        match *arg_ty {
1067                            Ti::Scalar(Sc {
1068                                kind: Sk::Float | Sk::Sint,
1069                                ..
1070                            })
1071                            | Ti::Vector {
1072                                scalar:
1073                                    Sc {
1074                                        kind: Sk::Float | Sk::Sint,
1075                                        ..
1076                                    },
1077                                ..
1078                            } => {}
1079                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1080                        }
1081                    }
1082                    Mf::Atan2 | Mf::Pow | Mf::Distance | Mf::Step => {
1083                        let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
1084                            (Some(ty1), None, None) => ty1,
1085                            _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1086                        };
1087                        match *arg_ty {
1088                            Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
1089                                if scalar.kind == Sk::Float => {}
1090                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1091                        }
1092                        if arg1_ty != arg_ty {
1093                            return Err(ExpressionError::InvalidArgumentType(
1094                                fun,
1095                                1,
1096                                arg1.unwrap(),
1097                            ));
1098                        }
1099                    }
1100                    Mf::Modf | Mf::Frexp => {
1101                        if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1102                            return Err(ExpressionError::WrongArgumentCount(fun));
1103                        }
1104                        if !matches!(*arg_ty,
1105                                     Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
1106                                     if scalar.kind == Sk::Float)
1107                        {
1108                            return Err(ExpressionError::InvalidArgumentType(fun, 1, arg));
1109                        }
1110                    }
1111                    Mf::Ldexp => {
1112                        let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
1113                            (Some(ty1), None, None) => ty1,
1114                            _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1115                        };
1116                        let size0 = match *arg_ty {
1117                            Ti::Scalar(Sc {
1118                                kind: Sk::Float, ..
1119                            }) => None,
1120                            Ti::Vector {
1121                                scalar:
1122                                    Sc {
1123                                        kind: Sk::Float, ..
1124                                    },
1125                                size,
1126                            } => Some(size),
1127                            _ => {
1128                                return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
1129                            }
1130                        };
1131                        let good = match *arg1_ty {
1132                            Ti::Scalar(Sc { kind: Sk::Sint, .. }) if size0.is_none() => true,
1133                            Ti::Vector {
1134                                size,
1135                                scalar: Sc { kind: Sk::Sint, .. },
1136                            } if Some(size) == size0 => true,
1137                            _ => false,
1138                        };
1139                        if !good {
1140                            return Err(ExpressionError::InvalidArgumentType(
1141                                fun,
1142                                1,
1143                                arg1.unwrap(),
1144                            ));
1145                        }
1146                    }
1147                    Mf::Dot => {
1148                        let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
1149                            (Some(ty1), None, None) => ty1,
1150                            _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1151                        };
1152                        match *arg_ty {
1153                            Ti::Vector {
1154                                scalar:
1155                                    Sc {
1156                                        kind: Sk::Float | Sk::Sint | Sk::Uint,
1157                                        ..
1158                                    },
1159                                ..
1160                            } => {}
1161                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1162                        }
1163                        if arg1_ty != arg_ty {
1164                            return Err(ExpressionError::InvalidArgumentType(
1165                                fun,
1166                                1,
1167                                arg1.unwrap(),
1168                            ));
1169                        }
1170                    }
1171                    Mf::Outer | Mf::Cross | Mf::Reflect => {
1172                        let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
1173                            (Some(ty1), None, None) => ty1,
1174                            _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1175                        };
1176                        match *arg_ty {
1177                            Ti::Vector {
1178                                scalar:
1179                                    Sc {
1180                                        kind: Sk::Float, ..
1181                                    },
1182                                ..
1183                            } => {}
1184                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1185                        }
1186                        if arg1_ty != arg_ty {
1187                            return Err(ExpressionError::InvalidArgumentType(
1188                                fun,
1189                                1,
1190                                arg1.unwrap(),
1191                            ));
1192                        }
1193                    }
1194                    Mf::Refract => {
1195                        let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
1196                            (Some(ty1), Some(ty2), None) => (ty1, ty2),
1197                            _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1198                        };
1199
1200                        match *arg_ty {
1201                            Ti::Vector {
1202                                scalar:
1203                                    Sc {
1204                                        kind: Sk::Float, ..
1205                                    },
1206                                ..
1207                            } => {}
1208                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1209                        }
1210
1211                        if arg1_ty != arg_ty {
1212                            return Err(ExpressionError::InvalidArgumentType(
1213                                fun,
1214                                1,
1215                                arg1.unwrap(),
1216                            ));
1217                        }
1218
1219                        match (arg_ty, arg2_ty) {
1220                            (
1221                                &Ti::Vector {
1222                                    scalar:
1223                                        Sc {
1224                                            width: vector_width,
1225                                            ..
1226                                        },
1227                                    ..
1228                                },
1229                                &Ti::Scalar(Sc {
1230                                    width: scalar_width,
1231                                    kind: Sk::Float,
1232                                }),
1233                            ) if vector_width == scalar_width => {}
1234                            _ => {
1235                                return Err(ExpressionError::InvalidArgumentType(
1236                                    fun,
1237                                    2,
1238                                    arg2.unwrap(),
1239                                ))
1240                            }
1241                        }
1242                    }
1243                    Mf::Normalize => {
1244                        if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1245                            return Err(ExpressionError::WrongArgumentCount(fun));
1246                        }
1247                        match *arg_ty {
1248                            Ti::Vector {
1249                                scalar:
1250                                    Sc {
1251                                        kind: Sk::Float, ..
1252                                    },
1253                                ..
1254                            } => {}
1255                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1256                        }
1257                    }
1258                    Mf::FaceForward | Mf::Fma | Mf::SmoothStep => {
1259                        let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
1260                            (Some(ty1), Some(ty2), None) => (ty1, ty2),
1261                            _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1262                        };
1263                        match *arg_ty {
1264                            Ti::Scalar(Sc {
1265                                kind: Sk::Float, ..
1266                            })
1267                            | Ti::Vector {
1268                                scalar:
1269                                    Sc {
1270                                        kind: Sk::Float, ..
1271                                    },
1272                                ..
1273                            } => {}
1274                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1275                        }
1276                        if arg1_ty != arg_ty {
1277                            return Err(ExpressionError::InvalidArgumentType(
1278                                fun,
1279                                1,
1280                                arg1.unwrap(),
1281                            ));
1282                        }
1283                        if arg2_ty != arg_ty {
1284                            return Err(ExpressionError::InvalidArgumentType(
1285                                fun,
1286                                2,
1287                                arg2.unwrap(),
1288                            ));
1289                        }
1290                    }
1291                    Mf::Mix => {
1292                        let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
1293                            (Some(ty1), Some(ty2), None) => (ty1, ty2),
1294                            _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1295                        };
1296                        let arg_width = match *arg_ty {
1297                            Ti::Scalar(Sc {
1298                                kind: Sk::Float,
1299                                width,
1300                            })
1301                            | Ti::Vector {
1302                                scalar:
1303                                    Sc {
1304                                        kind: Sk::Float,
1305                                        width,
1306                                    },
1307                                ..
1308                            } => width,
1309                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1310                        };
1311                        if arg1_ty != arg_ty {
1312                            return Err(ExpressionError::InvalidArgumentType(
1313                                fun,
1314                                1,
1315                                arg1.unwrap(),
1316                            ));
1317                        }
1318                        // the last argument can always be a scalar
1319                        match *arg2_ty {
1320                            Ti::Scalar(Sc {
1321                                kind: Sk::Float,
1322                                width,
1323                            }) if width == arg_width => {}
1324                            _ if arg2_ty == arg_ty => {}
1325                            _ => {
1326                                return Err(ExpressionError::InvalidArgumentType(
1327                                    fun,
1328                                    2,
1329                                    arg2.unwrap(),
1330                                ));
1331                            }
1332                        }
1333                    }
1334                    Mf::Inverse | Mf::Determinant => {
1335                        if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1336                            return Err(ExpressionError::WrongArgumentCount(fun));
1337                        }
1338                        let good = match *arg_ty {
1339                            Ti::Matrix { columns, rows, .. } => columns == rows,
1340                            _ => false,
1341                        };
1342                        if !good {
1343                            return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
1344                        }
1345                    }
1346                    Mf::Transpose => {
1347                        if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1348                            return Err(ExpressionError::WrongArgumentCount(fun));
1349                        }
1350                        match *arg_ty {
1351                            Ti::Matrix { .. } => {}
1352                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1353                        }
1354                    }
1355                    // Remove once fixed https://github.com/gfx-rs/wgpu/issues/5276
1356                    Mf::CountLeadingZeros
1357                    | Mf::CountTrailingZeros
1358                    | Mf::CountOneBits
1359                    | Mf::ReverseBits
1360                    | Mf::FindMsb
1361                    | Mf::FindLsb => {
1362                        if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1363                            return Err(ExpressionError::WrongArgumentCount(fun));
1364                        }
1365                        match *arg_ty {
1366                            Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
1367                                Sk::Sint | Sk::Uint => {
1368                                    if scalar.width != 4 {
1369                                        return Err(ExpressionError::UnsupportedWidth(
1370                                            fun,
1371                                            scalar.kind,
1372                                            scalar.width,
1373                                        ));
1374                                    }
1375                                }
1376                                _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1377                            },
1378                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1379                        }
1380                    }
1381                    Mf::InsertBits => {
1382                        let (arg1_ty, arg2_ty, arg3_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
1383                            (Some(ty1), Some(ty2), Some(ty3)) => (ty1, ty2, ty3),
1384                            _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1385                        };
1386                        match *arg_ty {
1387                            Ti::Scalar(Sc {
1388                                kind: Sk::Sint | Sk::Uint,
1389                                ..
1390                            })
1391                            | Ti::Vector {
1392                                scalar:
1393                                    Sc {
1394                                        kind: Sk::Sint | Sk::Uint,
1395                                        ..
1396                                    },
1397                                ..
1398                            } => {}
1399                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1400                        }
1401                        if arg1_ty != arg_ty {
1402                            return Err(ExpressionError::InvalidArgumentType(
1403                                fun,
1404                                1,
1405                                arg1.unwrap(),
1406                            ));
1407                        }
1408                        match *arg2_ty {
1409                            Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
1410                            _ => {
1411                                return Err(ExpressionError::InvalidArgumentType(
1412                                    fun,
1413                                    2,
1414                                    arg2.unwrap(),
1415                                ))
1416                            }
1417                        }
1418                        match *arg3_ty {
1419                            Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
1420                            _ => {
1421                                return Err(ExpressionError::InvalidArgumentType(
1422                                    fun,
1423                                    2,
1424                                    arg3.unwrap(),
1425                                ))
1426                            }
1427                        }
1428                        // Remove once fixed https://github.com/gfx-rs/wgpu/issues/5276
1429                        for &arg in [arg_ty, arg1_ty, arg2_ty, arg3_ty].iter() {
1430                            match *arg {
1431                                Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
1432                                    if scalar.width != 4 {
1433                                        return Err(ExpressionError::UnsupportedWidth(
1434                                            fun,
1435                                            scalar.kind,
1436                                            scalar.width,
1437                                        ));
1438                                    }
1439                                }
1440                                _ => {}
1441                            }
1442                        }
1443                    }
1444                    Mf::ExtractBits => {
1445                        let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
1446                            (Some(ty1), Some(ty2), None) => (ty1, ty2),
1447                            _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1448                        };
1449                        match *arg_ty {
1450                            Ti::Scalar(Sc {
1451                                kind: Sk::Sint | Sk::Uint,
1452                                ..
1453                            })
1454                            | Ti::Vector {
1455                                scalar:
1456                                    Sc {
1457                                        kind: Sk::Sint | Sk::Uint,
1458                                        ..
1459                                    },
1460                                ..
1461                            } => {}
1462                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1463                        }
1464                        match *arg1_ty {
1465                            Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
1466                            _ => {
1467                                return Err(ExpressionError::InvalidArgumentType(
1468                                    fun,
1469                                    2,
1470                                    arg1.unwrap(),
1471                                ))
1472                            }
1473                        }
1474                        match *arg2_ty {
1475                            Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
1476                            _ => {
1477                                return Err(ExpressionError::InvalidArgumentType(
1478                                    fun,
1479                                    2,
1480                                    arg2.unwrap(),
1481                                ))
1482                            }
1483                        }
1484                        // Remove once fixed https://github.com/gfx-rs/wgpu/issues/5276
1485                        for &arg in [arg_ty, arg1_ty, arg2_ty].iter() {
1486                            match *arg {
1487                                Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
1488                                    if scalar.width != 4 {
1489                                        return Err(ExpressionError::UnsupportedWidth(
1490                                            fun,
1491                                            scalar.kind,
1492                                            scalar.width,
1493                                        ));
1494                                    }
1495                                }
1496                                _ => {}
1497                            }
1498                        }
1499                    }
1500                    Mf::Pack2x16unorm | Mf::Pack2x16snorm | Mf::Pack2x16float => {
1501                        if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1502                            return Err(ExpressionError::WrongArgumentCount(fun));
1503                        }
1504                        match *arg_ty {
1505                            Ti::Vector {
1506                                size: crate::VectorSize::Bi,
1507                                scalar:
1508                                    Sc {
1509                                        kind: Sk::Float, ..
1510                                    },
1511                            } => {}
1512                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1513                        }
1514                    }
1515                    Mf::Pack4x8snorm | Mf::Pack4x8unorm => {
1516                        if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1517                            return Err(ExpressionError::WrongArgumentCount(fun));
1518                        }
1519                        match *arg_ty {
1520                            Ti::Vector {
1521                                size: crate::VectorSize::Quad,
1522                                scalar:
1523                                    Sc {
1524                                        kind: Sk::Float, ..
1525                                    },
1526                            } => {}
1527                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1528                        }
1529                    }
1530                    Mf::Unpack2x16float
1531                    | Mf::Unpack2x16snorm
1532                    | Mf::Unpack2x16unorm
1533                    | Mf::Unpack4x8snorm
1534                    | Mf::Unpack4x8unorm => {
1535                        if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1536                            return Err(ExpressionError::WrongArgumentCount(fun));
1537                        }
1538                        match *arg_ty {
1539                            Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
1540                            _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1541                        }
1542                    }
1543                }
1544                ShaderStages::all()
1545            }
1546            E::As {
1547                expr,
1548                kind,
1549                convert,
1550            } => {
1551                let mut base_scalar = match resolver[expr] {
1552                    crate::TypeInner::Scalar(scalar) | crate::TypeInner::Vector { scalar, .. } => {
1553                        scalar
1554                    }
1555                    crate::TypeInner::Matrix { scalar, .. } => scalar,
1556                    _ => return Err(ExpressionError::InvalidCastArgument),
1557                };
1558                base_scalar.kind = kind;
1559                if let Some(width) = convert {
1560                    base_scalar.width = width;
1561                }
1562                if self.check_width(base_scalar).is_err() {
1563                    return Err(ExpressionError::InvalidCastArgument);
1564                }
1565                ShaderStages::all()
1566            }
1567            E::CallResult(function) => mod_info.functions[function.index()].available_stages,
1568            E::AtomicResult { ty, comparison } => {
1569                let scalar_predicate = |ty: &crate::TypeInner| match ty {
1570                    &crate::TypeInner::Scalar(
1571                        scalar @ Sc {
1572                            kind: crate::ScalarKind::Uint | crate::ScalarKind::Sint,
1573                            ..
1574                        },
1575                    ) => self.check_width(scalar).is_ok(),
1576                    _ => false,
1577                };
1578                let good = match &module.types[ty].inner {
1579                    ty if !comparison => scalar_predicate(ty),
1580                    &crate::TypeInner::Struct { ref members, .. } if comparison => {
1581                        validate_atomic_compare_exchange_struct(
1582                            &module.types,
1583                            members,
1584                            scalar_predicate,
1585                        )
1586                    }
1587                    _ => false,
1588                };
1589                if !good {
1590                    return Err(ExpressionError::InvalidAtomicResultType(ty));
1591                }
1592                ShaderStages::all()
1593            }
1594            E::WorkGroupUniformLoadResult { ty } => {
1595                if self.types[ty.index()]
1596                    .flags
1597                    // Sized | Constructible is exactly the types currently supported by
1598                    // WorkGroupUniformLoad
1599                    .contains(TypeFlags::SIZED | TypeFlags::CONSTRUCTIBLE)
1600                {
1601                    ShaderStages::COMPUTE
1602                } else {
1603                    return Err(ExpressionError::InvalidWorkGroupUniformLoadResultType(ty));
1604                }
1605            }
1606            E::ArrayLength(expr) => match resolver[expr] {
1607                Ti::Pointer { base, .. } => {
1608                    let base_ty = &resolver.types[base];
1609                    if let Ti::Array {
1610                        size: crate::ArraySize::Dynamic,
1611                        ..
1612                    } = base_ty.inner
1613                    {
1614                        ShaderStages::all()
1615                    } else {
1616                        return Err(ExpressionError::InvalidArrayType(expr));
1617                    }
1618                }
1619                ref other => {
1620                    log::error!("Array length of {:?}", other);
1621                    return Err(ExpressionError::InvalidArrayType(expr));
1622                }
1623            },
1624            E::RayQueryProceedResult => ShaderStages::all(),
1625            E::RayQueryGetIntersection {
1626                query,
1627                committed: _,
1628            } => match resolver[query] {
1629                Ti::Pointer {
1630                    base,
1631                    space: crate::AddressSpace::Function,
1632                } => match resolver.types[base].inner {
1633                    Ti::RayQuery => ShaderStages::all(),
1634                    ref other => {
1635                        log::error!("Intersection result of a pointer to {:?}", other);
1636                        return Err(ExpressionError::InvalidRayQueryType(query));
1637                    }
1638                },
1639                ref other => {
1640                    log::error!("Intersection result of {:?}", other);
1641                    return Err(ExpressionError::InvalidRayQueryType(query));
1642                }
1643            },
1644            E::SubgroupBallotResult | E::SubgroupOperationResult { .. } => self.subgroup_stages,
1645        };
1646        Ok(stages)
1647    }
1648
1649    fn global_var_ty(
1650        module: &crate::Module,
1651        function: &crate::Function,
1652        expr: Handle<crate::Expression>,
1653    ) -> Result<Handle<crate::Type>, ExpressionError> {
1654        use crate::Expression as Ex;
1655
1656        match function.expressions[expr] {
1657            Ex::GlobalVariable(var_handle) => Ok(module.global_variables[var_handle].ty),
1658            Ex::FunctionArgument(i) => Ok(function.arguments[i as usize].ty),
1659            Ex::Access { base, .. } | Ex::AccessIndex { base, .. } => {
1660                match function.expressions[base] {
1661                    Ex::GlobalVariable(var_handle) => {
1662                        let array_ty = module.global_variables[var_handle].ty;
1663
1664                        match module.types[array_ty].inner {
1665                            crate::TypeInner::BindingArray { base, .. } => Ok(base),
1666                            _ => Err(ExpressionError::ExpectedBindingArrayType(array_ty)),
1667                        }
1668                    }
1669                    _ => Err(ExpressionError::ExpectedGlobalVariable),
1670                }
1671            }
1672            _ => Err(ExpressionError::ExpectedGlobalVariable),
1673        }
1674    }
1675
1676    pub fn validate_literal(&self, literal: crate::Literal) -> Result<(), LiteralError> {
1677        self.check_width(literal.scalar())?;
1678        check_literal_value(literal)?;
1679
1680        Ok(())
1681    }
1682}
1683
1684pub fn check_literal_value(literal: crate::Literal) -> Result<(), LiteralError> {
1685    let is_nan = match literal {
1686        crate::Literal::F64(v) => v.is_nan(),
1687        crate::Literal::F32(v) => v.is_nan(),
1688        _ => false,
1689    };
1690    if is_nan {
1691        return Err(LiteralError::NaN);
1692    }
1693
1694    let is_infinite = match literal {
1695        crate::Literal::F64(v) => v.is_infinite(),
1696        crate::Literal::F32(v) => v.is_infinite(),
1697        _ => false,
1698    };
1699    if is_infinite {
1700        return Err(LiteralError::Infinity);
1701    }
1702
1703    Ok(())
1704}
1705
1706#[cfg(all(test, feature = "validate"))]
1707/// Validate a module containing the given expression, expecting an error.
1708fn validate_with_expression(
1709    expr: crate::Expression,
1710    caps: super::Capabilities,
1711) -> Result<ModuleInfo, crate::span::WithSpan<super::ValidationError>> {
1712    use crate::span::Span;
1713
1714    let mut function = crate::Function::default();
1715    function.expressions.append(expr, Span::default());
1716    function.body.push(
1717        crate::Statement::Emit(function.expressions.range_from(0)),
1718        Span::default(),
1719    );
1720
1721    let mut module = crate::Module::default();
1722    module.functions.append(function, Span::default());
1723
1724    let mut validator = super::Validator::new(super::ValidationFlags::EXPRESSIONS, caps);
1725
1726    validator.validate(&module)
1727}
1728
1729#[cfg(all(test, feature = "validate"))]
1730/// Validate a module containing the given constant expression, expecting an error.
1731fn validate_with_const_expression(
1732    expr: crate::Expression,
1733    caps: super::Capabilities,
1734) -> Result<ModuleInfo, crate::span::WithSpan<super::ValidationError>> {
1735    use crate::span::Span;
1736
1737    let mut module = crate::Module::default();
1738    module.global_expressions.append(expr, Span::default());
1739
1740    let mut validator = super::Validator::new(super::ValidationFlags::CONSTANTS, caps);
1741
1742    validator.validate(&module)
1743}
1744
1745/// Using F64 in a function's expression arena is forbidden.
1746#[cfg(feature = "validate")]
1747#[test]
1748fn f64_runtime_literals() {
1749    let result = validate_with_expression(
1750        crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1751        super::Capabilities::default(),
1752    );
1753    let error = result.unwrap_err().into_inner();
1754    assert!(matches!(
1755        error,
1756        crate::valid::ValidationError::Function {
1757            source: super::FunctionError::Expression {
1758                source: super::ExpressionError::Literal(super::LiteralError::Width(
1759                    super::r#type::WidthError::MissingCapability {
1760                        name: "f64",
1761                        flag: "FLOAT64",
1762                    }
1763                ),),
1764                ..
1765            },
1766            ..
1767        }
1768    ));
1769
1770    let result = validate_with_expression(
1771        crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1772        super::Capabilities::default() | super::Capabilities::FLOAT64,
1773    );
1774    assert!(result.is_ok());
1775}
1776
1777/// Using F64 in a module's constant expression arena is forbidden.
1778#[cfg(feature = "validate")]
1779#[test]
1780fn f64_const_literals() {
1781    let result = validate_with_const_expression(
1782        crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1783        super::Capabilities::default(),
1784    );
1785    let error = result.unwrap_err().into_inner();
1786    assert!(matches!(
1787        error,
1788        crate::valid::ValidationError::ConstExpression {
1789            source: super::ConstExpressionError::Literal(super::LiteralError::Width(
1790                super::r#type::WidthError::MissingCapability {
1791                    name: "f64",
1792                    flag: "FLOAT64",
1793                }
1794            )),
1795            ..
1796        }
1797    ));
1798
1799    let result = validate_with_const_expression(
1800        crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1801        super::Capabilities::default() | super::Capabilities::FLOAT64,
1802    );
1803    assert!(result.is_ok());
1804}
1805
1806/// Using I64 in a function's expression arena is forbidden.
1807#[cfg(feature = "validate")]
1808#[test]
1809fn i64_runtime_literals() {
1810    let result = validate_with_expression(
1811        crate::Expression::Literal(crate::Literal::I64(1729)),
1812        // There is no capability that enables this.
1813        super::Capabilities::all(),
1814    );
1815    let error = result.unwrap_err().into_inner();
1816    assert!(matches!(
1817        error,
1818        crate::valid::ValidationError::Function {
1819            source: super::FunctionError::Expression {
1820                source: super::ExpressionError::Literal(super::LiteralError::Width(
1821                    super::r#type::WidthError::Unsupported64Bit
1822                ),),
1823                ..
1824            },
1825            ..
1826        }
1827    ));
1828}
1829
1830/// Using I64 in a module's constant expression arena is forbidden.
1831#[cfg(feature = "validate")]
1832#[test]
1833fn i64_const_literals() {
1834    let result = validate_with_const_expression(
1835        crate::Expression::Literal(crate::Literal::I64(1729)),
1836        // There is no capability that enables this.
1837        super::Capabilities::all(),
1838    );
1839    let error = result.unwrap_err().into_inner();
1840    assert!(matches!(
1841        error,
1842        crate::valid::ValidationError::ConstExpression {
1843            source: super::ConstExpressionError::Literal(super::LiteralError::Width(
1844                super::r#type::WidthError::Unsupported64Bit,
1845            ),),
1846            ..
1847        }
1848    ));
1849}