naga/proc/
mod.rs

1/*!
2[`Module`](super::Module) processing functionality.
3*/
4
5mod constant_evaluator;
6mod emitter;
7pub mod index;
8mod layouter;
9mod namer;
10mod terminator;
11mod typifier;
12
13pub use constant_evaluator::{
14    ConstantEvaluator, ConstantEvaluatorError, ExpressionKind, ExpressionKindTracker,
15};
16pub use emitter::Emitter;
17pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError};
18pub use layouter::{Alignment, LayoutError, LayoutErrorInner, Layouter, TypeLayout};
19pub use namer::{EntryPointIndex, NameKey, Namer};
20pub use terminator::ensure_block_returns;
21pub use typifier::{ResolveContext, ResolveError, TypeResolution};
22
23impl From<super::StorageFormat> for super::ScalarKind {
24    fn from(format: super::StorageFormat) -> Self {
25        use super::{ScalarKind as Sk, StorageFormat as Sf};
26        match format {
27            Sf::R8Unorm => Sk::Float,
28            Sf::R8Snorm => Sk::Float,
29            Sf::R8Uint => Sk::Uint,
30            Sf::R8Sint => Sk::Sint,
31            Sf::R16Uint => Sk::Uint,
32            Sf::R16Sint => Sk::Sint,
33            Sf::R16Float => Sk::Float,
34            Sf::Rg8Unorm => Sk::Float,
35            Sf::Rg8Snorm => Sk::Float,
36            Sf::Rg8Uint => Sk::Uint,
37            Sf::Rg8Sint => Sk::Sint,
38            Sf::R32Uint => Sk::Uint,
39            Sf::R32Sint => Sk::Sint,
40            Sf::R32Float => Sk::Float,
41            Sf::Rg16Uint => Sk::Uint,
42            Sf::Rg16Sint => Sk::Sint,
43            Sf::Rg16Float => Sk::Float,
44            Sf::Rgba8Unorm => Sk::Float,
45            Sf::Rgba8Snorm => Sk::Float,
46            Sf::Rgba8Uint => Sk::Uint,
47            Sf::Rgba8Sint => Sk::Sint,
48            Sf::Bgra8Unorm => Sk::Float,
49            Sf::Rgb10a2Uint => Sk::Uint,
50            Sf::Rgb10a2Unorm => Sk::Float,
51            Sf::Rg11b10Float => Sk::Float,
52            Sf::Rg32Uint => Sk::Uint,
53            Sf::Rg32Sint => Sk::Sint,
54            Sf::Rg32Float => Sk::Float,
55            Sf::Rgba16Uint => Sk::Uint,
56            Sf::Rgba16Sint => Sk::Sint,
57            Sf::Rgba16Float => Sk::Float,
58            Sf::Rgba32Uint => Sk::Uint,
59            Sf::Rgba32Sint => Sk::Sint,
60            Sf::Rgba32Float => Sk::Float,
61            Sf::R16Unorm => Sk::Float,
62            Sf::R16Snorm => Sk::Float,
63            Sf::Rg16Unorm => Sk::Float,
64            Sf::Rg16Snorm => Sk::Float,
65            Sf::Rgba16Unorm => Sk::Float,
66            Sf::Rgba16Snorm => Sk::Float,
67        }
68    }
69}
70
71impl super::ScalarKind {
72    pub const fn is_numeric(self) -> bool {
73        match self {
74            crate::ScalarKind::Sint
75            | crate::ScalarKind::Uint
76            | crate::ScalarKind::Float
77            | crate::ScalarKind::AbstractInt
78            | crate::ScalarKind::AbstractFloat => true,
79            crate::ScalarKind::Bool => false,
80        }
81    }
82}
83
84impl super::Scalar {
85    pub const I32: Self = Self {
86        kind: crate::ScalarKind::Sint,
87        width: 4,
88    };
89    pub const U32: Self = Self {
90        kind: crate::ScalarKind::Uint,
91        width: 4,
92    };
93    pub const F32: Self = Self {
94        kind: crate::ScalarKind::Float,
95        width: 4,
96    };
97    pub const F64: Self = Self {
98        kind: crate::ScalarKind::Float,
99        width: 8,
100    };
101    pub const I64: Self = Self {
102        kind: crate::ScalarKind::Sint,
103        width: 8,
104    };
105    pub const U64: Self = Self {
106        kind: crate::ScalarKind::Uint,
107        width: 8,
108    };
109    pub const BOOL: Self = Self {
110        kind: crate::ScalarKind::Bool,
111        width: crate::BOOL_WIDTH,
112    };
113    pub const ABSTRACT_INT: Self = Self {
114        kind: crate::ScalarKind::AbstractInt,
115        width: crate::ABSTRACT_WIDTH,
116    };
117    pub const ABSTRACT_FLOAT: Self = Self {
118        kind: crate::ScalarKind::AbstractFloat,
119        width: crate::ABSTRACT_WIDTH,
120    };
121
122    pub const fn is_abstract(self) -> bool {
123        match self.kind {
124            crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat => true,
125            crate::ScalarKind::Sint
126            | crate::ScalarKind::Uint
127            | crate::ScalarKind::Float
128            | crate::ScalarKind::Bool => false,
129        }
130    }
131
132    /// Construct a float `Scalar` with the given width.
133    ///
134    /// This is especially common when dealing with
135    /// `TypeInner::Matrix`, where the scalar kind is implicit.
136    pub const fn float(width: crate::Bytes) -> Self {
137        Self {
138            kind: crate::ScalarKind::Float,
139            width,
140        }
141    }
142
143    pub const fn to_inner_scalar(self) -> crate::TypeInner {
144        crate::TypeInner::Scalar(self)
145    }
146
147    pub const fn to_inner_vector(self, size: crate::VectorSize) -> crate::TypeInner {
148        crate::TypeInner::Vector { size, scalar: self }
149    }
150
151    pub const fn to_inner_atomic(self) -> crate::TypeInner {
152        crate::TypeInner::Atomic(self)
153    }
154}
155
156#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
157pub enum HashableLiteral {
158    F64(u64),
159    F32(u32),
160    U32(u32),
161    I32(i32),
162    U64(u64),
163    I64(i64),
164    Bool(bool),
165    AbstractInt(i64),
166    AbstractFloat(u64),
167}
168
169impl From<crate::Literal> for HashableLiteral {
170    fn from(l: crate::Literal) -> Self {
171        match l {
172            crate::Literal::F64(v) => Self::F64(v.to_bits()),
173            crate::Literal::F32(v) => Self::F32(v.to_bits()),
174            crate::Literal::U32(v) => Self::U32(v),
175            crate::Literal::I32(v) => Self::I32(v),
176            crate::Literal::U64(v) => Self::U64(v),
177            crate::Literal::I64(v) => Self::I64(v),
178            crate::Literal::Bool(v) => Self::Bool(v),
179            crate::Literal::AbstractInt(v) => Self::AbstractInt(v),
180            crate::Literal::AbstractFloat(v) => Self::AbstractFloat(v.to_bits()),
181        }
182    }
183}
184
185impl crate::Literal {
186    pub const fn new(value: u8, scalar: crate::Scalar) -> Option<Self> {
187        match (value, scalar.kind, scalar.width) {
188            (value, crate::ScalarKind::Float, 8) => Some(Self::F64(value as _)),
189            (value, crate::ScalarKind::Float, 4) => Some(Self::F32(value as _)),
190            (value, crate::ScalarKind::Uint, 4) => Some(Self::U32(value as _)),
191            (value, crate::ScalarKind::Sint, 4) => Some(Self::I32(value as _)),
192            (value, crate::ScalarKind::Uint, 8) => Some(Self::U64(value as _)),
193            (value, crate::ScalarKind::Sint, 8) => Some(Self::I64(value as _)),
194            (1, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(true)),
195            (0, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(false)),
196            _ => None,
197        }
198    }
199
200    pub const fn zero(scalar: crate::Scalar) -> Option<Self> {
201        Self::new(0, scalar)
202    }
203
204    pub const fn one(scalar: crate::Scalar) -> Option<Self> {
205        Self::new(1, scalar)
206    }
207
208    pub const fn width(&self) -> crate::Bytes {
209        match *self {
210            Self::F64(_) | Self::I64(_) | Self::U64(_) => 8,
211            Self::F32(_) | Self::U32(_) | Self::I32(_) => 4,
212            Self::Bool(_) => crate::BOOL_WIDTH,
213            Self::AbstractInt(_) | Self::AbstractFloat(_) => crate::ABSTRACT_WIDTH,
214        }
215    }
216    pub const fn scalar(&self) -> crate::Scalar {
217        match *self {
218            Self::F64(_) => crate::Scalar::F64,
219            Self::F32(_) => crate::Scalar::F32,
220            Self::U32(_) => crate::Scalar::U32,
221            Self::I32(_) => crate::Scalar::I32,
222            Self::U64(_) => crate::Scalar::U64,
223            Self::I64(_) => crate::Scalar::I64,
224            Self::Bool(_) => crate::Scalar::BOOL,
225            Self::AbstractInt(_) => crate::Scalar::ABSTRACT_INT,
226            Self::AbstractFloat(_) => crate::Scalar::ABSTRACT_FLOAT,
227        }
228    }
229    pub const fn scalar_kind(&self) -> crate::ScalarKind {
230        self.scalar().kind
231    }
232    pub const fn ty_inner(&self) -> crate::TypeInner {
233        crate::TypeInner::Scalar(self.scalar())
234    }
235}
236
237pub const POINTER_SPAN: u32 = 4;
238
239impl super::TypeInner {
240    /// Return the scalar type of `self`.
241    ///
242    /// If `inner` is a scalar, vector, or matrix type, return
243    /// its scalar type. Otherwise, return `None`.
244    pub const fn scalar(&self) -> Option<super::Scalar> {
245        use crate::TypeInner as Ti;
246        match *self {
247            Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => Some(scalar),
248            Ti::Matrix { scalar, .. } => Some(scalar),
249            _ => None,
250        }
251    }
252
253    pub fn scalar_kind(&self) -> Option<super::ScalarKind> {
254        self.scalar().map(|scalar| scalar.kind)
255    }
256
257    /// Returns the scalar width in bytes
258    pub fn scalar_width(&self) -> Option<u8> {
259        self.scalar().map(|scalar| scalar.width)
260    }
261
262    pub const fn pointer_space(&self) -> Option<crate::AddressSpace> {
263        match *self {
264            Self::Pointer { space, .. } => Some(space),
265            Self::ValuePointer { space, .. } => Some(space),
266            _ => None,
267        }
268    }
269
270    pub fn is_atomic_pointer(&self, types: &crate::UniqueArena<crate::Type>) -> bool {
271        match *self {
272            crate::TypeInner::Pointer { base, .. } => match types[base].inner {
273                crate::TypeInner::Atomic { .. } => true,
274                _ => false,
275            },
276            _ => false,
277        }
278    }
279
280    /// Get the size of this type.
281    pub fn size(&self, _gctx: GlobalCtx) -> u32 {
282        match *self {
283            Self::Scalar(scalar) | Self::Atomic(scalar) => scalar.width as u32,
284            Self::Vector { size, scalar } => size as u32 * scalar.width as u32,
285            // matrices are treated as arrays of aligned columns
286            Self::Matrix {
287                columns,
288                rows,
289                scalar,
290            } => Alignment::from(rows) * scalar.width as u32 * columns as u32,
291            Self::Pointer { .. } | Self::ValuePointer { .. } => POINTER_SPAN,
292            Self::Array {
293                base: _,
294                size,
295                stride,
296            } => {
297                let count = match size {
298                    super::ArraySize::Constant(count) => count.get(),
299                    // A dynamically-sized array has to have at least one element
300                    super::ArraySize::Dynamic => 1,
301                };
302                count * stride
303            }
304            Self::Struct { span, .. } => span,
305            Self::Image { .. }
306            | Self::Sampler { .. }
307            | Self::AccelerationStructure
308            | Self::RayQuery
309            | Self::BindingArray { .. } => 0,
310        }
311    }
312
313    /// Return the canonical form of `self`, or `None` if it's already in
314    /// canonical form.
315    ///
316    /// Certain types have multiple representations in `TypeInner`. This
317    /// function converts all forms of equivalent types to a single
318    /// representative of their class, so that simply applying `Eq` to the
319    /// result indicates whether the types are equivalent, as far as Naga IR is
320    /// concerned.
321    pub fn canonical_form(
322        &self,
323        types: &crate::UniqueArena<crate::Type>,
324    ) -> Option<crate::TypeInner> {
325        use crate::TypeInner as Ti;
326        match *self {
327            Ti::Pointer { base, space } => match types[base].inner {
328                Ti::Scalar(scalar) => Some(Ti::ValuePointer {
329                    size: None,
330                    scalar,
331                    space,
332                }),
333                Ti::Vector { size, scalar } => Some(Ti::ValuePointer {
334                    size: Some(size),
335                    scalar,
336                    space,
337                }),
338                _ => None,
339            },
340            _ => None,
341        }
342    }
343
344    /// Compare `self` and `rhs` as types.
345    ///
346    /// This is mostly the same as `<TypeInner as Eq>::eq`, but it treats
347    /// `ValuePointer` and `Pointer` types as equivalent.
348    ///
349    /// When you know that one side of the comparison is never a pointer, it's
350    /// fine to not bother with canonicalization, and just compare `TypeInner`
351    /// values with `==`.
352    pub fn equivalent(
353        &self,
354        rhs: &crate::TypeInner,
355        types: &crate::UniqueArena<crate::Type>,
356    ) -> bool {
357        let left = self.canonical_form(types);
358        let right = rhs.canonical_form(types);
359        left.as_ref().unwrap_or(self) == right.as_ref().unwrap_or(rhs)
360    }
361
362    pub fn is_dynamically_sized(&self, types: &crate::UniqueArena<crate::Type>) -> bool {
363        use crate::TypeInner as Ti;
364        match *self {
365            Ti::Array { size, .. } => size == crate::ArraySize::Dynamic,
366            Ti::Struct { ref members, .. } => members
367                .last()
368                .map(|last| types[last.ty].inner.is_dynamically_sized(types))
369                .unwrap_or(false),
370            _ => false,
371        }
372    }
373
374    pub fn components(&self) -> Option<u32> {
375        Some(match *self {
376            Self::Vector { size, .. } => size as u32,
377            Self::Matrix { columns, .. } => columns as u32,
378            Self::Array {
379                size: crate::ArraySize::Constant(len),
380                ..
381            } => len.get(),
382            Self::Struct { ref members, .. } => members.len() as u32,
383            _ => return None,
384        })
385    }
386
387    pub fn component_type(&self, index: usize) -> Option<TypeResolution> {
388        Some(match *self {
389            Self::Vector { scalar, .. } => TypeResolution::Value(crate::TypeInner::Scalar(scalar)),
390            Self::Matrix { rows, scalar, .. } => {
391                TypeResolution::Value(crate::TypeInner::Vector { size: rows, scalar })
392            }
393            Self::Array {
394                base,
395                size: crate::ArraySize::Constant(_),
396                ..
397            } => TypeResolution::Handle(base),
398            Self::Struct { ref members, .. } => TypeResolution::Handle(members[index].ty),
399            _ => return None,
400        })
401    }
402}
403
404impl super::AddressSpace {
405    pub fn access(self) -> crate::StorageAccess {
406        use crate::StorageAccess as Sa;
407        match self {
408            crate::AddressSpace::Function
409            | crate::AddressSpace::Private
410            | crate::AddressSpace::WorkGroup => Sa::LOAD | Sa::STORE,
411            crate::AddressSpace::Uniform => Sa::LOAD,
412            crate::AddressSpace::Storage { access } => access,
413            crate::AddressSpace::Handle => Sa::LOAD,
414            crate::AddressSpace::PushConstant => Sa::LOAD,
415        }
416    }
417}
418
419impl super::MathFunction {
420    pub const fn argument_count(&self) -> usize {
421        match *self {
422            // comparison
423            Self::Abs => 1,
424            Self::Min => 2,
425            Self::Max => 2,
426            Self::Clamp => 3,
427            Self::Saturate => 1,
428            // trigonometry
429            Self::Cos => 1,
430            Self::Cosh => 1,
431            Self::Sin => 1,
432            Self::Sinh => 1,
433            Self::Tan => 1,
434            Self::Tanh => 1,
435            Self::Acos => 1,
436            Self::Asin => 1,
437            Self::Atan => 1,
438            Self::Atan2 => 2,
439            Self::Asinh => 1,
440            Self::Acosh => 1,
441            Self::Atanh => 1,
442            Self::Radians => 1,
443            Self::Degrees => 1,
444            // decomposition
445            Self::Ceil => 1,
446            Self::Floor => 1,
447            Self::Round => 1,
448            Self::Fract => 1,
449            Self::Trunc => 1,
450            Self::Modf => 1,
451            Self::Frexp => 1,
452            Self::Ldexp => 2,
453            // exponent
454            Self::Exp => 1,
455            Self::Exp2 => 1,
456            Self::Log => 1,
457            Self::Log2 => 1,
458            Self::Pow => 2,
459            // geometry
460            Self::Dot => 2,
461            Self::Outer => 2,
462            Self::Cross => 2,
463            Self::Distance => 2,
464            Self::Length => 1,
465            Self::Normalize => 1,
466            Self::FaceForward => 3,
467            Self::Reflect => 2,
468            Self::Refract => 3,
469            // computational
470            Self::Sign => 1,
471            Self::Fma => 3,
472            Self::Mix => 3,
473            Self::Step => 2,
474            Self::SmoothStep => 3,
475            Self::Sqrt => 1,
476            Self::InverseSqrt => 1,
477            Self::Inverse => 1,
478            Self::Transpose => 1,
479            Self::Determinant => 1,
480            // bits
481            Self::CountTrailingZeros => 1,
482            Self::CountLeadingZeros => 1,
483            Self::CountOneBits => 1,
484            Self::ReverseBits => 1,
485            Self::ExtractBits => 3,
486            Self::InsertBits => 4,
487            Self::FindLsb => 1,
488            Self::FindMsb => 1,
489            // data packing
490            Self::Pack4x8snorm => 1,
491            Self::Pack4x8unorm => 1,
492            Self::Pack2x16snorm => 1,
493            Self::Pack2x16unorm => 1,
494            Self::Pack2x16float => 1,
495            // data unpacking
496            Self::Unpack4x8snorm => 1,
497            Self::Unpack4x8unorm => 1,
498            Self::Unpack2x16snorm => 1,
499            Self::Unpack2x16unorm => 1,
500            Self::Unpack2x16float => 1,
501        }
502    }
503}
504
505impl crate::Expression {
506    /// Returns true if the expression is considered emitted at the start of a function.
507    pub const fn needs_pre_emit(&self) -> bool {
508        match *self {
509            Self::Literal(_)
510            | Self::Constant(_)
511            | Self::Override(_)
512            | Self::ZeroValue(_)
513            | Self::FunctionArgument(_)
514            | Self::GlobalVariable(_)
515            | Self::LocalVariable(_) => true,
516            _ => false,
517        }
518    }
519
520    /// Return true if this expression is a dynamic array index, for [`Access`].
521    ///
522    /// This method returns true if this expression is a dynamically computed
523    /// index, and as such can only be used to index matrices and arrays when
524    /// they appear behind a pointer. See the documentation for [`Access`] for
525    /// details.
526    ///
527    /// Note, this does not check the _type_ of the given expression. It's up to
528    /// the caller to establish that the `Access` expression is well-typed
529    /// through other means, like [`ResolveContext`].
530    ///
531    /// [`Access`]: crate::Expression::Access
532    /// [`ResolveContext`]: crate::proc::ResolveContext
533    pub const fn is_dynamic_index(&self) -> bool {
534        match *self {
535            Self::Literal(_) | Self::ZeroValue(_) | Self::Constant(_) => false,
536            _ => true,
537        }
538    }
539}
540
541impl crate::Function {
542    /// Return the global variable being accessed by the expression `pointer`.
543    ///
544    /// Assuming that `pointer` is a series of `Access` and `AccessIndex`
545    /// expressions that ultimately access some part of a `GlobalVariable`,
546    /// return a handle for that global.
547    ///
548    /// If the expression does not ultimately access a global variable, return
549    /// `None`.
550    pub fn originating_global(
551        &self,
552        mut pointer: crate::Handle<crate::Expression>,
553    ) -> Option<crate::Handle<crate::GlobalVariable>> {
554        loop {
555            pointer = match self.expressions[pointer] {
556                crate::Expression::Access { base, .. } => base,
557                crate::Expression::AccessIndex { base, .. } => base,
558                crate::Expression::GlobalVariable(handle) => return Some(handle),
559                crate::Expression::LocalVariable(_) => return None,
560                crate::Expression::FunctionArgument(_) => return None,
561                // There are no other expressions that produce pointer values.
562                _ => unreachable!(),
563            }
564        }
565    }
566}
567
568impl crate::SampleLevel {
569    pub const fn implicit_derivatives(&self) -> bool {
570        match *self {
571            Self::Auto | Self::Bias(_) => true,
572            Self::Zero | Self::Exact(_) | Self::Gradient { .. } => false,
573        }
574    }
575}
576
577impl crate::Binding {
578    pub const fn to_built_in(&self) -> Option<crate::BuiltIn> {
579        match *self {
580            crate::Binding::BuiltIn(built_in) => Some(built_in),
581            Self::Location { .. } => None,
582        }
583    }
584}
585
586impl super::SwizzleComponent {
587    pub const XYZW: [Self; 4] = [Self::X, Self::Y, Self::Z, Self::W];
588
589    pub const fn index(&self) -> u32 {
590        match *self {
591            Self::X => 0,
592            Self::Y => 1,
593            Self::Z => 2,
594            Self::W => 3,
595        }
596    }
597    pub const fn from_index(idx: u32) -> Self {
598        match idx {
599            0 => Self::X,
600            1 => Self::Y,
601            2 => Self::Z,
602            _ => Self::W,
603        }
604    }
605}
606
607impl super::ImageClass {
608    pub const fn is_multisampled(self) -> bool {
609        match self {
610            crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => multi,
611            crate::ImageClass::Storage { .. } => false,
612        }
613    }
614
615    pub const fn is_mipmapped(self) -> bool {
616        match self {
617            crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => !multi,
618            crate::ImageClass::Storage { .. } => false,
619        }
620    }
621}
622
623impl crate::Module {
624    pub const fn to_ctx(&self) -> GlobalCtx<'_> {
625        GlobalCtx {
626            types: &self.types,
627            constants: &self.constants,
628            overrides: &self.overrides,
629            global_expressions: &self.global_expressions,
630        }
631    }
632}
633
634#[derive(Debug)]
635pub(super) enum U32EvalError {
636    NonConst,
637    Negative,
638}
639
640#[derive(Clone, Copy)]
641pub struct GlobalCtx<'a> {
642    pub types: &'a crate::UniqueArena<crate::Type>,
643    pub constants: &'a crate::Arena<crate::Constant>,
644    pub overrides: &'a crate::Arena<crate::Override>,
645    pub global_expressions: &'a crate::Arena<crate::Expression>,
646}
647
648impl GlobalCtx<'_> {
649    /// Try to evaluate the expression in `self.global_expressions` using its `handle` and return it as a `u32`.
650    #[allow(dead_code)]
651    pub(super) fn eval_expr_to_u32(
652        &self,
653        handle: crate::Handle<crate::Expression>,
654    ) -> Result<u32, U32EvalError> {
655        self.eval_expr_to_u32_from(handle, self.global_expressions)
656    }
657
658    /// Try to evaluate the expression in the `arena` using its `handle` and return it as a `u32`.
659    pub(super) fn eval_expr_to_u32_from(
660        &self,
661        handle: crate::Handle<crate::Expression>,
662        arena: &crate::Arena<crate::Expression>,
663    ) -> Result<u32, U32EvalError> {
664        match self.eval_expr_to_literal_from(handle, arena) {
665            Some(crate::Literal::U32(value)) => Ok(value),
666            Some(crate::Literal::I32(value)) => {
667                value.try_into().map_err(|_| U32EvalError::Negative)
668            }
669            _ => Err(U32EvalError::NonConst),
670        }
671    }
672
673    #[allow(dead_code)]
674    pub(crate) fn eval_expr_to_literal(
675        &self,
676        handle: crate::Handle<crate::Expression>,
677    ) -> Option<crate::Literal> {
678        self.eval_expr_to_literal_from(handle, self.global_expressions)
679    }
680
681    fn eval_expr_to_literal_from(
682        &self,
683        handle: crate::Handle<crate::Expression>,
684        arena: &crate::Arena<crate::Expression>,
685    ) -> Option<crate::Literal> {
686        fn get(
687            gctx: GlobalCtx,
688            handle: crate::Handle<crate::Expression>,
689            arena: &crate::Arena<crate::Expression>,
690        ) -> Option<crate::Literal> {
691            match arena[handle] {
692                crate::Expression::Literal(literal) => Some(literal),
693                crate::Expression::ZeroValue(ty) => match gctx.types[ty].inner {
694                    crate::TypeInner::Scalar(scalar) => crate::Literal::zero(scalar),
695                    _ => None,
696                },
697                _ => None,
698            }
699        }
700        match arena[handle] {
701            crate::Expression::Constant(c) => {
702                get(*self, self.constants[c].init, self.global_expressions)
703            }
704            _ => get(*self, handle, arena),
705        }
706    }
707}
708
709/// Return an iterator over the individual components assembled by a
710/// `Compose` expression.
711///
712/// Given `ty` and `components` from an `Expression::Compose`, return an
713/// iterator over the components of the resulting value.
714///
715/// Normally, this would just be an iterator over `components`. However,
716/// `Compose` expressions can concatenate vectors, in which case the i'th
717/// value being composed is not generally the i'th element of `components`.
718/// This function consults `ty` to decide if this concatenation is occurring,
719/// and returns an iterator that produces the components of the result of
720/// the `Compose` expression in either case.
721pub fn flatten_compose<'arenas>(
722    ty: crate::Handle<crate::Type>,
723    components: &'arenas [crate::Handle<crate::Expression>],
724    expressions: &'arenas crate::Arena<crate::Expression>,
725    types: &'arenas crate::UniqueArena<crate::Type>,
726) -> impl Iterator<Item = crate::Handle<crate::Expression>> + 'arenas {
727    // Returning `impl Iterator` is a bit tricky. We may or may not
728    // want to flatten the components, but we have to settle on a
729    // single concrete type to return. This function returns a single
730    // iterator chain that handles both the flattening and
731    // non-flattening cases.
732    let (size, is_vector) = if let crate::TypeInner::Vector { size, .. } = types[ty].inner {
733        (size as usize, true)
734    } else {
735        (components.len(), false)
736    };
737
738    /// Flatten `Compose` expressions if `is_vector` is true.
739    fn flatten_compose<'c>(
740        component: &'c crate::Handle<crate::Expression>,
741        is_vector: bool,
742        expressions: &'c crate::Arena<crate::Expression>,
743    ) -> &'c [crate::Handle<crate::Expression>] {
744        if is_vector {
745            if let crate::Expression::Compose {
746                ty: _,
747                components: ref subcomponents,
748            } = expressions[*component]
749            {
750                return subcomponents;
751            }
752        }
753        std::slice::from_ref(component)
754    }
755
756    /// Flatten `Splat` expressions if `is_vector` is true.
757    fn flatten_splat<'c>(
758        component: &'c crate::Handle<crate::Expression>,
759        is_vector: bool,
760        expressions: &'c crate::Arena<crate::Expression>,
761    ) -> impl Iterator<Item = crate::Handle<crate::Expression>> {
762        let mut expr = *component;
763        let mut count = 1;
764        if is_vector {
765            if let crate::Expression::Splat { size, value } = expressions[expr] {
766                expr = value;
767                count = size as usize;
768            }
769        }
770        std::iter::repeat(expr).take(count)
771    }
772
773    // Expressions like `vec4(vec3(vec2(6, 7), 8), 9)` require us to
774    // flatten up to two levels of `Compose` expressions.
775    //
776    // Expressions like `vec4(vec3(1.0), 1.0)` require us to flatten
777    // `Splat` expressions. Fortunately, the operand of a `Splat` must
778    // be a scalar, so we can stop there.
779    components
780        .iter()
781        .flat_map(move |component| flatten_compose(component, is_vector, expressions))
782        .flat_map(move |component| flatten_compose(component, is_vector, expressions))
783        .flat_map(move |component| flatten_splat(component, is_vector, expressions))
784        .take(size)
785}
786
787#[test]
788fn test_matrix_size() {
789    let module = crate::Module::default();
790    assert_eq!(
791        crate::TypeInner::Matrix {
792            columns: crate::VectorSize::Tri,
793            rows: crate::VectorSize::Tri,
794            scalar: crate::Scalar::F32,
795        }
796        .size(module.to_ctx()),
797        48,
798    );
799}