naga/proc/
typifier.rs

1use crate::arena::{Arena, Handle, UniqueArena};
2
3use thiserror::Error;
4
5/// The result of computing an expression's type.
6///
7/// This is the (Rust) type returned by [`ResolveContext::resolve`] to represent
8/// the (Naga) type it ascribes to some expression.
9///
10/// You might expect such a function to simply return a `Handle<Type>`. However,
11/// we want type resolution to be a read-only process, and that would limit the
12/// possible results to types already present in the expression's associated
13/// `UniqueArena<Type>`. Naga IR does have certain expressions whose types are
14/// not certain to be present.
15///
16/// So instead, type resolution returns a `TypeResolution` enum: either a
17/// [`Handle`], referencing some type in the arena, or a [`Value`], holding a
18/// free-floating [`TypeInner`]. This extends the range to cover anything that
19/// can be represented with a `TypeInner` referring to the existing arena.
20///
21/// What sorts of expressions can have types not available in the arena?
22///
23/// -   An [`Access`] or [`AccessIndex`] expression applied to a [`Vector`] or
24///     [`Matrix`] must have a [`Scalar`] or [`Vector`] type. But since `Vector`
25///     and `Matrix` represent their element and column types implicitly, not
26///     via a handle, there may not be a suitable type in the expression's
27///     associated arena. Instead, resolving such an expression returns a
28///     `TypeResolution::Value(TypeInner::X { ... })`, where `X` is `Scalar` or
29///     `Vector`.
30///
31/// -   Similarly, the type of an [`Access`] or [`AccessIndex`] expression
32///     applied to a *pointer to* a vector or matrix must produce a *pointer to*
33///     a scalar or vector type. These cannot be represented with a
34///     [`TypeInner::Pointer`], since the `Pointer`'s `base` must point into the
35///     arena, and as before, we cannot assume that a suitable scalar or vector
36///     type is there. So we take things one step further and provide
37///     [`TypeInner::ValuePointer`], specifically for the case of pointers to
38///     scalars or vectors. This type fits in a `TypeInner` and is exactly
39///     equivalent to a `Pointer` to a `Vector` or `Scalar`.
40///
41/// So, for example, the type of an `Access` expression applied to a value of type:
42///
43/// ```ignore
44/// TypeInner::Matrix { columns, rows, width }
45/// ```
46///
47/// might be:
48///
49/// ```ignore
50/// TypeResolution::Value(TypeInner::Vector {
51///     size: rows,
52///     kind: ScalarKind::Float,
53///     width,
54/// })
55/// ```
56///
57/// and the type of an access to a pointer of address space `space` to such a
58/// matrix might be:
59///
60/// ```ignore
61/// TypeResolution::Value(TypeInner::ValuePointer {
62///     size: Some(rows),
63///     kind: ScalarKind::Float,
64///     width,
65///     space,
66/// })
67/// ```
68///
69/// [`Handle`]: TypeResolution::Handle
70/// [`Value`]: TypeResolution::Value
71///
72/// [`Access`]: crate::Expression::Access
73/// [`AccessIndex`]: crate::Expression::AccessIndex
74///
75/// [`TypeInner`]: crate::TypeInner
76/// [`Matrix`]: crate::TypeInner::Matrix
77/// [`Pointer`]: crate::TypeInner::Pointer
78/// [`Scalar`]: crate::TypeInner::Scalar
79/// [`ValuePointer`]: crate::TypeInner::ValuePointer
80/// [`Vector`]: crate::TypeInner::Vector
81///
82/// [`TypeInner::Pointer`]: crate::TypeInner::Pointer
83/// [`TypeInner::ValuePointer`]: crate::TypeInner::ValuePointer
84#[derive(Debug, PartialEq)]
85#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
86#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
87pub enum TypeResolution {
88    /// A type stored in the associated arena.
89    Handle(Handle<crate::Type>),
90
91    /// A free-floating [`TypeInner`], representing a type that may not be
92    /// available in the associated arena. However, the `TypeInner` itself may
93    /// contain `Handle<Type>` values referring to types from the arena.
94    ///
95    /// [`TypeInner`]: crate::TypeInner
96    Value(crate::TypeInner),
97}
98
99impl TypeResolution {
100    pub const fn handle(&self) -> Option<Handle<crate::Type>> {
101        match *self {
102            Self::Handle(handle) => Some(handle),
103            Self::Value(_) => None,
104        }
105    }
106
107    pub fn inner_with<'a>(&'a self, arena: &'a UniqueArena<crate::Type>) -> &'a crate::TypeInner {
108        match *self {
109            Self::Handle(handle) => &arena[handle].inner,
110            Self::Value(ref inner) => inner,
111        }
112    }
113}
114
115// Clone is only implemented for numeric variants of `TypeInner`.
116impl Clone for TypeResolution {
117    fn clone(&self) -> Self {
118        use crate::TypeInner as Ti;
119        match *self {
120            Self::Handle(handle) => Self::Handle(handle),
121            Self::Value(ref v) => Self::Value(match *v {
122                Ti::Scalar(scalar) => Ti::Scalar(scalar),
123                Ti::Vector { size, scalar } => Ti::Vector { size, scalar },
124                Ti::Matrix {
125                    rows,
126                    columns,
127                    scalar,
128                } => Ti::Matrix {
129                    rows,
130                    columns,
131                    scalar,
132                },
133                Ti::Pointer { base, space } => Ti::Pointer { base, space },
134                Ti::ValuePointer {
135                    size,
136                    scalar,
137                    space,
138                } => Ti::ValuePointer {
139                    size,
140                    scalar,
141                    space,
142                },
143                _ => unreachable!("Unexpected clone type: {:?}", v),
144            }),
145        }
146    }
147}
148
149#[derive(Clone, Debug, Error, PartialEq)]
150pub enum ResolveError {
151    #[error("Index {index} is out of bounds for expression {expr:?}")]
152    OutOfBoundsIndex {
153        expr: Handle<crate::Expression>,
154        index: u32,
155    },
156    #[error("Invalid access into expression {expr:?}, indexed: {indexed}")]
157    InvalidAccess {
158        expr: Handle<crate::Expression>,
159        indexed: bool,
160    },
161    #[error("Invalid sub-access into type {ty:?}, indexed: {indexed}")]
162    InvalidSubAccess {
163        ty: Handle<crate::Type>,
164        indexed: bool,
165    },
166    #[error("Invalid scalar {0:?}")]
167    InvalidScalar(Handle<crate::Expression>),
168    #[error("Invalid vector {0:?}")]
169    InvalidVector(Handle<crate::Expression>),
170    #[error("Invalid pointer {0:?}")]
171    InvalidPointer(Handle<crate::Expression>),
172    #[error("Invalid image {0:?}")]
173    InvalidImage(Handle<crate::Expression>),
174    #[error("Function {name} not defined")]
175    FunctionNotDefined { name: String },
176    #[error("Function without return type")]
177    FunctionReturnsVoid,
178    #[error("Incompatible operands: {0}")]
179    IncompatibleOperands(String),
180    #[error("Function argument {0} doesn't exist")]
181    FunctionArgumentNotFound(u32),
182    #[error("Special type is not registered within the module")]
183    MissingSpecialType,
184}
185
186pub struct ResolveContext<'a> {
187    pub constants: &'a Arena<crate::Constant>,
188    pub overrides: &'a Arena<crate::Override>,
189    pub types: &'a UniqueArena<crate::Type>,
190    pub special_types: &'a crate::SpecialTypes,
191    pub global_vars: &'a Arena<crate::GlobalVariable>,
192    pub local_vars: &'a Arena<crate::LocalVariable>,
193    pub functions: &'a Arena<crate::Function>,
194    pub arguments: &'a [crate::FunctionArgument],
195}
196
197impl<'a> ResolveContext<'a> {
198    /// Initialize a resolve context from the module.
199    pub const fn with_locals(
200        module: &'a crate::Module,
201        local_vars: &'a Arena<crate::LocalVariable>,
202        arguments: &'a [crate::FunctionArgument],
203    ) -> Self {
204        Self {
205            constants: &module.constants,
206            overrides: &module.overrides,
207            types: &module.types,
208            special_types: &module.special_types,
209            global_vars: &module.global_variables,
210            local_vars,
211            functions: &module.functions,
212            arguments,
213        }
214    }
215
216    /// Determine the type of `expr`.
217    ///
218    /// The `past` argument must be a closure that can resolve the types of any
219    /// expressions that `expr` refers to. These can be gathered by caching the
220    /// results of prior calls to `resolve`, perhaps as done by the
221    /// [`front::Typifier`] utility type.
222    ///
223    /// Type resolution is a read-only process: this method takes `self` by
224    /// shared reference. However, this means that we cannot add anything to
225    /// `self.types` that we might need to describe `expr`. To work around this,
226    /// this method returns a [`TypeResolution`], rather than simply returning a
227    /// `Handle<Type>`; see the documentation for [`TypeResolution`] for
228    /// details.
229    ///
230    /// [`front::Typifier`]: crate::front::Typifier
231    pub fn resolve(
232        &self,
233        expr: &crate::Expression,
234        past: impl Fn(Handle<crate::Expression>) -> Result<&'a TypeResolution, ResolveError>,
235    ) -> Result<TypeResolution, ResolveError> {
236        use crate::TypeInner as Ti;
237        let types = self.types;
238        Ok(match *expr {
239            crate::Expression::Access { base, .. } => match *past(base)?.inner_with(types) {
240                // Arrays and matrices can only be indexed dynamically behind a
241                // pointer, but that's a validation error, not a type error, so
242                // go ahead provide a type here.
243                Ti::Array { base, .. } => TypeResolution::Handle(base),
244                Ti::Matrix { rows, scalar, .. } => {
245                    TypeResolution::Value(Ti::Vector { size: rows, scalar })
246                }
247                Ti::Vector { size: _, scalar } => TypeResolution::Value(Ti::Scalar(scalar)),
248                Ti::ValuePointer {
249                    size: Some(_),
250                    scalar,
251                    space,
252                } => TypeResolution::Value(Ti::ValuePointer {
253                    size: None,
254                    scalar,
255                    space,
256                }),
257                Ti::Pointer { base, space } => {
258                    TypeResolution::Value(match types[base].inner {
259                        Ti::Array { base, .. } => Ti::Pointer { base, space },
260                        Ti::Vector { size: _, scalar } => Ti::ValuePointer {
261                            size: None,
262                            scalar,
263                            space,
264                        },
265                        // Matrices are only dynamically indexed behind a pointer
266                        Ti::Matrix {
267                            columns: _,
268                            rows,
269                            scalar,
270                        } => Ti::ValuePointer {
271                            size: Some(rows),
272                            scalar,
273                            space,
274                        },
275                        Ti::BindingArray { base, .. } => Ti::Pointer { base, space },
276                        ref other => {
277                            log::error!("Access sub-type {:?}", other);
278                            return Err(ResolveError::InvalidSubAccess {
279                                ty: base,
280                                indexed: false,
281                            });
282                        }
283                    })
284                }
285                Ti::BindingArray { base, .. } => TypeResolution::Handle(base),
286                ref other => {
287                    log::error!("Access type {:?}", other);
288                    return Err(ResolveError::InvalidAccess {
289                        expr: base,
290                        indexed: false,
291                    });
292                }
293            },
294            crate::Expression::AccessIndex { base, index } => {
295                match *past(base)?.inner_with(types) {
296                    Ti::Vector { size, scalar } => {
297                        if index >= size as u32 {
298                            return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
299                        }
300                        TypeResolution::Value(Ti::Scalar(scalar))
301                    }
302                    Ti::Matrix {
303                        columns,
304                        rows,
305                        scalar,
306                    } => {
307                        if index >= columns as u32 {
308                            return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
309                        }
310                        TypeResolution::Value(crate::TypeInner::Vector { size: rows, scalar })
311                    }
312                    Ti::Array { base, .. } => TypeResolution::Handle(base),
313                    Ti::Struct { ref members, .. } => {
314                        let member = members
315                            .get(index as usize)
316                            .ok_or(ResolveError::OutOfBoundsIndex { expr: base, index })?;
317                        TypeResolution::Handle(member.ty)
318                    }
319                    Ti::ValuePointer {
320                        size: Some(size),
321                        scalar,
322                        space,
323                    } => {
324                        if index >= size as u32 {
325                            return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
326                        }
327                        TypeResolution::Value(Ti::ValuePointer {
328                            size: None,
329                            scalar,
330                            space,
331                        })
332                    }
333                    Ti::Pointer {
334                        base: ty_base,
335                        space,
336                    } => TypeResolution::Value(match types[ty_base].inner {
337                        Ti::Array { base, .. } => Ti::Pointer { base, space },
338                        Ti::Vector { size, scalar } => {
339                            if index >= size as u32 {
340                                return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
341                            }
342                            Ti::ValuePointer {
343                                size: None,
344                                scalar,
345                                space,
346                            }
347                        }
348                        Ti::Matrix {
349                            rows,
350                            columns,
351                            scalar,
352                        } => {
353                            if index >= columns as u32 {
354                                return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
355                            }
356                            Ti::ValuePointer {
357                                size: Some(rows),
358                                scalar,
359                                space,
360                            }
361                        }
362                        Ti::Struct { ref members, .. } => {
363                            let member = members
364                                .get(index as usize)
365                                .ok_or(ResolveError::OutOfBoundsIndex { expr: base, index })?;
366                            Ti::Pointer {
367                                base: member.ty,
368                                space,
369                            }
370                        }
371                        Ti::BindingArray { base, .. } => Ti::Pointer { base, space },
372                        ref other => {
373                            log::error!("Access index sub-type {:?}", other);
374                            return Err(ResolveError::InvalidSubAccess {
375                                ty: ty_base,
376                                indexed: true,
377                            });
378                        }
379                    }),
380                    Ti::BindingArray { base, .. } => TypeResolution::Handle(base),
381                    ref other => {
382                        log::error!("Access index type {:?}", other);
383                        return Err(ResolveError::InvalidAccess {
384                            expr: base,
385                            indexed: true,
386                        });
387                    }
388                }
389            }
390            crate::Expression::Splat { size, value } => match *past(value)?.inner_with(types) {
391                Ti::Scalar(scalar) => TypeResolution::Value(Ti::Vector { size, scalar }),
392                ref other => {
393                    log::error!("Scalar type {:?}", other);
394                    return Err(ResolveError::InvalidScalar(value));
395                }
396            },
397            crate::Expression::Swizzle {
398                size,
399                vector,
400                pattern: _,
401            } => match *past(vector)?.inner_with(types) {
402                Ti::Vector { size: _, scalar } => {
403                    TypeResolution::Value(Ti::Vector { size, scalar })
404                }
405                ref other => {
406                    log::error!("Vector type {:?}", other);
407                    return Err(ResolveError::InvalidVector(vector));
408                }
409            },
410            crate::Expression::Literal(lit) => TypeResolution::Value(lit.ty_inner()),
411            crate::Expression::Constant(h) => TypeResolution::Handle(self.constants[h].ty),
412            crate::Expression::Override(h) => TypeResolution::Handle(self.overrides[h].ty),
413            crate::Expression::ZeroValue(ty) => TypeResolution::Handle(ty),
414            crate::Expression::Compose { ty, .. } => TypeResolution::Handle(ty),
415            crate::Expression::FunctionArgument(index) => {
416                let arg = self
417                    .arguments
418                    .get(index as usize)
419                    .ok_or(ResolveError::FunctionArgumentNotFound(index))?;
420                TypeResolution::Handle(arg.ty)
421            }
422            crate::Expression::GlobalVariable(h) => {
423                let var = &self.global_vars[h];
424                if var.space == crate::AddressSpace::Handle {
425                    TypeResolution::Handle(var.ty)
426                } else {
427                    TypeResolution::Value(Ti::Pointer {
428                        base: var.ty,
429                        space: var.space,
430                    })
431                }
432            }
433            crate::Expression::LocalVariable(h) => {
434                let var = &self.local_vars[h];
435                TypeResolution::Value(Ti::Pointer {
436                    base: var.ty,
437                    space: crate::AddressSpace::Function,
438                })
439            }
440            crate::Expression::Load { pointer } => match *past(pointer)?.inner_with(types) {
441                Ti::Pointer { base, space: _ } => {
442                    if let Ti::Atomic(scalar) = types[base].inner {
443                        TypeResolution::Value(Ti::Scalar(scalar))
444                    } else {
445                        TypeResolution::Handle(base)
446                    }
447                }
448                Ti::ValuePointer {
449                    size,
450                    scalar,
451                    space: _,
452                } => TypeResolution::Value(match size {
453                    Some(size) => Ti::Vector { size, scalar },
454                    None => Ti::Scalar(scalar),
455                }),
456                ref other => {
457                    log::error!("Pointer type {:?}", other);
458                    return Err(ResolveError::InvalidPointer(pointer));
459                }
460            },
461            crate::Expression::ImageSample {
462                image,
463                gather: Some(_),
464                ..
465            } => match *past(image)?.inner_with(types) {
466                Ti::Image { class, .. } => TypeResolution::Value(Ti::Vector {
467                    scalar: crate::Scalar {
468                        kind: match class {
469                            crate::ImageClass::Sampled { kind, multi: _ } => kind,
470                            _ => crate::ScalarKind::Float,
471                        },
472                        width: 4,
473                    },
474                    size: crate::VectorSize::Quad,
475                }),
476                ref other => {
477                    log::error!("Image type {:?}", other);
478                    return Err(ResolveError::InvalidImage(image));
479                }
480            },
481            crate::Expression::ImageSample { image, .. }
482            | crate::Expression::ImageLoad { image, .. } => match *past(image)?.inner_with(types) {
483                Ti::Image { class, .. } => TypeResolution::Value(match class {
484                    crate::ImageClass::Depth { multi: _ } => Ti::Scalar(crate::Scalar::F32),
485                    crate::ImageClass::Sampled { kind, multi: _ } => Ti::Vector {
486                        scalar: crate::Scalar { kind, width: 4 },
487                        size: crate::VectorSize::Quad,
488                    },
489                    crate::ImageClass::Storage { format, .. } => Ti::Vector {
490                        scalar: crate::Scalar {
491                            kind: format.into(),
492                            width: 4,
493                        },
494                        size: crate::VectorSize::Quad,
495                    },
496                }),
497                ref other => {
498                    log::error!("Image type {:?}", other);
499                    return Err(ResolveError::InvalidImage(image));
500                }
501            },
502            crate::Expression::ImageQuery { image, query } => TypeResolution::Value(match query {
503                crate::ImageQuery::Size { level: _ } => match *past(image)?.inner_with(types) {
504                    Ti::Image { dim, .. } => match dim {
505                        crate::ImageDimension::D1 => Ti::Scalar(crate::Scalar::U32),
506                        crate::ImageDimension::D2 | crate::ImageDimension::Cube => Ti::Vector {
507                            size: crate::VectorSize::Bi,
508                            scalar: crate::Scalar::U32,
509                        },
510                        crate::ImageDimension::D3 => Ti::Vector {
511                            size: crate::VectorSize::Tri,
512                            scalar: crate::Scalar::U32,
513                        },
514                    },
515                    ref other => {
516                        log::error!("Image type {:?}", other);
517                        return Err(ResolveError::InvalidImage(image));
518                    }
519                },
520                crate::ImageQuery::NumLevels
521                | crate::ImageQuery::NumLayers
522                | crate::ImageQuery::NumSamples => Ti::Scalar(crate::Scalar::U32),
523            }),
524            crate::Expression::Unary { expr, .. } => past(expr)?.clone(),
525            crate::Expression::Binary { op, left, right } => match op {
526                crate::BinaryOperator::Add
527                | crate::BinaryOperator::Subtract
528                | crate::BinaryOperator::Divide
529                | crate::BinaryOperator::Modulo => past(left)?.clone(),
530                crate::BinaryOperator::Multiply => {
531                    let (res_left, res_right) = (past(left)?, past(right)?);
532                    match (res_left.inner_with(types), res_right.inner_with(types)) {
533                        (
534                            &Ti::Matrix {
535                                columns: _,
536                                rows,
537                                scalar,
538                            },
539                            &Ti::Matrix { columns, .. },
540                        ) => TypeResolution::Value(Ti::Matrix {
541                            columns,
542                            rows,
543                            scalar,
544                        }),
545                        (
546                            &Ti::Matrix {
547                                columns: _,
548                                rows,
549                                scalar,
550                            },
551                            &Ti::Vector { .. },
552                        ) => TypeResolution::Value(Ti::Vector { size: rows, scalar }),
553                        (
554                            &Ti::Vector { .. },
555                            &Ti::Matrix {
556                                columns,
557                                rows: _,
558                                scalar,
559                            },
560                        ) => TypeResolution::Value(Ti::Vector {
561                            size: columns,
562                            scalar,
563                        }),
564                        (&Ti::Scalar { .. }, _) => res_right.clone(),
565                        (_, &Ti::Scalar { .. }) => res_left.clone(),
566                        (&Ti::Vector { .. }, &Ti::Vector { .. }) => res_left.clone(),
567                        (tl, tr) => {
568                            return Err(ResolveError::IncompatibleOperands(format!(
569                                "{tl:?} * {tr:?}"
570                            )))
571                        }
572                    }
573                }
574                crate::BinaryOperator::Equal
575                | crate::BinaryOperator::NotEqual
576                | crate::BinaryOperator::Less
577                | crate::BinaryOperator::LessEqual
578                | crate::BinaryOperator::Greater
579                | crate::BinaryOperator::GreaterEqual
580                | crate::BinaryOperator::LogicalAnd
581                | crate::BinaryOperator::LogicalOr => {
582                    let scalar = crate::Scalar::BOOL;
583                    let inner = match *past(left)?.inner_with(types) {
584                        Ti::Scalar { .. } => Ti::Scalar(scalar),
585                        Ti::Vector { size, .. } => Ti::Vector { size, scalar },
586                        ref other => {
587                            return Err(ResolveError::IncompatibleOperands(format!(
588                                "{op:?}({other:?}, _)"
589                            )))
590                        }
591                    };
592                    TypeResolution::Value(inner)
593                }
594                crate::BinaryOperator::And
595                | crate::BinaryOperator::ExclusiveOr
596                | crate::BinaryOperator::InclusiveOr
597                | crate::BinaryOperator::ShiftLeft
598                | crate::BinaryOperator::ShiftRight => past(left)?.clone(),
599            },
600            crate::Expression::AtomicResult { ty, .. } => TypeResolution::Handle(ty),
601            crate::Expression::SubgroupOperationResult { ty } => TypeResolution::Handle(ty),
602            crate::Expression::WorkGroupUniformLoadResult { ty } => TypeResolution::Handle(ty),
603            crate::Expression::Select { accept, .. } => past(accept)?.clone(),
604            crate::Expression::Derivative { expr, .. } => past(expr)?.clone(),
605            crate::Expression::Relational { fun, argument } => match fun {
606                crate::RelationalFunction::All | crate::RelationalFunction::Any => {
607                    TypeResolution::Value(Ti::Scalar(crate::Scalar::BOOL))
608                }
609                crate::RelationalFunction::IsNan | crate::RelationalFunction::IsInf => {
610                    match *past(argument)?.inner_with(types) {
611                        Ti::Scalar { .. } => TypeResolution::Value(Ti::Scalar(crate::Scalar::BOOL)),
612                        Ti::Vector { size, .. } => TypeResolution::Value(Ti::Vector {
613                            scalar: crate::Scalar::BOOL,
614                            size,
615                        }),
616                        ref other => {
617                            return Err(ResolveError::IncompatibleOperands(format!(
618                                "{fun:?}({other:?})"
619                            )))
620                        }
621                    }
622                }
623            },
624            crate::Expression::Math {
625                fun,
626                arg,
627                arg1,
628                arg2: _,
629                arg3: _,
630            } => {
631                use crate::MathFunction as Mf;
632                let res_arg = past(arg)?;
633                match fun {
634                    // comparison
635                    Mf::Abs |
636                    Mf::Min |
637                    Mf::Max |
638                    Mf::Clamp |
639                    Mf::Saturate |
640                    // trigonometry
641                    Mf::Cos |
642                    Mf::Cosh |
643                    Mf::Sin |
644                    Mf::Sinh |
645                    Mf::Tan |
646                    Mf::Tanh |
647                    Mf::Acos |
648                    Mf::Asin |
649                    Mf::Atan |
650                    Mf::Atan2 |
651                    Mf::Asinh |
652                    Mf::Acosh |
653                    Mf::Atanh |
654                    Mf::Radians |
655                    Mf::Degrees |
656                    // decomposition
657                    Mf::Ceil |
658                    Mf::Floor |
659                    Mf::Round |
660                    Mf::Fract |
661                    Mf::Trunc |
662                    Mf::Ldexp |
663                    // exponent
664                    Mf::Exp |
665                    Mf::Exp2 |
666                    Mf::Log |
667                    Mf::Log2 |
668                    Mf::Pow => res_arg.clone(),
669                    Mf::Modf | Mf::Frexp => {
670                        let (size, width) = match res_arg.inner_with(types) {
671                            &Ti::Scalar(crate::Scalar {
672                                kind: crate::ScalarKind::Float,
673                                width,
674                            }) => (None, width),
675                            &Ti::Vector {
676                                scalar: crate::Scalar {
677                                    kind: crate::ScalarKind::Float,
678                                    width,
679                                },
680                                size,
681                            } => (Some(size), width),
682                            ref other =>
683                                return Err(ResolveError::IncompatibleOperands(format!("{fun:?}({other:?}, _)")))
684                        };
685                        let result = self
686                        .special_types
687                        .predeclared_types
688                        .get(&if fun == Mf::Modf {
689                            crate::PredeclaredType::ModfResult { size, width }
690                    } else {
691                            crate::PredeclaredType::FrexpResult { size, width }
692                    })
693                        .ok_or(ResolveError::MissingSpecialType)?;
694                        TypeResolution::Handle(*result)
695                    },
696                    // geometry
697                    Mf::Dot => match *res_arg.inner_with(types) {
698                        Ti::Vector {
699                            size: _,
700                            scalar,
701                        } => TypeResolution::Value(Ti::Scalar(scalar)),
702                        ref other =>
703                            return Err(ResolveError::IncompatibleOperands(
704                                format!("{fun:?}({other:?}, _)")
705                            )),
706                    },
707                    Mf::Outer => {
708                        let arg1 = arg1.ok_or_else(|| ResolveError::IncompatibleOperands(
709                            format!("{fun:?}(_, None)")
710                        ))?;
711                        match (res_arg.inner_with(types), past(arg1)?.inner_with(types)) {
712                            (
713                                &Ti::Vector { size: columns, scalar },
714                                &Ti::Vector{ size: rows, .. }
715                            ) => TypeResolution::Value(Ti::Matrix {
716                                columns,
717                                rows,
718                                scalar,
719                            }),
720                            (left, right) =>
721                                return Err(ResolveError::IncompatibleOperands(
722                                    format!("{fun:?}({left:?}, {right:?})")
723                                )),
724                        }
725                    },
726                    Mf::Cross => res_arg.clone(),
727                    Mf::Distance |
728                    Mf::Length => match *res_arg.inner_with(types) {
729                        Ti::Scalar(scalar) |
730                        Ti::Vector {scalar,size:_} => TypeResolution::Value(Ti::Scalar(scalar)),
731                        ref other => return Err(ResolveError::IncompatibleOperands(
732                                format!("{fun:?}({other:?})")
733                            )),
734                    },
735                    Mf::Normalize |
736                    Mf::FaceForward |
737                    Mf::Reflect |
738                    Mf::Refract => res_arg.clone(),
739                    // computational
740                    Mf::Sign |
741                    Mf::Fma |
742                    Mf::Mix |
743                    Mf::Step |
744                    Mf::SmoothStep |
745                    Mf::Sqrt |
746                    Mf::InverseSqrt => res_arg.clone(),
747                    Mf::Transpose => match *res_arg.inner_with(types) {
748                        Ti::Matrix {
749                            columns,
750                            rows,
751                            scalar,
752                        } => TypeResolution::Value(Ti::Matrix {
753                            columns: rows,
754                            rows: columns,
755                            scalar,
756                        }),
757                        ref other => return Err(ResolveError::IncompatibleOperands(
758                            format!("{fun:?}({other:?})")
759                        )),
760                    },
761                    Mf::Inverse => match *res_arg.inner_with(types) {
762                        Ti::Matrix {
763                            columns,
764                            rows,
765                            scalar,
766                        } if columns == rows => TypeResolution::Value(Ti::Matrix {
767                            columns,
768                            rows,
769                            scalar,
770                        }),
771                        ref other => return Err(ResolveError::IncompatibleOperands(
772                            format!("{fun:?}({other:?})")
773                        )),
774                    },
775                    Mf::Determinant => match *res_arg.inner_with(types) {
776                        Ti::Matrix {
777                            scalar,
778                            ..
779                        } => TypeResolution::Value(Ti::Scalar(scalar)),
780                        ref other => return Err(ResolveError::IncompatibleOperands(
781                            format!("{fun:?}({other:?})")
782                        )),
783                    },
784                    // bits
785                    Mf::CountTrailingZeros |
786                    Mf::CountLeadingZeros |
787                    Mf::CountOneBits |
788                    Mf::ReverseBits |
789                    Mf::ExtractBits |
790                    Mf::InsertBits |
791                    Mf::FindLsb |
792                    Mf::FindMsb => match *res_arg.inner_with(types)  {
793                        Ti::Scalar(scalar @ crate::Scalar {
794                            kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
795                            ..
796                        }) => TypeResolution::Value(Ti::Scalar(scalar)),
797                        Ti::Vector {
798                            size,
799                            scalar: scalar @ crate::Scalar {
800                                kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
801                                ..
802                            }
803                        } => TypeResolution::Value(Ti::Vector { size, scalar }),
804                        ref other => return Err(ResolveError::IncompatibleOperands(
805                                format!("{fun:?}({other:?})")
806                            )),
807                    },
808                    // data packing
809                    Mf::Pack4x8snorm |
810                    Mf::Pack4x8unorm |
811                    Mf::Pack2x16snorm |
812                    Mf::Pack2x16unorm |
813                    Mf::Pack2x16float => TypeResolution::Value(Ti::Scalar(crate::Scalar::U32)),
814                    // data unpacking
815                    Mf::Unpack4x8snorm |
816                    Mf::Unpack4x8unorm => TypeResolution::Value(Ti::Vector {
817                        size: crate::VectorSize::Quad,
818                        scalar: crate::Scalar::F32
819                    }),
820                    Mf::Unpack2x16snorm |
821                    Mf::Unpack2x16unorm |
822                    Mf::Unpack2x16float => TypeResolution::Value(Ti::Vector {
823                        size: crate::VectorSize::Bi,
824                        scalar: crate::Scalar::F32
825                    }),
826                }
827            }
828            crate::Expression::As {
829                expr,
830                kind,
831                convert,
832            } => match *past(expr)?.inner_with(types) {
833                Ti::Scalar(crate::Scalar { width, .. }) => {
834                    TypeResolution::Value(Ti::Scalar(crate::Scalar {
835                        kind,
836                        width: convert.unwrap_or(width),
837                    }))
838                }
839                Ti::Vector {
840                    size,
841                    scalar: crate::Scalar { kind: _, width },
842                } => TypeResolution::Value(Ti::Vector {
843                    size,
844                    scalar: crate::Scalar {
845                        kind,
846                        width: convert.unwrap_or(width),
847                    },
848                }),
849                Ti::Matrix {
850                    columns,
851                    rows,
852                    mut scalar,
853                } => {
854                    if let Some(width) = convert {
855                        scalar.width = width;
856                    }
857                    TypeResolution::Value(Ti::Matrix {
858                        columns,
859                        rows,
860                        scalar,
861                    })
862                }
863                ref other => {
864                    return Err(ResolveError::IncompatibleOperands(format!(
865                        "{other:?} as {kind:?}"
866                    )))
867                }
868            },
869            crate::Expression::CallResult(function) => {
870                let result = self.functions[function]
871                    .result
872                    .as_ref()
873                    .ok_or(ResolveError::FunctionReturnsVoid)?;
874                TypeResolution::Handle(result.ty)
875            }
876            crate::Expression::ArrayLength(_) => {
877                TypeResolution::Value(Ti::Scalar(crate::Scalar::U32))
878            }
879            crate::Expression::RayQueryProceedResult => {
880                TypeResolution::Value(Ti::Scalar(crate::Scalar::BOOL))
881            }
882            crate::Expression::RayQueryGetIntersection { .. } => {
883                let result = self
884                    .special_types
885                    .ray_intersection
886                    .ok_or(ResolveError::MissingSpecialType)?;
887                TypeResolution::Handle(result)
888            }
889            crate::Expression::SubgroupBallotResult => TypeResolution::Value(Ti::Vector {
890                scalar: crate::Scalar::U32,
891                size: crate::VectorSize::Quad,
892            }),
893        })
894    }
895}
896
897#[test]
898fn test_error_size() {
899    use std::mem::size_of;
900    assert_eq!(size_of::<ResolveError>(), 32);
901}