naga/back/msl/
writer.rs

1use super::{sampler as sm, Error, LocationMode, Options, PipelineOptions, TranslationInfo};
2use crate::{
3    arena::Handle,
4    back,
5    proc::index,
6    proc::{self, NameKey, TypeResolution},
7    valid, FastHashMap, FastHashSet,
8};
9use bit_set::BitSet;
10use std::{
11    fmt::{Display, Error as FmtError, Formatter, Write},
12    iter,
13};
14
15/// Shorthand result used internally by the backend
16type BackendResult = Result<(), Error>;
17
18const NAMESPACE: &str = "metal";
19// The name of the array member of the Metal struct types we generate to
20// represent Naga `Array` types. See the comments in `Writer::write_type_defs`
21// for details.
22const WRAPPED_ARRAY_FIELD: &str = "inner";
23// This is a hack: we need to pass a pointer to an atomic,
24// but generally the backend isn't putting "&" in front of every pointer.
25// Some more general handling of pointers is needed to be implemented here.
26const ATOMIC_REFERENCE: &str = "&";
27
28const RT_NAMESPACE: &str = "metal::raytracing";
29const RAY_QUERY_TYPE: &str = "_RayQuery";
30const RAY_QUERY_FIELD_INTERSECTOR: &str = "intersector";
31const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection";
32const RAY_QUERY_FIELD_READY: &str = "ready";
33const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type";
34
35pub(crate) const MODF_FUNCTION: &str = "naga_modf";
36pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
37
38/// Write the Metal name for a Naga numeric type: scalar, vector, or matrix.
39///
40/// The `sizes` slice determines whether this function writes a
41/// scalar, vector, or matrix type:
42///
43/// - An empty slice produces a scalar type.
44/// - A one-element slice produces a vector type.
45/// - A two element slice `[ROWS COLUMNS]` produces a matrix of the given size.
46fn put_numeric_type(
47    out: &mut impl Write,
48    scalar: crate::Scalar,
49    sizes: &[crate::VectorSize],
50) -> Result<(), FmtError> {
51    match (scalar, sizes) {
52        (scalar, &[]) => {
53            write!(out, "{}", scalar.to_msl_name())
54        }
55        (scalar, &[rows]) => {
56            write!(
57                out,
58                "{}::{}{}",
59                NAMESPACE,
60                scalar.to_msl_name(),
61                back::vector_size_str(rows)
62            )
63        }
64        (scalar, &[rows, columns]) => {
65            write!(
66                out,
67                "{}::{}{}x{}",
68                NAMESPACE,
69                scalar.to_msl_name(),
70                back::vector_size_str(columns),
71                back::vector_size_str(rows)
72            )
73        }
74        (_, _) => Ok(()), // not meaningful
75    }
76}
77
78/// Prefix for cached clamped level-of-detail values for `ImageLoad` expressions.
79const CLAMPED_LOD_LOAD_PREFIX: &str = "clamped_lod_e";
80
81struct TypeContext<'a> {
82    handle: Handle<crate::Type>,
83    gctx: proc::GlobalCtx<'a>,
84    names: &'a FastHashMap<NameKey, String>,
85    access: crate::StorageAccess,
86    binding: Option<&'a super::ResolvedBinding>,
87    first_time: bool,
88}
89
90impl<'a> Display for TypeContext<'a> {
91    fn fmt(&self, out: &mut Formatter<'_>) -> Result<(), FmtError> {
92        let ty = &self.gctx.types[self.handle];
93        if ty.needs_alias() && !self.first_time {
94            let name = &self.names[&NameKey::Type(self.handle)];
95            return write!(out, "{name}");
96        }
97
98        match ty.inner {
99            crate::TypeInner::Scalar(scalar) => put_numeric_type(out, scalar, &[]),
100            crate::TypeInner::Atomic(scalar) => {
101                write!(out, "{}::atomic_{}", NAMESPACE, scalar.to_msl_name())
102            }
103            crate::TypeInner::Vector { size, scalar } => put_numeric_type(out, scalar, &[size]),
104            crate::TypeInner::Matrix { columns, rows, .. } => {
105                put_numeric_type(out, crate::Scalar::F32, &[rows, columns])
106            }
107            crate::TypeInner::Pointer { base, space } => {
108                let sub = Self {
109                    handle: base,
110                    first_time: false,
111                    ..*self
112                };
113                let space_name = match space.to_msl_name() {
114                    Some(name) => name,
115                    None => return Ok(()),
116                };
117                write!(out, "{space_name} {sub}&")
118            }
119            crate::TypeInner::ValuePointer {
120                size,
121                scalar,
122                space,
123            } => {
124                match space.to_msl_name() {
125                    Some(name) => write!(out, "{name} ")?,
126                    None => return Ok(()),
127                };
128                match size {
129                    Some(rows) => put_numeric_type(out, scalar, &[rows])?,
130                    None => put_numeric_type(out, scalar, &[])?,
131                };
132
133                write!(out, "&")
134            }
135            crate::TypeInner::Array { base, .. } => {
136                let sub = Self {
137                    handle: base,
138                    first_time: false,
139                    ..*self
140                };
141                // Array lengths go at the end of the type definition,
142                // so just print the element type here.
143                write!(out, "{sub}")
144            }
145            crate::TypeInner::Struct { .. } => unreachable!(),
146            crate::TypeInner::Image {
147                dim,
148                arrayed,
149                class,
150            } => {
151                let dim_str = match dim {
152                    crate::ImageDimension::D1 => "1d",
153                    crate::ImageDimension::D2 => "2d",
154                    crate::ImageDimension::D3 => "3d",
155                    crate::ImageDimension::Cube => "cube",
156                };
157                let (texture_str, msaa_str, kind, access) = match class {
158                    crate::ImageClass::Sampled { kind, multi } => {
159                        let (msaa_str, access) = if multi {
160                            ("_ms", "read")
161                        } else {
162                            ("", "sample")
163                        };
164                        ("texture", msaa_str, kind, access)
165                    }
166                    crate::ImageClass::Depth { multi } => {
167                        let (msaa_str, access) = if multi {
168                            ("_ms", "read")
169                        } else {
170                            ("", "sample")
171                        };
172                        ("depth", msaa_str, crate::ScalarKind::Float, access)
173                    }
174                    crate::ImageClass::Storage { format, .. } => {
175                        let access = if self
176                            .access
177                            .contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE)
178                        {
179                            "read_write"
180                        } else if self.access.contains(crate::StorageAccess::STORE) {
181                            "write"
182                        } else if self.access.contains(crate::StorageAccess::LOAD) {
183                            "read"
184                        } else {
185                            log::warn!(
186                                "Storage access for {:?} (name '{}'): {:?}",
187                                self.handle,
188                                ty.name.as_deref().unwrap_or_default(),
189                                self.access
190                            );
191                            unreachable!("module is not valid");
192                        };
193                        ("texture", "", format.into(), access)
194                    }
195                };
196                let base_name = crate::Scalar { kind, width: 4 }.to_msl_name();
197                let array_str = if arrayed { "_array" } else { "" };
198                write!(
199                    out,
200                    "{NAMESPACE}::{texture_str}{dim_str}{msaa_str}{array_str}<{base_name}, {NAMESPACE}::access::{access}>",
201                )
202            }
203            crate::TypeInner::Sampler { comparison: _ } => {
204                write!(out, "{NAMESPACE}::sampler")
205            }
206            crate::TypeInner::AccelerationStructure => {
207                write!(out, "{RT_NAMESPACE}::instance_acceleration_structure")
208            }
209            crate::TypeInner::RayQuery => {
210                write!(out, "{RAY_QUERY_TYPE}")
211            }
212            crate::TypeInner::BindingArray { base, size } => {
213                let base_tyname = Self {
214                    handle: base,
215                    first_time: false,
216                    ..*self
217                };
218
219                if let Some(&super::ResolvedBinding::Resource(super::BindTarget {
220                    binding_array_size: Some(override_size),
221                    ..
222                })) = self.binding
223                {
224                    write!(out, "{NAMESPACE}::array<{base_tyname}, {override_size}>")
225                } else if let crate::ArraySize::Constant(size) = size {
226                    write!(out, "{NAMESPACE}::array<{base_tyname}, {size}>")
227                } else {
228                    unreachable!("metal requires all arrays be constant sized");
229                }
230            }
231        }
232    }
233}
234
235struct TypedGlobalVariable<'a> {
236    module: &'a crate::Module,
237    names: &'a FastHashMap<NameKey, String>,
238    handle: Handle<crate::GlobalVariable>,
239    usage: valid::GlobalUse,
240    binding: Option<&'a super::ResolvedBinding>,
241    reference: bool,
242}
243
244impl<'a> TypedGlobalVariable<'a> {
245    fn try_fmt<W: Write>(&self, out: &mut W) -> BackendResult {
246        let var = &self.module.global_variables[self.handle];
247        let name = &self.names[&NameKey::GlobalVariable(self.handle)];
248
249        let storage_access = match var.space {
250            crate::AddressSpace::Storage { access } => access,
251            _ => match self.module.types[var.ty].inner {
252                crate::TypeInner::Image {
253                    class: crate::ImageClass::Storage { access, .. },
254                    ..
255                } => access,
256                crate::TypeInner::BindingArray { base, .. } => {
257                    match self.module.types[base].inner {
258                        crate::TypeInner::Image {
259                            class: crate::ImageClass::Storage { access, .. },
260                            ..
261                        } => access,
262                        _ => crate::StorageAccess::default(),
263                    }
264                }
265                _ => crate::StorageAccess::default(),
266            },
267        };
268        let ty_name = TypeContext {
269            handle: var.ty,
270            gctx: self.module.to_ctx(),
271            names: self.names,
272            access: storage_access,
273            binding: self.binding,
274            first_time: false,
275        };
276
277        let (space, access, reference) = match var.space.to_msl_name() {
278            Some(space) if self.reference => {
279                let access = if var.space.needs_access_qualifier()
280                    && !self.usage.contains(valid::GlobalUse::WRITE)
281                {
282                    "const"
283                } else {
284                    ""
285                };
286                (space, access, "&")
287            }
288            _ => ("", "", ""),
289        };
290
291        Ok(write!(
292            out,
293            "{}{}{}{}{}{} {}",
294            space,
295            if space.is_empty() { "" } else { " " },
296            ty_name,
297            if access.is_empty() { "" } else { " " },
298            access,
299            reference,
300            name,
301        )?)
302    }
303}
304
305pub struct Writer<W> {
306    out: W,
307    names: FastHashMap<NameKey, String>,
308    named_expressions: crate::NamedExpressions,
309    /// Set of expressions that need to be baked to avoid unnecessary repetition in output
310    need_bake_expressions: back::NeedBakeExpressions,
311    namer: proc::Namer,
312    #[cfg(test)]
313    put_expression_stack_pointers: FastHashSet<*const ()>,
314    #[cfg(test)]
315    put_block_stack_pointers: FastHashSet<*const ()>,
316    /// Set of (struct type, struct field index) denoting which fields require
317    /// padding inserted **before** them (i.e. between fields at index - 1 and index)
318    struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>,
319}
320
321impl crate::Scalar {
322    fn to_msl_name(self) -> &'static str {
323        use crate::ScalarKind as Sk;
324        match self {
325            Self {
326                kind: Sk::Float,
327                width: _,
328            } => "float",
329            Self {
330                kind: Sk::Sint,
331                width: 4,
332            } => "int",
333            Self {
334                kind: Sk::Uint,
335                width: 4,
336            } => "uint",
337            Self {
338                kind: Sk::Sint,
339                width: 8,
340            } => "long",
341            Self {
342                kind: Sk::Uint,
343                width: 8,
344            } => "ulong",
345            Self {
346                kind: Sk::Bool,
347                width: _,
348            } => "bool",
349            Self {
350                kind: Sk::AbstractInt | Sk::AbstractFloat,
351                width: _,
352            } => unreachable!("Found Abstract scalar kind"),
353            _ => unreachable!("Unsupported scalar kind: {:?}", self),
354        }
355    }
356}
357
358const fn separate(need_separator: bool) -> &'static str {
359    if need_separator {
360        ","
361    } else {
362        ""
363    }
364}
365
366fn should_pack_struct_member(
367    members: &[crate::StructMember],
368    span: u32,
369    index: usize,
370    module: &crate::Module,
371) -> Option<crate::Scalar> {
372    let member = &members[index];
373
374    let ty_inner = &module.types[member.ty].inner;
375    let last_offset = member.offset + ty_inner.size(module.to_ctx());
376    let next_offset = match members.get(index + 1) {
377        Some(next) => next.offset,
378        None => span,
379    };
380    let is_tight = next_offset == last_offset;
381
382    match *ty_inner {
383        crate::TypeInner::Vector {
384            size: crate::VectorSize::Tri,
385            scalar: scalar @ crate::Scalar { width: 4, .. },
386        } if is_tight => Some(scalar),
387        _ => None,
388    }
389}
390
391fn needs_array_length(ty: Handle<crate::Type>, arena: &crate::UniqueArena<crate::Type>) -> bool {
392    match arena[ty].inner {
393        crate::TypeInner::Struct { ref members, .. } => {
394            if let Some(member) = members.last() {
395                if let crate::TypeInner::Array {
396                    size: crate::ArraySize::Dynamic,
397                    ..
398                } = arena[member.ty].inner
399                {
400                    return true;
401                }
402            }
403            false
404        }
405        crate::TypeInner::Array {
406            size: crate::ArraySize::Dynamic,
407            ..
408        } => true,
409        _ => false,
410    }
411}
412
413impl crate::AddressSpace {
414    /// Returns true if global variables in this address space are
415    /// passed in function arguments. These arguments need to be
416    /// passed through any functions called from the entry point.
417    const fn needs_pass_through(&self) -> bool {
418        match *self {
419            Self::Uniform
420            | Self::Storage { .. }
421            | Self::Private
422            | Self::WorkGroup
423            | Self::PushConstant
424            | Self::Handle => true,
425            Self::Function => false,
426        }
427    }
428
429    /// Returns true if the address space may need a "const" qualifier.
430    const fn needs_access_qualifier(&self) -> bool {
431        match *self {
432            //Note: we are ignoring the storage access here, and instead
433            // rely on the actual use of a global by functions. This means we
434            // may end up with "const" even if the binding is read-write,
435            // and that should be OK.
436            Self::Storage { .. } => true,
437            // These should always be read-write.
438            Self::Private | Self::WorkGroup => false,
439            // These translate to `constant` address space, no need for qualifiers.
440            Self::Uniform | Self::PushConstant => false,
441            // Not applicable.
442            Self::Handle | Self::Function => false,
443        }
444    }
445
446    const fn to_msl_name(self) -> Option<&'static str> {
447        match self {
448            Self::Handle => None,
449            Self::Uniform | Self::PushConstant => Some("constant"),
450            Self::Storage { .. } => Some("device"),
451            Self::Private | Self::Function => Some("thread"),
452            Self::WorkGroup => Some("threadgroup"),
453        }
454    }
455}
456
457impl crate::Type {
458    // Returns `true` if we need to emit an alias for this type.
459    const fn needs_alias(&self) -> bool {
460        use crate::TypeInner as Ti;
461
462        match self.inner {
463            // value types are concise enough, we only alias them if they are named
464            Ti::Scalar(_)
465            | Ti::Vector { .. }
466            | Ti::Matrix { .. }
467            | Ti::Atomic(_)
468            | Ti::Pointer { .. }
469            | Ti::ValuePointer { .. } => self.name.is_some(),
470            // composite types are better to be aliased, regardless of the name
471            Ti::Struct { .. } | Ti::Array { .. } => true,
472            // handle types may be different, depending on the global var access, so we always inline them
473            Ti::Image { .. }
474            | Ti::Sampler { .. }
475            | Ti::AccelerationStructure
476            | Ti::RayQuery
477            | Ti::BindingArray { .. } => false,
478        }
479    }
480}
481
482enum FunctionOrigin {
483    Handle(Handle<crate::Function>),
484    EntryPoint(proc::EntryPointIndex),
485}
486
487/// A level of detail argument.
488///
489/// When [`BoundsCheckPolicy::Restrict`] applies to an [`ImageLoad`] access, we
490/// save the clamped level of detail in a temporary variable whose name is based
491/// on the handle of the `ImageLoad` expression. But for other policies, we just
492/// use the expression directly.
493///
494/// [`BoundsCheckPolicy::Restrict`]: index::BoundsCheckPolicy::Restrict
495/// [`ImageLoad`]: crate::Expression::ImageLoad
496#[derive(Clone, Copy)]
497enum LevelOfDetail {
498    Direct(Handle<crate::Expression>),
499    Restricted(Handle<crate::Expression>),
500}
501
502/// Values needed to select a particular texel for [`ImageLoad`] and [`ImageStore`].
503///
504/// When this is used in code paths unconcerned with the `Restrict` bounds check
505/// policy, the `LevelOfDetail` enum introduces an unneeded match, since `level`
506/// will always be either `None` or `Some(Direct(_))`. But this turns out not to
507/// be too awkward. If that changes, we can revisit.
508///
509/// [`ImageLoad`]: crate::Expression::ImageLoad
510/// [`ImageStore`]: crate::Statement::ImageStore
511struct TexelAddress {
512    coordinate: Handle<crate::Expression>,
513    array_index: Option<Handle<crate::Expression>>,
514    sample: Option<Handle<crate::Expression>>,
515    level: Option<LevelOfDetail>,
516}
517
518struct ExpressionContext<'a> {
519    function: &'a crate::Function,
520    origin: FunctionOrigin,
521    info: &'a valid::FunctionInfo,
522    module: &'a crate::Module,
523    mod_info: &'a valid::ModuleInfo,
524    pipeline_options: &'a PipelineOptions,
525    lang_version: (u8, u8),
526    policies: index::BoundsCheckPolicies,
527
528    /// A bitset containing the `Expression` handle indexes of expressions used
529    /// as indices in `ReadZeroSkipWrite`-policy accesses. These may need to be
530    /// cached in temporary variables. See `index::find_checked_indexes` for
531    /// details.
532    guarded_indices: BitSet,
533}
534
535impl<'a> ExpressionContext<'a> {
536    fn resolve_type(&self, handle: Handle<crate::Expression>) -> &'a crate::TypeInner {
537        self.info[handle].ty.inner_with(&self.module.types)
538    }
539
540    /// Return true if calls to `image`'s `read` and `write` methods should supply a level of detail.
541    ///
542    /// Only mipmapped images need to specify a level of detail. Since 1D
543    /// textures cannot have mipmaps, MSL requires that the level argument to
544    /// texture1d queries and accesses must be a constexpr 0. It's easiest
545    /// just to omit the level entirely for 1D textures.
546    fn image_needs_lod(&self, image: Handle<crate::Expression>) -> bool {
547        let image_ty = self.resolve_type(image);
548        if let crate::TypeInner::Image { dim, class, .. } = *image_ty {
549            class.is_mipmapped() && dim != crate::ImageDimension::D1
550        } else {
551            false
552        }
553    }
554
555    fn choose_bounds_check_policy(
556        &self,
557        pointer: Handle<crate::Expression>,
558    ) -> index::BoundsCheckPolicy {
559        self.policies
560            .choose_policy(pointer, &self.module.types, self.info)
561    }
562
563    fn access_needs_check(
564        &self,
565        base: Handle<crate::Expression>,
566        index: index::GuardedIndex,
567    ) -> Option<index::IndexableLength> {
568        index::access_needs_check(base, index, self.module, self.function, self.info)
569    }
570
571    fn get_packed_vec_kind(&self, expr_handle: Handle<crate::Expression>) -> Option<crate::Scalar> {
572        match self.function.expressions[expr_handle] {
573            crate::Expression::AccessIndex { base, index } => {
574                let ty = match *self.resolve_type(base) {
575                    crate::TypeInner::Pointer { base, .. } => &self.module.types[base].inner,
576                    ref ty => ty,
577                };
578                match *ty {
579                    crate::TypeInner::Struct {
580                        ref members, span, ..
581                    } => should_pack_struct_member(members, span, index as usize, self.module),
582                    _ => None,
583                }
584            }
585            _ => None,
586        }
587    }
588}
589
590struct StatementContext<'a> {
591    expression: ExpressionContext<'a>,
592    result_struct: Option<&'a str>,
593}
594
595impl<W: Write> Writer<W> {
596    /// Creates a new `Writer` instance.
597    pub fn new(out: W) -> Self {
598        Writer {
599            out,
600            names: FastHashMap::default(),
601            named_expressions: Default::default(),
602            need_bake_expressions: Default::default(),
603            namer: proc::Namer::default(),
604            #[cfg(test)]
605            put_expression_stack_pointers: Default::default(),
606            #[cfg(test)]
607            put_block_stack_pointers: Default::default(),
608            struct_member_pads: FastHashSet::default(),
609        }
610    }
611
612    /// Finishes writing and returns the output.
613    // See https://github.com/rust-lang/rust-clippy/issues/4979.
614    #[allow(clippy::missing_const_for_fn)]
615    pub fn finish(self) -> W {
616        self.out
617    }
618
619    fn put_call_parameters(
620        &mut self,
621        parameters: impl Iterator<Item = Handle<crate::Expression>>,
622        context: &ExpressionContext,
623    ) -> BackendResult {
624        self.put_call_parameters_impl(parameters, context, |writer, context, expr| {
625            writer.put_expression(expr, context, true)
626        })
627    }
628
629    fn put_call_parameters_impl<C, E>(
630        &mut self,
631        parameters: impl Iterator<Item = Handle<crate::Expression>>,
632        ctx: &C,
633        put_expression: E,
634    ) -> BackendResult
635    where
636        E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
637    {
638        write!(self.out, "(")?;
639        for (i, handle) in parameters.enumerate() {
640            if i != 0 {
641                write!(self.out, ", ")?;
642            }
643            put_expression(self, ctx, handle)?;
644        }
645        write!(self.out, ")")?;
646        Ok(())
647    }
648
649    fn put_level_of_detail(
650        &mut self,
651        level: LevelOfDetail,
652        context: &ExpressionContext,
653    ) -> BackendResult {
654        match level {
655            LevelOfDetail::Direct(expr) => self.put_expression(expr, context, true)?,
656            LevelOfDetail::Restricted(load) => {
657                write!(self.out, "{}{}", CLAMPED_LOD_LOAD_PREFIX, load.index())?
658            }
659        }
660        Ok(())
661    }
662
663    fn put_image_query(
664        &mut self,
665        image: Handle<crate::Expression>,
666        query: &str,
667        level: Option<LevelOfDetail>,
668        context: &ExpressionContext,
669    ) -> BackendResult {
670        self.put_expression(image, context, false)?;
671        write!(self.out, ".get_{query}(")?;
672        if let Some(level) = level {
673            self.put_level_of_detail(level, context)?;
674        }
675        write!(self.out, ")")?;
676        Ok(())
677    }
678
679    fn put_image_size_query(
680        &mut self,
681        image: Handle<crate::Expression>,
682        level: Option<LevelOfDetail>,
683        kind: crate::ScalarKind,
684        context: &ExpressionContext,
685    ) -> BackendResult {
686        //Note: MSL only has separate width/height/depth queries,
687        // so compose the result of them.
688        let dim = match *context.resolve_type(image) {
689            crate::TypeInner::Image { dim, .. } => dim,
690            ref other => unreachable!("Unexpected type {:?}", other),
691        };
692        let scalar = crate::Scalar { kind, width: 4 };
693        let coordinate_type = scalar.to_msl_name();
694        match dim {
695            crate::ImageDimension::D1 => {
696                // Since 1D textures never have mipmaps, MSL requires that the
697                // `level` argument be a constexpr 0. It's simplest for us just
698                // to pass `None` and omit the level entirely.
699                if kind == crate::ScalarKind::Uint {
700                    // No need to construct a vector. No cast needed.
701                    self.put_image_query(image, "width", None, context)?;
702                } else {
703                    // There's no definition for `int` in the `metal` namespace.
704                    write!(self.out, "int(")?;
705                    self.put_image_query(image, "width", None, context)?;
706                    write!(self.out, ")")?;
707                }
708            }
709            crate::ImageDimension::D2 => {
710                write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
711                self.put_image_query(image, "width", level, context)?;
712                write!(self.out, ", ")?;
713                self.put_image_query(image, "height", level, context)?;
714                write!(self.out, ")")?;
715            }
716            crate::ImageDimension::D3 => {
717                write!(self.out, "{NAMESPACE}::{coordinate_type}3(")?;
718                self.put_image_query(image, "width", level, context)?;
719                write!(self.out, ", ")?;
720                self.put_image_query(image, "height", level, context)?;
721                write!(self.out, ", ")?;
722                self.put_image_query(image, "depth", level, context)?;
723                write!(self.out, ")")?;
724            }
725            crate::ImageDimension::Cube => {
726                write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
727                self.put_image_query(image, "width", level, context)?;
728                write!(self.out, ")")?;
729            }
730        }
731        Ok(())
732    }
733
734    fn put_cast_to_uint_scalar_or_vector(
735        &mut self,
736        expr: Handle<crate::Expression>,
737        context: &ExpressionContext,
738    ) -> BackendResult {
739        // coordinates in IR are int, but Metal expects uint
740        match *context.resolve_type(expr) {
741            crate::TypeInner::Scalar(_) => {
742                put_numeric_type(&mut self.out, crate::Scalar::U32, &[])?
743            }
744            crate::TypeInner::Vector { size, .. } => {
745                put_numeric_type(&mut self.out, crate::Scalar::U32, &[size])?
746            }
747            _ => {
748                return Err(Error::GenericValidation(
749                    "Invalid type for image coordinate".into(),
750                ))
751            }
752        };
753
754        write!(self.out, "(")?;
755        self.put_expression(expr, context, true)?;
756        write!(self.out, ")")?;
757        Ok(())
758    }
759
760    fn put_image_sample_level(
761        &mut self,
762        image: Handle<crate::Expression>,
763        level: crate::SampleLevel,
764        context: &ExpressionContext,
765    ) -> BackendResult {
766        let has_levels = context.image_needs_lod(image);
767        match level {
768            crate::SampleLevel::Auto => {}
769            crate::SampleLevel::Zero => {
770                //TODO: do we support Zero on `Sampled` image classes?
771            }
772            _ if !has_levels => {
773                log::warn!("1D image can't be sampled with level {:?}", level);
774            }
775            crate::SampleLevel::Exact(h) => {
776                write!(self.out, ", {NAMESPACE}::level(")?;
777                self.put_expression(h, context, true)?;
778                write!(self.out, ")")?;
779            }
780            crate::SampleLevel::Bias(h) => {
781                write!(self.out, ", {NAMESPACE}::bias(")?;
782                self.put_expression(h, context, true)?;
783                write!(self.out, ")")?;
784            }
785            crate::SampleLevel::Gradient { x, y } => {
786                write!(self.out, ", {NAMESPACE}::gradient2d(")?;
787                self.put_expression(x, context, true)?;
788                write!(self.out, ", ")?;
789                self.put_expression(y, context, true)?;
790                write!(self.out, ")")?;
791            }
792        }
793        Ok(())
794    }
795
796    fn put_image_coordinate_limits(
797        &mut self,
798        image: Handle<crate::Expression>,
799        level: Option<LevelOfDetail>,
800        context: &ExpressionContext,
801    ) -> BackendResult {
802        self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
803        write!(self.out, " - 1")?;
804        Ok(())
805    }
806
807    /// General function for writing restricted image indexes.
808    ///
809    /// This is used to produce restricted mip levels, array indices, and sample
810    /// indices for [`ImageLoad`] and [`ImageStore`] accesses under the
811    /// [`Restrict`] bounds check policy.
812    ///
813    /// This function writes an expression of the form:
814    ///
815    /// ```ignore
816    ///
817    ///     metal::min(uint(INDEX), IMAGE.LIMIT_METHOD() - 1)
818    ///
819    /// ```
820    ///
821    /// [`ImageLoad`]: crate::Expression::ImageLoad
822    /// [`ImageStore`]: crate::Statement::ImageStore
823    /// [`Restrict`]: index::BoundsCheckPolicy::Restrict
824    fn put_restricted_scalar_image_index(
825        &mut self,
826        image: Handle<crate::Expression>,
827        index: Handle<crate::Expression>,
828        limit_method: &str,
829        context: &ExpressionContext,
830    ) -> BackendResult {
831        write!(self.out, "{NAMESPACE}::min(uint(")?;
832        self.put_expression(index, context, true)?;
833        write!(self.out, "), ")?;
834        self.put_expression(image, context, false)?;
835        write!(self.out, ".{limit_method}() - 1)")?;
836        Ok(())
837    }
838
839    fn put_restricted_texel_address(
840        &mut self,
841        image: Handle<crate::Expression>,
842        address: &TexelAddress,
843        context: &ExpressionContext,
844    ) -> BackendResult {
845        // Write the coordinate.
846        write!(self.out, "{NAMESPACE}::min(")?;
847        self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
848        write!(self.out, ", ")?;
849        self.put_image_coordinate_limits(image, address.level, context)?;
850        write!(self.out, ")")?;
851
852        // Write the array index, if present.
853        if let Some(array_index) = address.array_index {
854            write!(self.out, ", ")?;
855            self.put_restricted_scalar_image_index(image, array_index, "get_array_size", context)?;
856        }
857
858        // Write the sample index, if present.
859        if let Some(sample) = address.sample {
860            write!(self.out, ", ")?;
861            self.put_restricted_scalar_image_index(image, sample, "get_num_samples", context)?;
862        }
863
864        // The level of detail should be clamped and cached by
865        // `put_cache_restricted_level`, so we don't need to clamp it here.
866        if let Some(level) = address.level {
867            write!(self.out, ", ")?;
868            self.put_level_of_detail(level, context)?;
869        }
870
871        Ok(())
872    }
873
874    /// Write an expression that is true if the given image access is in bounds.
875    fn put_image_access_bounds_check(
876        &mut self,
877        image: Handle<crate::Expression>,
878        address: &TexelAddress,
879        context: &ExpressionContext,
880    ) -> BackendResult {
881        let mut conjunction = "";
882
883        // First, check the level of detail. Only if that is in bounds can we
884        // use it to find the appropriate bounds for the coordinates.
885        let level = if let Some(level) = address.level {
886            write!(self.out, "uint(")?;
887            self.put_level_of_detail(level, context)?;
888            write!(self.out, ") < ")?;
889            self.put_expression(image, context, true)?;
890            write!(self.out, ".get_num_mip_levels()")?;
891            conjunction = " && ";
892            Some(level)
893        } else {
894            None
895        };
896
897        // Check sample index, if present.
898        if let Some(sample) = address.sample {
899            write!(self.out, "uint(")?;
900            self.put_expression(sample, context, true)?;
901            write!(self.out, ") < ")?;
902            self.put_expression(image, context, true)?;
903            write!(self.out, ".get_num_samples()")?;
904            conjunction = " && ";
905        }
906
907        // Check array index, if present.
908        if let Some(array_index) = address.array_index {
909            write!(self.out, "{conjunction}uint(")?;
910            self.put_expression(array_index, context, true)?;
911            write!(self.out, ") < ")?;
912            self.put_expression(image, context, true)?;
913            write!(self.out, ".get_array_size()")?;
914            conjunction = " && ";
915        }
916
917        // Finally, check if the coordinates are within bounds.
918        let coord_is_vector = match *context.resolve_type(address.coordinate) {
919            crate::TypeInner::Vector { .. } => true,
920            _ => false,
921        };
922        write!(self.out, "{conjunction}")?;
923        if coord_is_vector {
924            write!(self.out, "{NAMESPACE}::all(")?;
925        }
926        self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
927        write!(self.out, " < ")?;
928        self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
929        if coord_is_vector {
930            write!(self.out, ")")?;
931        }
932
933        Ok(())
934    }
935
936    fn put_image_load(
937        &mut self,
938        load: Handle<crate::Expression>,
939        image: Handle<crate::Expression>,
940        mut address: TexelAddress,
941        context: &ExpressionContext,
942    ) -> BackendResult {
943        match context.policies.image_load {
944            proc::BoundsCheckPolicy::Restrict => {
945                // Use the cached restricted level of detail, if any. Omit the
946                // level altogether for 1D textures.
947                if address.level.is_some() {
948                    address.level = if context.image_needs_lod(image) {
949                        Some(LevelOfDetail::Restricted(load))
950                    } else {
951                        None
952                    }
953                }
954
955                self.put_expression(image, context, false)?;
956                write!(self.out, ".read(")?;
957                self.put_restricted_texel_address(image, &address, context)?;
958                write!(self.out, ")")?;
959            }
960            proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
961                write!(self.out, "(")?;
962                self.put_image_access_bounds_check(image, &address, context)?;
963                write!(self.out, " ? ")?;
964                self.put_unchecked_image_load(image, &address, context)?;
965                write!(self.out, ": DefaultConstructible())")?;
966            }
967            proc::BoundsCheckPolicy::Unchecked => {
968                self.put_unchecked_image_load(image, &address, context)?;
969            }
970        }
971
972        Ok(())
973    }
974
975    fn put_unchecked_image_load(
976        &mut self,
977        image: Handle<crate::Expression>,
978        address: &TexelAddress,
979        context: &ExpressionContext,
980    ) -> BackendResult {
981        self.put_expression(image, context, false)?;
982        write!(self.out, ".read(")?;
983        // coordinates in IR are int, but Metal expects uint
984        self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
985        if let Some(expr) = address.array_index {
986            write!(self.out, ", ")?;
987            self.put_expression(expr, context, true)?;
988        }
989        if let Some(sample) = address.sample {
990            write!(self.out, ", ")?;
991            self.put_expression(sample, context, true)?;
992        }
993        if let Some(level) = address.level {
994            if context.image_needs_lod(image) {
995                write!(self.out, ", ")?;
996                self.put_level_of_detail(level, context)?;
997            }
998        }
999        write!(self.out, ")")?;
1000
1001        Ok(())
1002    }
1003
1004    fn put_image_store(
1005        &mut self,
1006        level: back::Level,
1007        image: Handle<crate::Expression>,
1008        address: &TexelAddress,
1009        value: Handle<crate::Expression>,
1010        context: &StatementContext,
1011    ) -> BackendResult {
1012        match context.expression.policies.image_store {
1013            proc::BoundsCheckPolicy::Restrict => {
1014                // We don't have a restricted level value, because we don't
1015                // support writes to mipmapped textures.
1016                debug_assert!(address.level.is_none());
1017
1018                write!(self.out, "{level}")?;
1019                self.put_expression(image, &context.expression, false)?;
1020                write!(self.out, ".write(")?;
1021                self.put_expression(value, &context.expression, true)?;
1022                write!(self.out, ", ")?;
1023                self.put_restricted_texel_address(image, address, &context.expression)?;
1024                writeln!(self.out, ");")?;
1025            }
1026            proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
1027                write!(self.out, "{level}if (")?;
1028                self.put_image_access_bounds_check(image, address, &context.expression)?;
1029                writeln!(self.out, ") {{")?;
1030                self.put_unchecked_image_store(level.next(), image, address, value, context)?;
1031                writeln!(self.out, "{level}}}")?;
1032            }
1033            proc::BoundsCheckPolicy::Unchecked => {
1034                self.put_unchecked_image_store(level, image, address, value, context)?;
1035            }
1036        }
1037
1038        Ok(())
1039    }
1040
1041    fn put_unchecked_image_store(
1042        &mut self,
1043        level: back::Level,
1044        image: Handle<crate::Expression>,
1045        address: &TexelAddress,
1046        value: Handle<crate::Expression>,
1047        context: &StatementContext,
1048    ) -> BackendResult {
1049        write!(self.out, "{level}")?;
1050        self.put_expression(image, &context.expression, false)?;
1051        write!(self.out, ".write(")?;
1052        self.put_expression(value, &context.expression, true)?;
1053        write!(self.out, ", ")?;
1054        // coordinates in IR are int, but Metal expects uint
1055        self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1056        if let Some(expr) = address.array_index {
1057            write!(self.out, ", ")?;
1058            self.put_expression(expr, &context.expression, true)?;
1059        }
1060        writeln!(self.out, ");")?;
1061
1062        Ok(())
1063    }
1064
1065    /// Write the maximum valid index of the dynamically sized array at the end of `handle`.
1066    ///
1067    /// The 'maximum valid index' is simply one less than the array's length.
1068    ///
1069    /// This emits an expression of the form `a / b`, so the caller must
1070    /// parenthesize its output if it will be applying operators of higher
1071    /// precedence.
1072    ///
1073    /// `handle` must be the handle of a global variable whose final member is a
1074    /// dynamically sized array.
1075    fn put_dynamic_array_max_index(
1076        &mut self,
1077        handle: Handle<crate::GlobalVariable>,
1078        context: &ExpressionContext,
1079    ) -> BackendResult {
1080        let global = &context.module.global_variables[handle];
1081        let (offset, array_ty) = match context.module.types[global.ty].inner {
1082            crate::TypeInner::Struct { ref members, .. } => match members.last() {
1083                Some(&crate::StructMember { offset, ty, .. }) => (offset, ty),
1084                None => return Err(Error::GenericValidation("Struct has no members".into())),
1085            },
1086            crate::TypeInner::Array {
1087                size: crate::ArraySize::Dynamic,
1088                ..
1089            } => (0, global.ty),
1090            ref ty => {
1091                return Err(Error::GenericValidation(format!(
1092                    "Expected type with dynamic array, got {ty:?}"
1093                )))
1094            }
1095        };
1096
1097        let (size, stride) = match context.module.types[array_ty].inner {
1098            crate::TypeInner::Array { base, stride, .. } => (
1099                context.module.types[base]
1100                    .inner
1101                    .size(context.module.to_ctx()),
1102                stride,
1103            ),
1104            ref ty => {
1105                return Err(Error::GenericValidation(format!(
1106                    "Expected array type, got {ty:?}"
1107                )))
1108            }
1109        };
1110
1111        // When the stride length is larger than the size, the final element's stride of
1112        // bytes would have padding following the value. But the buffer size in
1113        // `buffer_sizes.sizeN` may not include this padding - it only needs to be large
1114        // enough to hold the actual values' bytes.
1115        //
1116        // So subtract off the size to get a byte size that falls at the start or within
1117        // the final element. Then divide by the stride size, to get one less than the
1118        // length, and then add one. This works even if the buffer size does include the
1119        // stride padding, since division rounds towards zero (MSL 2.4 §6.1). It will fail
1120        // if there are zero elements in the array, but the WebGPU `validating shader binding`
1121        // rules, together with draw-time validation when `minBindingSize` is zero,
1122        // prevent that.
1123        write!(
1124            self.out,
1125            "(_buffer_sizes.size{idx} - {offset} - {size}) / {stride}",
1126            idx = handle.index(),
1127            offset = offset,
1128            size = size,
1129            stride = stride,
1130        )?;
1131        Ok(())
1132    }
1133
1134    fn put_atomic_operation(
1135        &mut self,
1136        pointer: Handle<crate::Expression>,
1137        key: &str,
1138        value: Handle<crate::Expression>,
1139        context: &ExpressionContext,
1140    ) -> BackendResult {
1141        // If the pointer we're passing to the atomic operation needs to be conditional
1142        // for `ReadZeroSkipWrite`, the condition needs to *surround* the atomic op, and
1143        // the pointer operand should be unchecked.
1144        let policy = context.choose_bounds_check_policy(pointer);
1145        let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
1146            && self.put_bounds_checks(pointer, context, back::Level(0), "")?;
1147
1148        // If requested and successfully put bounds checks, continue the ternary expression.
1149        if checked {
1150            write!(self.out, " ? ")?;
1151        }
1152
1153        write!(
1154            self.out,
1155            "{NAMESPACE}::atomic_{key}_explicit({ATOMIC_REFERENCE}"
1156        )?;
1157        self.put_access_chain(pointer, policy, context)?;
1158        write!(self.out, ", ")?;
1159        self.put_expression(value, context, true)?;
1160        write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
1161
1162        // Finish the ternary expression.
1163        if checked {
1164            write!(self.out, " : DefaultConstructible()")?;
1165        }
1166
1167        Ok(())
1168    }
1169
1170    /// Emit code for the arithmetic expression of the dot product.
1171    ///
1172    fn put_dot_product(
1173        &mut self,
1174        arg: Handle<crate::Expression>,
1175        arg1: Handle<crate::Expression>,
1176        size: usize,
1177        context: &ExpressionContext,
1178    ) -> BackendResult {
1179        // Write parentheses around the dot product expression to prevent operators
1180        // with different precedences from applying earlier.
1181        write!(self.out, "(")?;
1182
1183        // Cycle trough all the components of the vector
1184        for index in 0..size {
1185            let component = back::COMPONENTS[index];
1186            // Write the addition to the previous product
1187            // This will print an extra '+' at the beginning but that is fine in msl
1188            write!(self.out, " + ")?;
1189            // Write the first vector expression, this expression is marked to be
1190            // cached so unless it can't be cached (for example, it's a Constant)
1191            // it shouldn't produce large expressions.
1192            self.put_expression(arg, context, true)?;
1193            // Access the current component on the first vector
1194            write!(self.out, ".{component} * ")?;
1195            // Write the second vector expression, this expression is marked to be
1196            // cached so unless it can't be cached (for example, it's a Constant)
1197            // it shouldn't produce large expressions.
1198            self.put_expression(arg1, context, true)?;
1199            // Access the current component on the second vector
1200            write!(self.out, ".{component}")?;
1201        }
1202
1203        write!(self.out, ")")?;
1204        Ok(())
1205    }
1206
1207    /// Emit code for the sign(i32) expression.
1208    ///
1209    fn put_isign(
1210        &mut self,
1211        arg: Handle<crate::Expression>,
1212        context: &ExpressionContext,
1213    ) -> BackendResult {
1214        write!(self.out, "{NAMESPACE}::select({NAMESPACE}::select(")?;
1215        match context.resolve_type(arg) {
1216            &crate::TypeInner::Vector { size, .. } => {
1217                let size = back::vector_size_str(size);
1218                write!(self.out, "int{size}(-1), int{size}(1)")?;
1219            }
1220            _ => {
1221                write!(self.out, "-1, 1")?;
1222            }
1223        }
1224        write!(self.out, ", (")?;
1225        self.put_expression(arg, context, true)?;
1226        write!(self.out, " > 0)), 0, (")?;
1227        self.put_expression(arg, context, true)?;
1228        write!(self.out, " == 0))")?;
1229        Ok(())
1230    }
1231
1232    fn put_const_expression(
1233        &mut self,
1234        expr_handle: Handle<crate::Expression>,
1235        module: &crate::Module,
1236        mod_info: &valid::ModuleInfo,
1237    ) -> BackendResult {
1238        self.put_possibly_const_expression(
1239            expr_handle,
1240            &module.global_expressions,
1241            module,
1242            mod_info,
1243            &(module, mod_info),
1244            |&(_, mod_info), expr| &mod_info[expr],
1245            |writer, &(module, _), expr| writer.put_const_expression(expr, module, mod_info),
1246        )
1247    }
1248
1249    #[allow(clippy::too_many_arguments)]
1250    fn put_possibly_const_expression<C, I, E>(
1251        &mut self,
1252        expr_handle: Handle<crate::Expression>,
1253        expressions: &crate::Arena<crate::Expression>,
1254        module: &crate::Module,
1255        mod_info: &valid::ModuleInfo,
1256        ctx: &C,
1257        get_expr_ty: I,
1258        put_expression: E,
1259    ) -> BackendResult
1260    where
1261        I: Fn(&C, Handle<crate::Expression>) -> &TypeResolution,
1262        E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
1263    {
1264        match expressions[expr_handle] {
1265            crate::Expression::Literal(literal) => match literal {
1266                crate::Literal::F64(_) => {
1267                    return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64))
1268                }
1269                crate::Literal::F32(value) => {
1270                    if value.is_infinite() {
1271                        let sign = if value.is_sign_negative() { "-" } else { "" };
1272                        write!(self.out, "{sign}INFINITY")?;
1273                    } else if value.is_nan() {
1274                        write!(self.out, "NAN")?;
1275                    } else {
1276                        let suffix = if value.fract() == 0.0 { ".0" } else { "" };
1277                        write!(self.out, "{value}{suffix}")?;
1278                    }
1279                }
1280                crate::Literal::U32(value) => {
1281                    write!(self.out, "{value}u")?;
1282                }
1283                crate::Literal::I32(value) => {
1284                    write!(self.out, "{value}")?;
1285                }
1286                crate::Literal::U64(value) => {
1287                    write!(self.out, "{value}uL")?;
1288                }
1289                crate::Literal::I64(value) => {
1290                    write!(self.out, "{value}L")?;
1291                }
1292                crate::Literal::Bool(value) => {
1293                    write!(self.out, "{value}")?;
1294                }
1295                crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
1296                    return Err(Error::GenericValidation(
1297                        "Unsupported abstract literal".into(),
1298                    ));
1299                }
1300            },
1301            crate::Expression::Constant(handle) => {
1302                let constant = &module.constants[handle];
1303                if constant.name.is_some() {
1304                    write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
1305                } else {
1306                    self.put_const_expression(constant.init, module, mod_info)?;
1307                }
1308            }
1309            crate::Expression::ZeroValue(ty) => {
1310                let ty_name = TypeContext {
1311                    handle: ty,
1312                    gctx: module.to_ctx(),
1313                    names: &self.names,
1314                    access: crate::StorageAccess::empty(),
1315                    binding: None,
1316                    first_time: false,
1317                };
1318                write!(self.out, "{ty_name} {{}}")?;
1319            }
1320            crate::Expression::Compose { ty, ref components } => {
1321                let ty_name = TypeContext {
1322                    handle: ty,
1323                    gctx: module.to_ctx(),
1324                    names: &self.names,
1325                    access: crate::StorageAccess::empty(),
1326                    binding: None,
1327                    first_time: false,
1328                };
1329                write!(self.out, "{ty_name}")?;
1330                match module.types[ty].inner {
1331                    crate::TypeInner::Scalar(_)
1332                    | crate::TypeInner::Vector { .. }
1333                    | crate::TypeInner::Matrix { .. } => {
1334                        self.put_call_parameters_impl(
1335                            components.iter().copied(),
1336                            ctx,
1337                            put_expression,
1338                        )?;
1339                    }
1340                    crate::TypeInner::Array { .. } | crate::TypeInner::Struct { .. } => {
1341                        write!(self.out, " {{")?;
1342                        for (index, &component) in components.iter().enumerate() {
1343                            if index != 0 {
1344                                write!(self.out, ", ")?;
1345                            }
1346                            // insert padding initialization, if needed
1347                            if self.struct_member_pads.contains(&(ty, index as u32)) {
1348                                write!(self.out, "{{}}, ")?;
1349                            }
1350                            put_expression(self, ctx, component)?;
1351                        }
1352                        write!(self.out, "}}")?;
1353                    }
1354                    _ => return Err(Error::UnsupportedCompose(ty)),
1355                }
1356            }
1357            crate::Expression::Splat { size, value } => {
1358                let scalar = match *get_expr_ty(ctx, value).inner_with(&module.types) {
1359                    crate::TypeInner::Scalar(scalar) => scalar,
1360                    ref ty => {
1361                        return Err(Error::GenericValidation(format!(
1362                            "Expected splat value type must be a scalar, got {ty:?}",
1363                        )))
1364                    }
1365                };
1366                put_numeric_type(&mut self.out, scalar, &[size])?;
1367                write!(self.out, "(")?;
1368                put_expression(self, ctx, value)?;
1369                write!(self.out, ")")?;
1370            }
1371            _ => unreachable!(),
1372        }
1373
1374        Ok(())
1375    }
1376
1377    /// Emit code for the expression `expr_handle`.
1378    ///
1379    /// The `is_scoped` argument is true if the surrounding operators have the
1380    /// precedence of the comma operator, or lower. So, for example:
1381    ///
1382    /// - Pass `true` for `is_scoped` when writing function arguments, an
1383    ///   expression statement, an initializer expression, or anything already
1384    ///   wrapped in parenthesis.
1385    ///
1386    /// - Pass `false` if it is an operand of a `?:` operator, a `[]`, or really
1387    ///   almost anything else.
1388    fn put_expression(
1389        &mut self,
1390        expr_handle: Handle<crate::Expression>,
1391        context: &ExpressionContext,
1392        is_scoped: bool,
1393    ) -> BackendResult {
1394        // Add to the set in order to track the stack size.
1395        #[cfg(test)]
1396        #[allow(trivial_casts)]
1397        self.put_expression_stack_pointers
1398            .insert(&expr_handle as *const _ as *const ());
1399
1400        if let Some(name) = self.named_expressions.get(&expr_handle) {
1401            write!(self.out, "{name}")?;
1402            return Ok(());
1403        }
1404
1405        let expression = &context.function.expressions[expr_handle];
1406        log::trace!("expression {:?} = {:?}", expr_handle, expression);
1407        match *expression {
1408            crate::Expression::Literal(_)
1409            | crate::Expression::Constant(_)
1410            | crate::Expression::ZeroValue(_)
1411            | crate::Expression::Compose { .. }
1412            | crate::Expression::Splat { .. } => {
1413                self.put_possibly_const_expression(
1414                    expr_handle,
1415                    &context.function.expressions,
1416                    context.module,
1417                    context.mod_info,
1418                    context,
1419                    |context, expr: Handle<crate::Expression>| &context.info[expr].ty,
1420                    |writer, context, expr| writer.put_expression(expr, context, true),
1421                )?;
1422            }
1423            crate::Expression::Override(_) => return Err(Error::Override),
1424            crate::Expression::Access { base, .. }
1425            | crate::Expression::AccessIndex { base, .. } => {
1426                // This is an acceptable place to generate a `ReadZeroSkipWrite` check.
1427                // Since `put_bounds_checks` and `put_access_chain` handle an entire
1428                // access chain at a time, recursing back through `put_expression` only
1429                // for index expressions and the base object, we will never see intermediate
1430                // `Access` or `AccessIndex` expressions here.
1431                let policy = context.choose_bounds_check_policy(base);
1432                if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
1433                    && self.put_bounds_checks(
1434                        expr_handle,
1435                        context,
1436                        back::Level(0),
1437                        if is_scoped { "" } else { "(" },
1438                    )?
1439                {
1440                    write!(self.out, " ? ")?;
1441                    self.put_access_chain(expr_handle, policy, context)?;
1442                    write!(self.out, " : DefaultConstructible()")?;
1443
1444                    if !is_scoped {
1445                        write!(self.out, ")")?;
1446                    }
1447                } else {
1448                    self.put_access_chain(expr_handle, policy, context)?;
1449                }
1450            }
1451            crate::Expression::Swizzle {
1452                size,
1453                vector,
1454                pattern,
1455            } => {
1456                self.put_wrapped_expression_for_packed_vec3_access(vector, context, false)?;
1457                write!(self.out, ".")?;
1458                for &sc in pattern[..size as usize].iter() {
1459                    write!(self.out, "{}", back::COMPONENTS[sc as usize])?;
1460                }
1461            }
1462            crate::Expression::FunctionArgument(index) => {
1463                let name_key = match context.origin {
1464                    FunctionOrigin::Handle(handle) => NameKey::FunctionArgument(handle, index),
1465                    FunctionOrigin::EntryPoint(ep_index) => {
1466                        NameKey::EntryPointArgument(ep_index, index)
1467                    }
1468                };
1469                let name = &self.names[&name_key];
1470                write!(self.out, "{name}")?;
1471            }
1472            crate::Expression::GlobalVariable(handle) => {
1473                let name = &self.names[&NameKey::GlobalVariable(handle)];
1474                write!(self.out, "{name}")?;
1475            }
1476            crate::Expression::LocalVariable(handle) => {
1477                let name_key = match context.origin {
1478                    FunctionOrigin::Handle(fun_handle) => {
1479                        NameKey::FunctionLocal(fun_handle, handle)
1480                    }
1481                    FunctionOrigin::EntryPoint(ep_index) => {
1482                        NameKey::EntryPointLocal(ep_index, handle)
1483                    }
1484                };
1485                let name = &self.names[&name_key];
1486                write!(self.out, "{name}")?;
1487            }
1488            crate::Expression::Load { pointer } => self.put_load(pointer, context, is_scoped)?,
1489            crate::Expression::ImageSample {
1490                image,
1491                sampler,
1492                gather,
1493                coordinate,
1494                array_index,
1495                offset,
1496                level,
1497                depth_ref,
1498            } => {
1499                let main_op = match gather {
1500                    Some(_) => "gather",
1501                    None => "sample",
1502                };
1503                let comparison_op = match depth_ref {
1504                    Some(_) => "_compare",
1505                    None => "",
1506                };
1507                self.put_expression(image, context, false)?;
1508                write!(self.out, ".{main_op}{comparison_op}(")?;
1509                self.put_expression(sampler, context, true)?;
1510                write!(self.out, ", ")?;
1511                self.put_expression(coordinate, context, true)?;
1512                if let Some(expr) = array_index {
1513                    write!(self.out, ", ")?;
1514                    self.put_expression(expr, context, true)?;
1515                }
1516                if let Some(dref) = depth_ref {
1517                    write!(self.out, ", ")?;
1518                    self.put_expression(dref, context, true)?;
1519                }
1520
1521                self.put_image_sample_level(image, level, context)?;
1522
1523                if let Some(offset) = offset {
1524                    write!(self.out, ", ")?;
1525                    self.put_const_expression(offset, context.module, context.mod_info)?;
1526                }
1527
1528                match gather {
1529                    None | Some(crate::SwizzleComponent::X) => {}
1530                    Some(component) => {
1531                        let is_cube_map = match *context.resolve_type(image) {
1532                            crate::TypeInner::Image {
1533                                dim: crate::ImageDimension::Cube,
1534                                ..
1535                            } => true,
1536                            _ => false,
1537                        };
1538                        // Offset always comes before the gather, except
1539                        // in cube maps where it's not applicable
1540                        if offset.is_none() && !is_cube_map {
1541                            write!(self.out, ", {NAMESPACE}::int2(0)")?;
1542                        }
1543                        let letter = back::COMPONENTS[component as usize];
1544                        write!(self.out, ", {NAMESPACE}::component::{letter}")?;
1545                    }
1546                }
1547                write!(self.out, ")")?;
1548            }
1549            crate::Expression::ImageLoad {
1550                image,
1551                coordinate,
1552                array_index,
1553                sample,
1554                level,
1555            } => {
1556                let address = TexelAddress {
1557                    coordinate,
1558                    array_index,
1559                    sample,
1560                    level: level.map(LevelOfDetail::Direct),
1561                };
1562                self.put_image_load(expr_handle, image, address, context)?;
1563            }
1564            //Note: for all the queries, the signed integers are expected,
1565            // so a conversion is needed.
1566            crate::Expression::ImageQuery { image, query } => match query {
1567                crate::ImageQuery::Size { level } => {
1568                    self.put_image_size_query(
1569                        image,
1570                        level.map(LevelOfDetail::Direct),
1571                        crate::ScalarKind::Uint,
1572                        context,
1573                    )?;
1574                }
1575                crate::ImageQuery::NumLevels => {
1576                    self.put_expression(image, context, false)?;
1577                    write!(self.out, ".get_num_mip_levels()")?;
1578                }
1579                crate::ImageQuery::NumLayers => {
1580                    self.put_expression(image, context, false)?;
1581                    write!(self.out, ".get_array_size()")?;
1582                }
1583                crate::ImageQuery::NumSamples => {
1584                    self.put_expression(image, context, false)?;
1585                    write!(self.out, ".get_num_samples()")?;
1586                }
1587            },
1588            crate::Expression::Unary { op, expr } => {
1589                let op_str = match op {
1590                    crate::UnaryOperator::Negate => "-",
1591                    crate::UnaryOperator::LogicalNot => "!",
1592                    crate::UnaryOperator::BitwiseNot => "~",
1593                };
1594                write!(self.out, "{op_str}(")?;
1595                self.put_expression(expr, context, false)?;
1596                write!(self.out, ")")?;
1597            }
1598            crate::Expression::Binary { op, left, right } => {
1599                let op_str = crate::back::binary_operation_str(op);
1600                let kind = context
1601                    .resolve_type(left)
1602                    .scalar_kind()
1603                    .ok_or(Error::UnsupportedBinaryOp(op))?;
1604
1605                // TODO: handle undefined behavior of BinaryOperator::Modulo
1606                //
1607                // sint:
1608                // if right == 0 return 0
1609                // if left == min(type_of(left)) && right == -1 return 0
1610                // if sign(left) == -1 || sign(right) == -1 return result as defined by WGSL
1611                //
1612                // uint:
1613                // if right == 0 return 0
1614                //
1615                // float:
1616                // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798
1617
1618                if op == crate::BinaryOperator::Modulo && kind == crate::ScalarKind::Float {
1619                    write!(self.out, "{NAMESPACE}::fmod(")?;
1620                    self.put_expression(left, context, true)?;
1621                    write!(self.out, ", ")?;
1622                    self.put_expression(right, context, true)?;
1623                    write!(self.out, ")")?;
1624                } else {
1625                    if !is_scoped {
1626                        write!(self.out, "(")?;
1627                    }
1628
1629                    // Cast packed vector if necessary
1630                    // Packed vector - matrix multiplications are not supported in MSL
1631                    if op == crate::BinaryOperator::Multiply
1632                        && matches!(
1633                            context.resolve_type(right),
1634                            &crate::TypeInner::Matrix { .. }
1635                        )
1636                    {
1637                        self.put_wrapped_expression_for_packed_vec3_access(left, context, false)?;
1638                    } else {
1639                        self.put_expression(left, context, false)?;
1640                    }
1641
1642                    write!(self.out, " {op_str} ")?;
1643
1644                    // See comment above
1645                    if op == crate::BinaryOperator::Multiply
1646                        && matches!(context.resolve_type(left), &crate::TypeInner::Matrix { .. })
1647                    {
1648                        self.put_wrapped_expression_for_packed_vec3_access(right, context, false)?;
1649                    } else {
1650                        self.put_expression(right, context, false)?;
1651                    }
1652
1653                    if !is_scoped {
1654                        write!(self.out, ")")?;
1655                    }
1656                }
1657            }
1658            crate::Expression::Select {
1659                condition,
1660                accept,
1661                reject,
1662            } => match *context.resolve_type(condition) {
1663                crate::TypeInner::Scalar(crate::Scalar {
1664                    kind: crate::ScalarKind::Bool,
1665                    ..
1666                }) => {
1667                    if !is_scoped {
1668                        write!(self.out, "(")?;
1669                    }
1670                    self.put_expression(condition, context, false)?;
1671                    write!(self.out, " ? ")?;
1672                    self.put_expression(accept, context, false)?;
1673                    write!(self.out, " : ")?;
1674                    self.put_expression(reject, context, false)?;
1675                    if !is_scoped {
1676                        write!(self.out, ")")?;
1677                    }
1678                }
1679                crate::TypeInner::Vector {
1680                    scalar:
1681                        crate::Scalar {
1682                            kind: crate::ScalarKind::Bool,
1683                            ..
1684                        },
1685                    ..
1686                } => {
1687                    write!(self.out, "{NAMESPACE}::select(")?;
1688                    self.put_expression(reject, context, true)?;
1689                    write!(self.out, ", ")?;
1690                    self.put_expression(accept, context, true)?;
1691                    write!(self.out, ", ")?;
1692                    self.put_expression(condition, context, true)?;
1693                    write!(self.out, ")")?;
1694                }
1695                ref ty => {
1696                    return Err(Error::GenericValidation(format!(
1697                        "Expected select condition to be a non-bool type, got {ty:?}",
1698                    )))
1699                }
1700            },
1701            crate::Expression::Derivative { axis, expr, .. } => {
1702                use crate::DerivativeAxis as Axis;
1703                let op = match axis {
1704                    Axis::X => "dfdx",
1705                    Axis::Y => "dfdy",
1706                    Axis::Width => "fwidth",
1707                };
1708                write!(self.out, "{NAMESPACE}::{op}")?;
1709                self.put_call_parameters(iter::once(expr), context)?;
1710            }
1711            crate::Expression::Relational { fun, argument } => {
1712                let op = match fun {
1713                    crate::RelationalFunction::Any => "any",
1714                    crate::RelationalFunction::All => "all",
1715                    crate::RelationalFunction::IsNan => "isnan",
1716                    crate::RelationalFunction::IsInf => "isinf",
1717                };
1718                write!(self.out, "{NAMESPACE}::{op}")?;
1719                self.put_call_parameters(iter::once(argument), context)?;
1720            }
1721            crate::Expression::Math {
1722                fun,
1723                arg,
1724                arg1,
1725                arg2,
1726                arg3,
1727            } => {
1728                use crate::MathFunction as Mf;
1729
1730                let arg_type = context.resolve_type(arg);
1731                let scalar_argument = match arg_type {
1732                    &crate::TypeInner::Scalar(_) => true,
1733                    _ => false,
1734                };
1735
1736                let fun_name = match fun {
1737                    // comparison
1738                    Mf::Abs => "abs",
1739                    Mf::Min => "min",
1740                    Mf::Max => "max",
1741                    Mf::Clamp => "clamp",
1742                    Mf::Saturate => "saturate",
1743                    // trigonometry
1744                    Mf::Cos => "cos",
1745                    Mf::Cosh => "cosh",
1746                    Mf::Sin => "sin",
1747                    Mf::Sinh => "sinh",
1748                    Mf::Tan => "tan",
1749                    Mf::Tanh => "tanh",
1750                    Mf::Acos => "acos",
1751                    Mf::Asin => "asin",
1752                    Mf::Atan => "atan",
1753                    Mf::Atan2 => "atan2",
1754                    Mf::Asinh => "asinh",
1755                    Mf::Acosh => "acosh",
1756                    Mf::Atanh => "atanh",
1757                    Mf::Radians => "",
1758                    Mf::Degrees => "",
1759                    // decomposition
1760                    Mf::Ceil => "ceil",
1761                    Mf::Floor => "floor",
1762                    Mf::Round => "rint",
1763                    Mf::Fract => "fract",
1764                    Mf::Trunc => "trunc",
1765                    Mf::Modf => MODF_FUNCTION,
1766                    Mf::Frexp => FREXP_FUNCTION,
1767                    Mf::Ldexp => "ldexp",
1768                    // exponent
1769                    Mf::Exp => "exp",
1770                    Mf::Exp2 => "exp2",
1771                    Mf::Log => "log",
1772                    Mf::Log2 => "log2",
1773                    Mf::Pow => "pow",
1774                    // geometry
1775                    Mf::Dot => match *context.resolve_type(arg) {
1776                        crate::TypeInner::Vector {
1777                            scalar:
1778                                crate::Scalar {
1779                                    kind: crate::ScalarKind::Float,
1780                                    ..
1781                                },
1782                            ..
1783                        } => "dot",
1784                        crate::TypeInner::Vector { size, .. } => {
1785                            return self.put_dot_product(arg, arg1.unwrap(), size as usize, context)
1786                        }
1787                        _ => unreachable!(
1788                            "Correct TypeInner for dot product should be already validated"
1789                        ),
1790                    },
1791                    Mf::Outer => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
1792                    Mf::Cross => "cross",
1793                    Mf::Distance => "distance",
1794                    Mf::Length if scalar_argument => "abs",
1795                    Mf::Length => "length",
1796                    Mf::Normalize => "normalize",
1797                    Mf::FaceForward => "faceforward",
1798                    Mf::Reflect => "reflect",
1799                    Mf::Refract => "refract",
1800                    // computational
1801                    Mf::Sign => match arg_type.scalar_kind() {
1802                        Some(crate::ScalarKind::Sint) => {
1803                            return self.put_isign(arg, context);
1804                        }
1805                        _ => "sign",
1806                    },
1807                    Mf::Fma => "fma",
1808                    Mf::Mix => "mix",
1809                    Mf::Step => "step",
1810                    Mf::SmoothStep => "smoothstep",
1811                    Mf::Sqrt => "sqrt",
1812                    Mf::InverseSqrt => "rsqrt",
1813                    Mf::Inverse => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
1814                    Mf::Transpose => "transpose",
1815                    Mf::Determinant => "determinant",
1816                    // bits
1817                    Mf::CountTrailingZeros => "ctz",
1818                    Mf::CountLeadingZeros => "clz",
1819                    Mf::CountOneBits => "popcount",
1820                    Mf::ReverseBits => "reverse_bits",
1821                    Mf::ExtractBits => "",
1822                    Mf::InsertBits => "",
1823                    Mf::FindLsb => "",
1824                    Mf::FindMsb => "",
1825                    // data packing
1826                    Mf::Pack4x8snorm => "pack_float_to_snorm4x8",
1827                    Mf::Pack4x8unorm => "pack_float_to_unorm4x8",
1828                    Mf::Pack2x16snorm => "pack_float_to_snorm2x16",
1829                    Mf::Pack2x16unorm => "pack_float_to_unorm2x16",
1830                    Mf::Pack2x16float => "",
1831                    // data unpacking
1832                    Mf::Unpack4x8snorm => "unpack_snorm4x8_to_float",
1833                    Mf::Unpack4x8unorm => "unpack_unorm4x8_to_float",
1834                    Mf::Unpack2x16snorm => "unpack_snorm2x16_to_float",
1835                    Mf::Unpack2x16unorm => "unpack_unorm2x16_to_float",
1836                    Mf::Unpack2x16float => "",
1837                };
1838
1839                match fun {
1840                    Mf::ReverseBits | Mf::ExtractBits | Mf::InsertBits => {
1841                        // reverse_bits is listed as requiring MSL 2.1 but that
1842                        // is a copy/paste error. Looking at previous snapshots
1843                        // on web.archive.org it's present in MSL 1.2.
1844                        //
1845                        // https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html
1846                        // also talks about MSL 1.2 adding "New integer
1847                        // functions to extract, insert, and reverse bits, as
1848                        // described in Integer Functions."
1849                        if context.lang_version < (1, 2) {
1850                            return Err(Error::UnsupportedFunction(fun_name.to_string()));
1851                        }
1852                    }
1853                    _ => {}
1854                }
1855
1856                if fun == Mf::Distance && scalar_argument {
1857                    write!(self.out, "{NAMESPACE}::abs(")?;
1858                    self.put_expression(arg, context, false)?;
1859                    write!(self.out, " - ")?;
1860                    self.put_expression(arg1.unwrap(), context, false)?;
1861                    write!(self.out, ")")?;
1862                } else if fun == Mf::FindLsb {
1863                    let scalar = context.resolve_type(arg).scalar().unwrap();
1864                    let constant = scalar.width * 8 + 1;
1865
1866                    write!(self.out, "((({NAMESPACE}::ctz(")?;
1867                    self.put_expression(arg, context, true)?;
1868                    write!(self.out, ") + 1) % {constant}) - 1)")?;
1869                } else if fun == Mf::FindMsb {
1870                    let inner = context.resolve_type(arg);
1871                    let scalar = inner.scalar().unwrap();
1872                    let constant = scalar.width * 8 - 1;
1873
1874                    write!(
1875                        self.out,
1876                        "{NAMESPACE}::select({constant} - {NAMESPACE}::clz("
1877                    )?;
1878
1879                    if scalar.kind == crate::ScalarKind::Sint {
1880                        write!(self.out, "{NAMESPACE}::select(")?;
1881                        self.put_expression(arg, context, true)?;
1882                        write!(self.out, ", ~")?;
1883                        self.put_expression(arg, context, true)?;
1884                        write!(self.out, ", ")?;
1885                        self.put_expression(arg, context, true)?;
1886                        write!(self.out, " < 0)")?;
1887                    } else {
1888                        self.put_expression(arg, context, true)?;
1889                    }
1890
1891                    write!(self.out, "), ")?;
1892
1893                    // or metal will complain that select is ambiguous
1894                    match *inner {
1895                        crate::TypeInner::Vector { size, scalar } => {
1896                            let size = back::vector_size_str(size);
1897                            let name = scalar.to_msl_name();
1898                            write!(self.out, "{name}{size}")?;
1899                        }
1900                        crate::TypeInner::Scalar(scalar) => {
1901                            let name = scalar.to_msl_name();
1902                            write!(self.out, "{name}")?;
1903                        }
1904                        _ => (),
1905                    }
1906
1907                    write!(self.out, "(-1), ")?;
1908                    self.put_expression(arg, context, true)?;
1909                    write!(self.out, " == 0 || ")?;
1910                    self.put_expression(arg, context, true)?;
1911                    write!(self.out, " == -1)")?;
1912                } else if fun == Mf::Unpack2x16float {
1913                    write!(self.out, "float2(as_type<half2>(")?;
1914                    self.put_expression(arg, context, false)?;
1915                    write!(self.out, "))")?;
1916                } else if fun == Mf::Pack2x16float {
1917                    write!(self.out, "as_type<uint>(half2(")?;
1918                    self.put_expression(arg, context, false)?;
1919                    write!(self.out, "))")?;
1920                } else if fun == Mf::ExtractBits {
1921                    // The behavior of ExtractBits is undefined when offset + count > bit_width. We need
1922                    // to first sanitize the offset and count first. If we don't do this, Apple chips
1923                    // will return out-of-spec values if the extracted range is not within the bit width.
1924                    //
1925                    // This encodes the exact formula specified by the wgsl spec, without temporary values:
1926                    // https://gpuweb.github.io/gpuweb/wgsl/#extractBits-unsigned-builtin
1927                    //
1928                    // w = sizeof(x) * 8
1929                    // o = min(offset, w)
1930                    // tmp = w - o
1931                    // c = min(count, tmp)
1932                    //
1933                    // bitfieldExtract(x, o, c)
1934                    //
1935                    // extract_bits(e, min(offset, w), min(count, w - min(offset, w))))
1936
1937                    let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
1938
1939                    write!(self.out, "{NAMESPACE}::extract_bits(")?;
1940                    self.put_expression(arg, context, true)?;
1941                    write!(self.out, ", {NAMESPACE}::min(")?;
1942                    self.put_expression(arg1.unwrap(), context, true)?;
1943                    write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
1944                    self.put_expression(arg2.unwrap(), context, true)?;
1945                    write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
1946                    self.put_expression(arg1.unwrap(), context, true)?;
1947                    write!(self.out, ", {scalar_bits}u)))")?;
1948                } else if fun == Mf::InsertBits {
1949                    // The behavior of InsertBits has the same issue as ExtractBits.
1950                    //
1951                    // insertBits(e, newBits, min(offset, w), min(count, w - min(offset, w))))
1952
1953                    let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
1954
1955                    write!(self.out, "{NAMESPACE}::insert_bits(")?;
1956                    self.put_expression(arg, context, true)?;
1957                    write!(self.out, ", ")?;
1958                    self.put_expression(arg1.unwrap(), context, true)?;
1959                    write!(self.out, ", {NAMESPACE}::min(")?;
1960                    self.put_expression(arg2.unwrap(), context, true)?;
1961                    write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
1962                    self.put_expression(arg3.unwrap(), context, true)?;
1963                    write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
1964                    self.put_expression(arg2.unwrap(), context, true)?;
1965                    write!(self.out, ", {scalar_bits}u)))")?;
1966                } else if fun == Mf::Radians {
1967                    write!(self.out, "((")?;
1968                    self.put_expression(arg, context, false)?;
1969                    write!(self.out, ") * 0.017453292519943295474)")?;
1970                } else if fun == Mf::Degrees {
1971                    write!(self.out, "((")?;
1972                    self.put_expression(arg, context, false)?;
1973                    write!(self.out, ") * 57.295779513082322865)")?;
1974                } else if fun == Mf::Modf || fun == Mf::Frexp {
1975                    write!(self.out, "{fun_name}")?;
1976                    self.put_call_parameters(iter::once(arg), context)?;
1977                } else {
1978                    write!(self.out, "{NAMESPACE}::{fun_name}")?;
1979                    self.put_call_parameters(
1980                        iter::once(arg).chain(arg1).chain(arg2).chain(arg3),
1981                        context,
1982                    )?;
1983                }
1984            }
1985            crate::Expression::As {
1986                expr,
1987                kind,
1988                convert,
1989            } => match *context.resolve_type(expr) {
1990                crate::TypeInner::Scalar(src) | crate::TypeInner::Vector { scalar: src, .. } => {
1991                    let target_scalar = crate::Scalar {
1992                        kind,
1993                        width: convert.unwrap_or(src.width),
1994                    };
1995                    let op = match convert {
1996                        Some(_) => "static_cast",
1997                        None => "as_type",
1998                    };
1999                    write!(self.out, "{op}<")?;
2000                    match *context.resolve_type(expr) {
2001                        crate::TypeInner::Vector { size, .. } => {
2002                            put_numeric_type(&mut self.out, target_scalar, &[size])?
2003                        }
2004                        _ => put_numeric_type(&mut self.out, target_scalar, &[])?,
2005                    };
2006                    write!(self.out, ">(")?;
2007                    self.put_expression(expr, context, true)?;
2008                    write!(self.out, ")")?;
2009                }
2010                crate::TypeInner::Matrix {
2011                    columns,
2012                    rows,
2013                    scalar,
2014                } => {
2015                    let target_scalar = crate::Scalar {
2016                        kind,
2017                        width: convert.unwrap_or(scalar.width),
2018                    };
2019                    put_numeric_type(&mut self.out, target_scalar, &[rows, columns])?;
2020                    write!(self.out, "(")?;
2021                    self.put_expression(expr, context, true)?;
2022                    write!(self.out, ")")?;
2023                }
2024                ref ty => {
2025                    return Err(Error::GenericValidation(format!(
2026                        "Unsupported type for As: {ty:?}"
2027                    )))
2028                }
2029            },
2030            // has to be a named expression
2031            crate::Expression::CallResult(_)
2032            | crate::Expression::AtomicResult { .. }
2033            | crate::Expression::WorkGroupUniformLoadResult { .. }
2034            | crate::Expression::SubgroupBallotResult
2035            | crate::Expression::SubgroupOperationResult { .. }
2036            | crate::Expression::RayQueryProceedResult => {
2037                unreachable!()
2038            }
2039            crate::Expression::ArrayLength(expr) => {
2040                // Find the global to which the array belongs.
2041                let global = match context.function.expressions[expr] {
2042                    crate::Expression::AccessIndex { base, .. } => {
2043                        match context.function.expressions[base] {
2044                            crate::Expression::GlobalVariable(handle) => handle,
2045                            ref ex => {
2046                                return Err(Error::GenericValidation(format!(
2047                                    "Expected global variable in AccessIndex, got {ex:?}"
2048                                )))
2049                            }
2050                        }
2051                    }
2052                    crate::Expression::GlobalVariable(handle) => handle,
2053                    ref ex => {
2054                        return Err(Error::GenericValidation(format!(
2055                            "Unexpected expression in ArrayLength, got {ex:?}"
2056                        )))
2057                    }
2058                };
2059
2060                if !is_scoped {
2061                    write!(self.out, "(")?;
2062                }
2063                write!(self.out, "1 + ")?;
2064                self.put_dynamic_array_max_index(global, context)?;
2065                if !is_scoped {
2066                    write!(self.out, ")")?;
2067                }
2068            }
2069            crate::Expression::RayQueryGetIntersection { query, committed } => {
2070                if context.lang_version < (2, 4) {
2071                    return Err(Error::UnsupportedRayTracing);
2072                }
2073
2074                if !committed {
2075                    unimplemented!()
2076                }
2077                let ty = context.module.special_types.ray_intersection.unwrap();
2078                let type_name = &self.names[&NameKey::Type(ty)];
2079                write!(self.out, "{type_name} {{{RAY_QUERY_FUN_MAP_INTERSECTION}(")?;
2080                self.put_expression(query, context, true)?;
2081                write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.type)")?;
2082                let fields = [
2083                    "distance",
2084                    "user_instance_id", // req Metal 2.4
2085                    "instance_id",
2086                    "", // SBT offset
2087                    "geometry_id",
2088                    "primitive_id",
2089                    "triangle_barycentric_coord",
2090                    "triangle_front_facing",
2091                    "",                          // padding
2092                    "object_to_world_transform", // req Metal 2.4
2093                    "world_to_object_transform", // req Metal 2.4
2094                ];
2095                for field in fields {
2096                    write!(self.out, ", ")?;
2097                    if field.is_empty() {
2098                        write!(self.out, "{{}}")?;
2099                    } else {
2100                        self.put_expression(query, context, true)?;
2101                        write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.{field}")?;
2102                    }
2103                }
2104                write!(self.out, "}}")?;
2105            }
2106        }
2107        Ok(())
2108    }
2109
2110    /// Used by expressions like Swizzle and Binary since they need packed_vec3's to be casted to a vec3
2111    fn put_wrapped_expression_for_packed_vec3_access(
2112        &mut self,
2113        expr_handle: Handle<crate::Expression>,
2114        context: &ExpressionContext,
2115        is_scoped: bool,
2116    ) -> BackendResult {
2117        if let Some(scalar) = context.get_packed_vec_kind(expr_handle) {
2118            write!(self.out, "{}::{}3(", NAMESPACE, scalar.to_msl_name())?;
2119            self.put_expression(expr_handle, context, is_scoped)?;
2120            write!(self.out, ")")?;
2121        } else {
2122            self.put_expression(expr_handle, context, is_scoped)?;
2123        }
2124        Ok(())
2125    }
2126
2127    /// Write a `GuardedIndex` as a Metal expression.
2128    fn put_index(
2129        &mut self,
2130        index: index::GuardedIndex,
2131        context: &ExpressionContext,
2132        is_scoped: bool,
2133    ) -> BackendResult {
2134        match index {
2135            index::GuardedIndex::Expression(expr) => {
2136                self.put_expression(expr, context, is_scoped)?
2137            }
2138            index::GuardedIndex::Known(value) => write!(self.out, "{value}")?,
2139        }
2140        Ok(())
2141    }
2142
2143    /// Emit an index bounds check condition for `chain`, if required.
2144    ///
2145    /// `chain` is a subtree of `Access` and `AccessIndex` expressions,
2146    /// operating either on a pointer to a value, or on a value directly. If we cannot
2147    /// statically determine that all indexing operations in `chain` are within
2148    /// bounds, then write a conditional expression to check them dynamically,
2149    /// and return true. All accesses in the chain are checked by the generated
2150    /// expression.
2151    ///
2152    /// This assumes that the [`BoundsCheckPolicy`] for `chain` is [`ReadZeroSkipWrite`].
2153    ///
2154    /// The text written is of the form:
2155    ///
2156    /// ```ignore
2157    /// {level}{prefix}uint(i) < 4 && uint(j) < 10
2158    /// ```
2159    ///
2160    /// where `{level}` and `{prefix}` are the arguments to this function. For [`Store`]
2161    /// statements, presumably these arguments start an indented `if` statement; for
2162    /// [`Load`] expressions, the caller is probably building up a ternary `?:`
2163    /// expression. In either case, what is written is not a complete syntactic structure
2164    /// in its own right, and the caller will have to finish it off if we return `true`.
2165    ///
2166    /// If no expression is written, return false.
2167    ///
2168    /// [`BoundsCheckPolicy`]: index::BoundsCheckPolicy
2169    /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite
2170    /// [`Store`]: crate::Statement::Store
2171    /// [`Load`]: crate::Expression::Load
2172    #[allow(unused_variables)]
2173    fn put_bounds_checks(
2174        &mut self,
2175        mut chain: Handle<crate::Expression>,
2176        context: &ExpressionContext,
2177        level: back::Level,
2178        prefix: &'static str,
2179    ) -> Result<bool, Error> {
2180        let mut check_written = false;
2181
2182        // Iterate over the access chain, handling each expression.
2183        loop {
2184            // Produce a `GuardedIndex`, so we can shared code between the
2185            // `Access` and `AccessIndex` cases.
2186            let (base, guarded_index) = match context.function.expressions[chain] {
2187                crate::Expression::Access { base, index } => {
2188                    (base, Some(index::GuardedIndex::Expression(index)))
2189                }
2190                crate::Expression::AccessIndex { base, index } => {
2191                    // Don't try to check indices into structs. Validation already took
2192                    // care of them, and index::needs_guard doesn't handle that case.
2193                    let mut base_inner = context.resolve_type(base);
2194                    if let crate::TypeInner::Pointer { base, .. } = *base_inner {
2195                        base_inner = &context.module.types[base].inner;
2196                    }
2197                    match *base_inner {
2198                        crate::TypeInner::Struct { .. } => (base, None),
2199                        _ => (base, Some(index::GuardedIndex::Known(index))),
2200                    }
2201                }
2202                _ => break,
2203            };
2204
2205            if let Some(index) = guarded_index {
2206                if let Some(length) = context.access_needs_check(base, index) {
2207                    if check_written {
2208                        write!(self.out, " && ")?;
2209                    } else {
2210                        write!(self.out, "{level}{prefix}")?;
2211                        check_written = true;
2212                    }
2213
2214                    // Check that the index falls within bounds. Do this with a single
2215                    // comparison, by casting the index to `uint` first, so that negative
2216                    // indices become large positive values.
2217                    write!(self.out, "uint(")?;
2218                    self.put_index(index, context, true)?;
2219                    self.out.write_str(") < ")?;
2220                    match length {
2221                        index::IndexableLength::Known(value) => write!(self.out, "{value}")?,
2222                        index::IndexableLength::Dynamic => {
2223                            let global =
2224                                context.function.originating_global(base).ok_or_else(|| {
2225                                    Error::GenericValidation(
2226                                        "Could not find originating global".into(),
2227                                    )
2228                                })?;
2229                            write!(self.out, "1 + ")?;
2230                            self.put_dynamic_array_max_index(global, context)?
2231                        }
2232                    }
2233                }
2234            }
2235
2236            chain = base
2237        }
2238
2239        Ok(check_written)
2240    }
2241
2242    /// Write the access chain `chain`.
2243    ///
2244    /// `chain` is a subtree of [`Access`] and [`AccessIndex`] expressions,
2245    /// operating either on a pointer to a value, or on a value directly.
2246    ///
2247    /// Generate bounds checks code only if `policy` is [`Restrict`]. The
2248    /// [`ReadZeroSkipWrite`] policy requires checks before any accesses take place, so
2249    /// that must be handled in the caller.
2250    ///
2251    /// Handle the entire chain, recursing back into `put_expression` only for index
2252    /// expressions and the base expression that originates the pointer or composite value
2253    /// being accessed. This allows `put_expression` to assume that any `Access` or
2254    /// `AccessIndex` expressions it sees are the top of a chain, so it can emit
2255    /// `ReadZeroSkipWrite` checks.
2256    ///
2257    /// [`Access`]: crate::Expression::Access
2258    /// [`AccessIndex`]: crate::Expression::AccessIndex
2259    /// [`Restrict`]: crate::proc::index::BoundsCheckPolicy::Restrict
2260    /// [`ReadZeroSkipWrite`]: crate::proc::index::BoundsCheckPolicy::ReadZeroSkipWrite
2261    fn put_access_chain(
2262        &mut self,
2263        chain: Handle<crate::Expression>,
2264        policy: index::BoundsCheckPolicy,
2265        context: &ExpressionContext,
2266    ) -> BackendResult {
2267        match context.function.expressions[chain] {
2268            crate::Expression::Access { base, index } => {
2269                let mut base_ty = context.resolve_type(base);
2270
2271                // Look through any pointers to see what we're really indexing.
2272                if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
2273                    base_ty = &context.module.types[base].inner;
2274                }
2275
2276                self.put_subscripted_access_chain(
2277                    base,
2278                    base_ty,
2279                    index::GuardedIndex::Expression(index),
2280                    policy,
2281                    context,
2282                )?;
2283            }
2284            crate::Expression::AccessIndex { base, index } => {
2285                let base_resolution = &context.info[base].ty;
2286                let mut base_ty = base_resolution.inner_with(&context.module.types);
2287                let mut base_ty_handle = base_resolution.handle();
2288
2289                // Look through any pointers to see what we're really indexing.
2290                if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
2291                    base_ty = &context.module.types[base].inner;
2292                    base_ty_handle = Some(base);
2293                }
2294
2295                // Handle structs and anything else that can use `.x` syntax here, so
2296                // `put_subscripted_access_chain` won't have to handle the absurd case of
2297                // indexing a struct with an expression.
2298                match *base_ty {
2299                    crate::TypeInner::Struct { .. } => {
2300                        let base_ty = base_ty_handle.unwrap();
2301                        self.put_access_chain(base, policy, context)?;
2302                        let name = &self.names[&NameKey::StructMember(base_ty, index)];
2303                        write!(self.out, ".{name}")?;
2304                    }
2305                    crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Vector { .. } => {
2306                        self.put_access_chain(base, policy, context)?;
2307                        // Prior to Metal v2.1 component access for packed vectors wasn't available
2308                        // however array indexing is
2309                        if context.get_packed_vec_kind(base).is_some() {
2310                            write!(self.out, "[{index}]")?;
2311                        } else {
2312                            write!(self.out, ".{}", back::COMPONENTS[index as usize])?;
2313                        }
2314                    }
2315                    _ => {
2316                        self.put_subscripted_access_chain(
2317                            base,
2318                            base_ty,
2319                            index::GuardedIndex::Known(index),
2320                            policy,
2321                            context,
2322                        )?;
2323                    }
2324                }
2325            }
2326            _ => self.put_expression(chain, context, false)?,
2327        }
2328
2329        Ok(())
2330    }
2331
2332    /// Write a `[]`-style access of `base` by `index`.
2333    ///
2334    /// If `policy` is [`Restrict`], then generate code as needed to force all index
2335    /// values within bounds.
2336    ///
2337    /// The `base_ty` argument must be the type we are actually indexing, like [`Array`] or
2338    /// [`Vector`]. In other words, it's `base`'s type with any surrounding [`Pointer`]
2339    /// removed. Our callers often already have this handy.
2340    ///
2341    /// This only emits `[]` expressions; it doesn't handle struct member accesses or
2342    /// referencing vector components by name.
2343    ///
2344    /// [`Restrict`]: crate::proc::index::BoundsCheckPolicy::Restrict
2345    /// [`Array`]: crate::TypeInner::Array
2346    /// [`Vector`]: crate::TypeInner::Vector
2347    /// [`Pointer`]: crate::TypeInner::Pointer
2348    fn put_subscripted_access_chain(
2349        &mut self,
2350        base: Handle<crate::Expression>,
2351        base_ty: &crate::TypeInner,
2352        index: index::GuardedIndex,
2353        policy: index::BoundsCheckPolicy,
2354        context: &ExpressionContext,
2355    ) -> BackendResult {
2356        let accessing_wrapped_array = match *base_ty {
2357            crate::TypeInner::Array {
2358                size: crate::ArraySize::Constant(_),
2359                ..
2360            } => true,
2361            _ => false,
2362        };
2363
2364        self.put_access_chain(base, policy, context)?;
2365        if accessing_wrapped_array {
2366            write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
2367        }
2368        write!(self.out, "[")?;
2369
2370        // Decide whether this index needs to be clamped to fall within range.
2371        let restriction_needed = if policy == index::BoundsCheckPolicy::Restrict {
2372            context.access_needs_check(base, index)
2373        } else {
2374            None
2375        };
2376        if let Some(limit) = restriction_needed {
2377            write!(self.out, "{NAMESPACE}::min(unsigned(")?;
2378            self.put_index(index, context, true)?;
2379            write!(self.out, "), ")?;
2380            match limit {
2381                index::IndexableLength::Known(limit) => {
2382                    write!(self.out, "{}u", limit - 1)?;
2383                }
2384                index::IndexableLength::Dynamic => {
2385                    let global = context.function.originating_global(base).ok_or_else(|| {
2386                        Error::GenericValidation("Could not find originating global".into())
2387                    })?;
2388                    self.put_dynamic_array_max_index(global, context)?;
2389                }
2390            }
2391            write!(self.out, ")")?;
2392        } else {
2393            self.put_index(index, context, true)?;
2394        }
2395
2396        write!(self.out, "]")?;
2397
2398        Ok(())
2399    }
2400
2401    fn put_load(
2402        &mut self,
2403        pointer: Handle<crate::Expression>,
2404        context: &ExpressionContext,
2405        is_scoped: bool,
2406    ) -> BackendResult {
2407        // Since access chains never cross between address spaces, we can just
2408        // check the index bounds check policy once at the top.
2409        let policy = context.choose_bounds_check_policy(pointer);
2410        if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
2411            && self.put_bounds_checks(
2412                pointer,
2413                context,
2414                back::Level(0),
2415                if is_scoped { "" } else { "(" },
2416            )?
2417        {
2418            write!(self.out, " ? ")?;
2419            self.put_unchecked_load(pointer, policy, context)?;
2420            write!(self.out, " : DefaultConstructible()")?;
2421
2422            if !is_scoped {
2423                write!(self.out, ")")?;
2424            }
2425        } else {
2426            self.put_unchecked_load(pointer, policy, context)?;
2427        }
2428
2429        Ok(())
2430    }
2431
2432    fn put_unchecked_load(
2433        &mut self,
2434        pointer: Handle<crate::Expression>,
2435        policy: index::BoundsCheckPolicy,
2436        context: &ExpressionContext,
2437    ) -> BackendResult {
2438        let is_atomic_pointer = context
2439            .resolve_type(pointer)
2440            .is_atomic_pointer(&context.module.types);
2441
2442        if is_atomic_pointer {
2443            write!(
2444                self.out,
2445                "{NAMESPACE}::atomic_load_explicit({ATOMIC_REFERENCE}"
2446            )?;
2447            self.put_access_chain(pointer, policy, context)?;
2448            write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
2449        } else {
2450            // We don't do any dereferencing with `*` here as pointer arguments to functions
2451            // are done by `&` references and not `*` pointers. These do not need to be
2452            // dereferenced.
2453            self.put_access_chain(pointer, policy, context)?;
2454        }
2455
2456        Ok(())
2457    }
2458
2459    fn put_return_value(
2460        &mut self,
2461        level: back::Level,
2462        expr_handle: Handle<crate::Expression>,
2463        result_struct: Option<&str>,
2464        context: &ExpressionContext,
2465    ) -> BackendResult {
2466        match result_struct {
2467            Some(struct_name) => {
2468                let mut has_point_size = false;
2469                let result_ty = context.function.result.as_ref().unwrap().ty;
2470                match context.module.types[result_ty].inner {
2471                    crate::TypeInner::Struct { ref members, .. } => {
2472                        let tmp = "_tmp";
2473                        write!(self.out, "{level}const auto {tmp} = ")?;
2474                        self.put_expression(expr_handle, context, true)?;
2475                        writeln!(self.out, ";")?;
2476                        write!(self.out, "{level}return {struct_name} {{")?;
2477
2478                        let mut is_first = true;
2479
2480                        for (index, member) in members.iter().enumerate() {
2481                            if let Some(crate::Binding::BuiltIn(crate::BuiltIn::PointSize)) =
2482                                member.binding
2483                            {
2484                                has_point_size = true;
2485                                if !context.pipeline_options.allow_and_force_point_size {
2486                                    continue;
2487                                }
2488                            }
2489
2490                            let comma = if is_first { "" } else { "," };
2491                            is_first = false;
2492                            let name = &self.names[&NameKey::StructMember(result_ty, index as u32)];
2493                            // HACK: we are forcefully deduplicating the expression here
2494                            // to convert from a wrapped struct to a raw array, e.g.
2495                            // `float gl_ClipDistance1 [[clip_distance]] [1];`.
2496                            if let crate::TypeInner::Array {
2497                                size: crate::ArraySize::Constant(size),
2498                                ..
2499                            } = context.module.types[member.ty].inner
2500                            {
2501                                write!(self.out, "{comma} {{")?;
2502                                for j in 0..size.get() {
2503                                    if j != 0 {
2504                                        write!(self.out, ",")?;
2505                                    }
2506                                    write!(self.out, "{tmp}.{name}.{WRAPPED_ARRAY_FIELD}[{j}]")?;
2507                                }
2508                                write!(self.out, "}}")?;
2509                            } else {
2510                                write!(self.out, "{comma} {tmp}.{name}")?;
2511                            }
2512                        }
2513                    }
2514                    _ => {
2515                        write!(self.out, "{level}return {struct_name} {{ ")?;
2516                        self.put_expression(expr_handle, context, true)?;
2517                    }
2518                }
2519
2520                if let FunctionOrigin::EntryPoint(ep_index) = context.origin {
2521                    let stage = context.module.entry_points[ep_index as usize].stage;
2522                    if context.pipeline_options.allow_and_force_point_size
2523                        && stage == crate::ShaderStage::Vertex
2524                        && !has_point_size
2525                    {
2526                        // point size was injected and comes last
2527                        write!(self.out, ", 1.0")?;
2528                    }
2529                }
2530                write!(self.out, " }}")?;
2531            }
2532            None => {
2533                write!(self.out, "{level}return ")?;
2534                self.put_expression(expr_handle, context, true)?;
2535            }
2536        }
2537        writeln!(self.out, ";")?;
2538        Ok(())
2539    }
2540
2541    /// Helper method used to find which expressions of a given function require baking
2542    ///
2543    /// # Notes
2544    /// This function overwrites the contents of `self.need_bake_expressions`
2545    fn update_expressions_to_bake(
2546        &mut self,
2547        func: &crate::Function,
2548        info: &valid::FunctionInfo,
2549        context: &ExpressionContext,
2550    ) {
2551        use crate::Expression;
2552        self.need_bake_expressions.clear();
2553
2554        for (expr_handle, expr) in func.expressions.iter() {
2555            // Expressions whose reference count is above the
2556            // threshold should always be stored in temporaries.
2557            let expr_info = &info[expr_handle];
2558            let min_ref_count = func.expressions[expr_handle].bake_ref_count();
2559            if min_ref_count <= expr_info.ref_count {
2560                self.need_bake_expressions.insert(expr_handle);
2561            } else {
2562                match expr_info.ty {
2563                    // force ray desc to be baked: it's used multiple times internally
2564                    TypeResolution::Handle(h)
2565                        if Some(h) == context.module.special_types.ray_desc =>
2566                    {
2567                        self.need_bake_expressions.insert(expr_handle);
2568                    }
2569                    _ => {}
2570                }
2571            }
2572
2573            if let Expression::Math {
2574                fun,
2575                arg,
2576                arg1,
2577                arg2,
2578                ..
2579            } = *expr
2580            {
2581                match fun {
2582                    crate::MathFunction::Dot => {
2583                        // WGSL's `dot` function works on any `vecN` type, but Metal's only
2584                        // works on floating-point vectors, so we emit inline code for
2585                        // integer vector `dot` calls. But that code uses each argument `N`
2586                        // times, once for each component (see `put_dot_product`), so to
2587                        // avoid duplicated evaluation, we must bake integer operands.
2588
2589                        // check what kind of product this is depending
2590                        // on the resolve type of the Dot function itself
2591                        let inner = context.resolve_type(expr_handle);
2592                        if let crate::TypeInner::Scalar(scalar) = *inner {
2593                            match scalar.kind {
2594                                crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
2595                                    self.need_bake_expressions.insert(arg);
2596                                    self.need_bake_expressions.insert(arg1.unwrap());
2597                                }
2598                                _ => {}
2599                            }
2600                        }
2601                    }
2602                    crate::MathFunction::FindMsb => {
2603                        self.need_bake_expressions.insert(arg);
2604                    }
2605                    crate::MathFunction::ExtractBits => {
2606                        // Only argument 1 is re-used.
2607                        self.need_bake_expressions.insert(arg1.unwrap());
2608                    }
2609                    crate::MathFunction::InsertBits => {
2610                        // Only argument 2 is re-used.
2611                        self.need_bake_expressions.insert(arg2.unwrap());
2612                    }
2613                    crate::MathFunction::Sign => {
2614                        // WGSL's `sign` function works also on signed ints, but Metal's only
2615                        // works on floating points, so we emit inline code for integer `sign`
2616                        // calls. But that code uses each argument 2 times (see `put_isign`),
2617                        // so to avoid duplicated evaluation, we must bake the argument.
2618                        let inner = context.resolve_type(expr_handle);
2619                        if inner.scalar_kind() == Some(crate::ScalarKind::Sint) {
2620                            self.need_bake_expressions.insert(arg);
2621                        }
2622                    }
2623                    _ => {}
2624                }
2625            }
2626        }
2627    }
2628
2629    fn start_baking_expression(
2630        &mut self,
2631        handle: Handle<crate::Expression>,
2632        context: &ExpressionContext,
2633        name: &str,
2634    ) -> BackendResult {
2635        match context.info[handle].ty {
2636            TypeResolution::Handle(ty_handle) => {
2637                let ty_name = TypeContext {
2638                    handle: ty_handle,
2639                    gctx: context.module.to_ctx(),
2640                    names: &self.names,
2641                    access: crate::StorageAccess::empty(),
2642                    binding: None,
2643                    first_time: false,
2644                };
2645                write!(self.out, "{ty_name}")?;
2646            }
2647            TypeResolution::Value(crate::TypeInner::Scalar(scalar)) => {
2648                put_numeric_type(&mut self.out, scalar, &[])?;
2649            }
2650            TypeResolution::Value(crate::TypeInner::Vector { size, scalar }) => {
2651                put_numeric_type(&mut self.out, scalar, &[size])?;
2652            }
2653            TypeResolution::Value(crate::TypeInner::Matrix {
2654                columns,
2655                rows,
2656                scalar,
2657            }) => {
2658                put_numeric_type(&mut self.out, scalar, &[rows, columns])?;
2659            }
2660            TypeResolution::Value(ref other) => {
2661                log::warn!("Type {:?} isn't a known local", other); //TEMP!
2662                return Err(Error::FeatureNotImplemented("weird local type".to_string()));
2663            }
2664        }
2665
2666        //TODO: figure out the naming scheme that wouldn't collide with user names.
2667        write!(self.out, " {name} = ")?;
2668
2669        Ok(())
2670    }
2671
2672    /// Cache a clamped level of detail value, if necessary.
2673    ///
2674    /// [`ImageLoad`] accesses covered by [`BoundsCheckPolicy::Restrict`] use a
2675    /// properly clamped level of detail value both in the access itself, and
2676    /// for fetching the size of the requested MIP level, needed to clamp the
2677    /// coordinates. To avoid recomputing this clamped level of detail, we cache
2678    /// it in a temporary variable, as part of the [`Emit`] statement covering
2679    /// the [`ImageLoad`] expression.
2680    ///
2681    /// [`ImageLoad`]: crate::Expression::ImageLoad
2682    /// [`BoundsCheckPolicy::Restrict`]: index::BoundsCheckPolicy::Restrict
2683    /// [`Emit`]: crate::Statement::Emit
2684    fn put_cache_restricted_level(
2685        &mut self,
2686        load: Handle<crate::Expression>,
2687        image: Handle<crate::Expression>,
2688        mip_level: Option<Handle<crate::Expression>>,
2689        indent: back::Level,
2690        context: &StatementContext,
2691    ) -> BackendResult {
2692        // Does this image access actually require (or even permit) a
2693        // level-of-detail, and does the policy require us to restrict it?
2694        let level_of_detail = match mip_level {
2695            Some(level) => level,
2696            None => return Ok(()),
2697        };
2698
2699        if context.expression.policies.image_load != index::BoundsCheckPolicy::Restrict
2700            || !context.expression.image_needs_lod(image)
2701        {
2702            return Ok(());
2703        }
2704
2705        write!(
2706            self.out,
2707            "{}uint {}{} = ",
2708            indent,
2709            CLAMPED_LOD_LOAD_PREFIX,
2710            load.index(),
2711        )?;
2712        self.put_restricted_scalar_image_index(
2713            image,
2714            level_of_detail,
2715            "get_num_mip_levels",
2716            &context.expression,
2717        )?;
2718        writeln!(self.out, ";")?;
2719
2720        Ok(())
2721    }
2722
2723    fn put_block(
2724        &mut self,
2725        level: back::Level,
2726        statements: &[crate::Statement],
2727        context: &StatementContext,
2728    ) -> BackendResult {
2729        // Add to the set in order to track the stack size.
2730        #[cfg(test)]
2731        #[allow(trivial_casts)]
2732        self.put_block_stack_pointers
2733            .insert(&level as *const _ as *const ());
2734
2735        for statement in statements {
2736            log::trace!("statement[{}] {:?}", level.0, statement);
2737            match *statement {
2738                crate::Statement::Emit(ref range) => {
2739                    for handle in range.clone() {
2740                        // `ImageLoad` expressions covered by the `Restrict` bounds check policy
2741                        // may need to cache a clamped version of their level-of-detail argument.
2742                        if let crate::Expression::ImageLoad {
2743                            image,
2744                            level: mip_level,
2745                            ..
2746                        } = context.expression.function.expressions[handle]
2747                        {
2748                            self.put_cache_restricted_level(
2749                                handle, image, mip_level, level, context,
2750                            )?;
2751                        }
2752
2753                        let ptr_class = context.expression.resolve_type(handle).pointer_space();
2754                        let expr_name = if ptr_class.is_some() {
2755                            None // don't bake pointer expressions (just yet)
2756                        } else if let Some(name) =
2757                            context.expression.function.named_expressions.get(&handle)
2758                        {
2759                            // The `crate::Function::named_expressions` table holds
2760                            // expressions that should be saved in temporaries once they
2761                            // are `Emit`ted. We only add them to `self.named_expressions`
2762                            // when we reach the `Emit` that covers them, so that we don't
2763                            // try to use their names before we've actually initialized
2764                            // the temporary that holds them.
2765                            //
2766                            // Don't assume the names in `named_expressions` are unique,
2767                            // or even valid. Use the `Namer`.
2768                            Some(self.namer.call(name))
2769                        } else {
2770                            // If this expression is an index that we're going to first compare
2771                            // against a limit, and then actually use as an index, then we may
2772                            // want to cache it in a temporary, to avoid evaluating it twice.
2773                            let bake =
2774                                if context.expression.guarded_indices.contains(handle.index()) {
2775                                    true
2776                                } else {
2777                                    self.need_bake_expressions.contains(&handle)
2778                                };
2779
2780                            if bake {
2781                                Some(format!("{}{}", back::BAKE_PREFIX, handle.index()))
2782                            } else {
2783                                None
2784                            }
2785                        };
2786
2787                        if let Some(name) = expr_name {
2788                            write!(self.out, "{level}")?;
2789                            self.start_baking_expression(handle, &context.expression, &name)?;
2790                            self.put_expression(handle, &context.expression, true)?;
2791                            self.named_expressions.insert(handle, name);
2792                            writeln!(self.out, ";")?;
2793                        }
2794                    }
2795                }
2796                crate::Statement::Block(ref block) => {
2797                    if !block.is_empty() {
2798                        writeln!(self.out, "{level}{{")?;
2799                        self.put_block(level.next(), block, context)?;
2800                        writeln!(self.out, "{level}}}")?;
2801                    }
2802                }
2803                crate::Statement::If {
2804                    condition,
2805                    ref accept,
2806                    ref reject,
2807                } => {
2808                    write!(self.out, "{level}if (")?;
2809                    self.put_expression(condition, &context.expression, true)?;
2810                    writeln!(self.out, ") {{")?;
2811                    self.put_block(level.next(), accept, context)?;
2812                    if !reject.is_empty() {
2813                        writeln!(self.out, "{level}}} else {{")?;
2814                        self.put_block(level.next(), reject, context)?;
2815                    }
2816                    writeln!(self.out, "{level}}}")?;
2817                }
2818                crate::Statement::Switch {
2819                    selector,
2820                    ref cases,
2821                } => {
2822                    write!(self.out, "{level}switch(")?;
2823                    self.put_expression(selector, &context.expression, true)?;
2824                    writeln!(self.out, ") {{")?;
2825                    let lcase = level.next();
2826                    for case in cases.iter() {
2827                        match case.value {
2828                            crate::SwitchValue::I32(value) => {
2829                                write!(self.out, "{lcase}case {value}:")?;
2830                            }
2831                            crate::SwitchValue::U32(value) => {
2832                                write!(self.out, "{lcase}case {value}u:")?;
2833                            }
2834                            crate::SwitchValue::Default => {
2835                                write!(self.out, "{lcase}default:")?;
2836                            }
2837                        }
2838
2839                        let write_block_braces = !(case.fall_through && case.body.is_empty());
2840                        if write_block_braces {
2841                            writeln!(self.out, " {{")?;
2842                        } else {
2843                            writeln!(self.out)?;
2844                        }
2845
2846                        self.put_block(lcase.next(), &case.body, context)?;
2847                        if !case.fall_through
2848                            && case.body.last().map_or(true, |s| !s.is_terminator())
2849                        {
2850                            writeln!(self.out, "{}break;", lcase.next())?;
2851                        }
2852
2853                        if write_block_braces {
2854                            writeln!(self.out, "{lcase}}}")?;
2855                        }
2856                    }
2857                    writeln!(self.out, "{level}}}")?;
2858                }
2859                crate::Statement::Loop {
2860                    ref body,
2861                    ref continuing,
2862                    break_if,
2863                } => {
2864                    if !continuing.is_empty() || break_if.is_some() {
2865                        let gate_name = self.namer.call("loop_init");
2866                        writeln!(self.out, "{level}bool {gate_name} = true;")?;
2867                        writeln!(self.out, "{level}while(true) {{")?;
2868                        let lif = level.next();
2869                        let lcontinuing = lif.next();
2870                        writeln!(self.out, "{lif}if (!{gate_name}) {{")?;
2871                        self.put_block(lcontinuing, continuing, context)?;
2872                        if let Some(condition) = break_if {
2873                            write!(self.out, "{lcontinuing}if (")?;
2874                            self.put_expression(condition, &context.expression, true)?;
2875                            writeln!(self.out, ") {{")?;
2876                            writeln!(self.out, "{}break;", lcontinuing.next())?;
2877                            writeln!(self.out, "{lcontinuing}}}")?;
2878                        }
2879                        writeln!(self.out, "{lif}}}")?;
2880                        writeln!(self.out, "{lif}{gate_name} = false;")?;
2881                    } else {
2882                        writeln!(self.out, "{level}while(true) {{")?;
2883                    }
2884                    self.put_block(level.next(), body, context)?;
2885                    writeln!(self.out, "{level}}}")?;
2886                }
2887                crate::Statement::Break => {
2888                    writeln!(self.out, "{level}break;")?;
2889                }
2890                crate::Statement::Continue => {
2891                    writeln!(self.out, "{level}continue;")?;
2892                }
2893                crate::Statement::Return {
2894                    value: Some(expr_handle),
2895                } => {
2896                    self.put_return_value(
2897                        level,
2898                        expr_handle,
2899                        context.result_struct,
2900                        &context.expression,
2901                    )?;
2902                }
2903                crate::Statement::Return { value: None } => {
2904                    writeln!(self.out, "{level}return;")?;
2905                }
2906                crate::Statement::Kill => {
2907                    writeln!(self.out, "{level}{NAMESPACE}::discard_fragment();")?;
2908                }
2909                crate::Statement::Barrier(flags) => {
2910                    self.write_barrier(flags, level)?;
2911                }
2912                crate::Statement::Store { pointer, value } => {
2913                    self.put_store(pointer, value, level, context)?
2914                }
2915                crate::Statement::ImageStore {
2916                    image,
2917                    coordinate,
2918                    array_index,
2919                    value,
2920                } => {
2921                    let address = TexelAddress {
2922                        coordinate,
2923                        array_index,
2924                        sample: None,
2925                        level: None,
2926                    };
2927                    self.put_image_store(level, image, &address, value, context)?
2928                }
2929                crate::Statement::Call {
2930                    function,
2931                    ref arguments,
2932                    result,
2933                } => {
2934                    write!(self.out, "{level}")?;
2935                    if let Some(expr) = result {
2936                        let name = format!("{}{}", back::BAKE_PREFIX, expr.index());
2937                        self.start_baking_expression(expr, &context.expression, &name)?;
2938                        self.named_expressions.insert(expr, name);
2939                    }
2940                    let fun_name = &self.names[&NameKey::Function(function)];
2941                    write!(self.out, "{fun_name}(")?;
2942                    // first, write down the actual arguments
2943                    for (i, &handle) in arguments.iter().enumerate() {
2944                        if i != 0 {
2945                            write!(self.out, ", ")?;
2946                        }
2947                        self.put_expression(handle, &context.expression, true)?;
2948                    }
2949                    // follow-up with any global resources used
2950                    let mut separate = !arguments.is_empty();
2951                    let fun_info = &context.expression.mod_info[function];
2952                    let mut supports_array_length = false;
2953                    for (handle, var) in context.expression.module.global_variables.iter() {
2954                        if fun_info[handle].is_empty() {
2955                            continue;
2956                        }
2957                        if var.space.needs_pass_through() {
2958                            let name = &self.names[&NameKey::GlobalVariable(handle)];
2959                            if separate {
2960                                write!(self.out, ", ")?;
2961                            } else {
2962                                separate = true;
2963                            }
2964                            write!(self.out, "{name}")?;
2965                        }
2966                        supports_array_length |=
2967                            needs_array_length(var.ty, &context.expression.module.types);
2968                    }
2969                    if supports_array_length {
2970                        if separate {
2971                            write!(self.out, ", ")?;
2972                        }
2973                        write!(self.out, "_buffer_sizes")?;
2974                    }
2975
2976                    // done
2977                    writeln!(self.out, ");")?;
2978                }
2979                crate::Statement::Atomic {
2980                    pointer,
2981                    ref fun,
2982                    value,
2983                    result,
2984                } => {
2985                    write!(self.out, "{level}")?;
2986                    let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
2987                    self.start_baking_expression(result, &context.expression, &res_name)?;
2988                    self.named_expressions.insert(result, res_name);
2989                    let fun_str = fun.to_msl()?;
2990                    self.put_atomic_operation(pointer, fun_str, value, &context.expression)?;
2991                    // done
2992                    writeln!(self.out, ";")?;
2993                }
2994                crate::Statement::WorkGroupUniformLoad { pointer, result } => {
2995                    self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
2996
2997                    write!(self.out, "{level}")?;
2998                    let name = self.namer.call("");
2999                    self.start_baking_expression(result, &context.expression, &name)?;
3000                    self.put_load(pointer, &context.expression, true)?;
3001                    self.named_expressions.insert(result, name);
3002
3003                    writeln!(self.out, ";")?;
3004                    self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
3005                }
3006                crate::Statement::RayQuery { query, ref fun } => {
3007                    if context.expression.lang_version < (2, 4) {
3008                        return Err(Error::UnsupportedRayTracing);
3009                    }
3010
3011                    match *fun {
3012                        crate::RayQueryFunction::Initialize {
3013                            acceleration_structure,
3014                            descriptor,
3015                        } => {
3016                            //TODO: how to deal with winding?
3017                            write!(self.out, "{level}")?;
3018                            self.put_expression(query, &context.expression, true)?;
3019                            writeln!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.assume_geometry_type({RT_NAMESPACE}::geometry_type::triangle);")?;
3020                            {
3021                                let f_opaque = back::RayFlag::CULL_OPAQUE.bits();
3022                                let f_no_opaque = back::RayFlag::CULL_NO_OPAQUE.bits();
3023                                write!(self.out, "{level}")?;
3024                                self.put_expression(query, &context.expression, true)?;
3025                                write!(
3026                                    self.out,
3027                                    ".{RAY_QUERY_FIELD_INTERSECTOR}.set_opacity_cull_mode(("
3028                                )?;
3029                                self.put_expression(descriptor, &context.expression, true)?;
3030                                write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::opaque : (")?;
3031                                self.put_expression(descriptor, &context.expression, true)?;
3032                                write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::non_opaque : ")?;
3033                                writeln!(self.out, "{RT_NAMESPACE}::opacity_cull_mode::none);")?;
3034                            }
3035                            {
3036                                let f_opaque = back::RayFlag::OPAQUE.bits();
3037                                let f_no_opaque = back::RayFlag::NO_OPAQUE.bits();
3038                                write!(self.out, "{level}")?;
3039                                self.put_expression(query, &context.expression, true)?;
3040                                write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.force_opacity((")?;
3041                                self.put_expression(descriptor, &context.expression, true)?;
3042                                write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::opaque : (")?;
3043                                self.put_expression(descriptor, &context.expression, true)?;
3044                                write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::non_opaque : ")?;
3045                                writeln!(self.out, "{RT_NAMESPACE}::forced_opacity::none);")?;
3046                            }
3047                            {
3048                                let flag = back::RayFlag::TERMINATE_ON_FIRST_HIT.bits();
3049                                write!(self.out, "{level}")?;
3050                                self.put_expression(query, &context.expression, true)?;
3051                                write!(
3052                                    self.out,
3053                                    ".{RAY_QUERY_FIELD_INTERSECTOR}.accept_any_intersection(("
3054                                )?;
3055                                self.put_expression(descriptor, &context.expression, true)?;
3056                                writeln!(self.out, ".flags & {flag}) != 0);")?;
3057                            }
3058
3059                            write!(self.out, "{level}")?;
3060                            self.put_expression(query, &context.expression, true)?;
3061                            write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION} = ")?;
3062                            self.put_expression(query, &context.expression, true)?;
3063                            write!(
3064                                self.out,
3065                                ".{RAY_QUERY_FIELD_INTERSECTOR}.intersect({RT_NAMESPACE}::ray("
3066                            )?;
3067                            self.put_expression(descriptor, &context.expression, true)?;
3068                            write!(self.out, ".origin, ")?;
3069                            self.put_expression(descriptor, &context.expression, true)?;
3070                            write!(self.out, ".dir, ")?;
3071                            self.put_expression(descriptor, &context.expression, true)?;
3072                            write!(self.out, ".tmin, ")?;
3073                            self.put_expression(descriptor, &context.expression, true)?;
3074                            write!(self.out, ".tmax), ")?;
3075                            self.put_expression(acceleration_structure, &context.expression, true)?;
3076                            write!(self.out, ", ")?;
3077                            self.put_expression(descriptor, &context.expression, true)?;
3078                            write!(self.out, ".cull_mask);")?;
3079
3080                            write!(self.out, "{level}")?;
3081                            self.put_expression(query, &context.expression, true)?;
3082                            writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = true;")?;
3083                        }
3084                        crate::RayQueryFunction::Proceed { result } => {
3085                            write!(self.out, "{level}")?;
3086                            let name = format!("{}{}", back::BAKE_PREFIX, result.index());
3087                            self.start_baking_expression(result, &context.expression, &name)?;
3088                            self.named_expressions.insert(result, name);
3089                            self.put_expression(query, &context.expression, true)?;
3090                            writeln!(self.out, ".{RAY_QUERY_FIELD_READY};")?;
3091                            //TODO: actually proceed?
3092
3093                            write!(self.out, "{level}")?;
3094                            self.put_expression(query, &context.expression, true)?;
3095                            writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = false;")?;
3096                        }
3097                        crate::RayQueryFunction::Terminate => {
3098                            write!(self.out, "{level}")?;
3099                            self.put_expression(query, &context.expression, true)?;
3100                            writeln!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.abort();")?;
3101                        }
3102                    }
3103                }
3104                crate::Statement::SubgroupBallot { result, predicate } => {
3105                    write!(self.out, "{level}")?;
3106                    let name = self.namer.call("");
3107                    self.start_baking_expression(result, &context.expression, &name)?;
3108                    self.named_expressions.insert(result, name);
3109                    write!(self.out, "uint4((uint64_t){NAMESPACE}::simd_ballot(")?;
3110                    if let Some(predicate) = predicate {
3111                        self.put_expression(predicate, &context.expression, true)?;
3112                    } else {
3113                        write!(self.out, "true")?;
3114                    }
3115                    writeln!(self.out, "), 0, 0, 0);")?;
3116                }
3117                crate::Statement::SubgroupCollectiveOperation {
3118                    op,
3119                    collective_op,
3120                    argument,
3121                    result,
3122                } => {
3123                    write!(self.out, "{level}")?;
3124                    let name = self.namer.call("");
3125                    self.start_baking_expression(result, &context.expression, &name)?;
3126                    self.named_expressions.insert(result, name);
3127                    match (collective_op, op) {
3128                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
3129                            write!(self.out, "{NAMESPACE}::simd_all(")?
3130                        }
3131                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
3132                            write!(self.out, "{NAMESPACE}::simd_any(")?
3133                        }
3134                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
3135                            write!(self.out, "{NAMESPACE}::simd_sum(")?
3136                        }
3137                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
3138                            write!(self.out, "{NAMESPACE}::simd_product(")?
3139                        }
3140                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
3141                            write!(self.out, "{NAMESPACE}::simd_max(")?
3142                        }
3143                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
3144                            write!(self.out, "{NAMESPACE}::simd_min(")?
3145                        }
3146                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
3147                            write!(self.out, "{NAMESPACE}::simd_and(")?
3148                        }
3149                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
3150                            write!(self.out, "{NAMESPACE}::simd_or(")?
3151                        }
3152                        (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
3153                            write!(self.out, "{NAMESPACE}::simd_xor(")?
3154                        }
3155                        (
3156                            crate::CollectiveOperation::ExclusiveScan,
3157                            crate::SubgroupOperation::Add,
3158                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_sum(")?,
3159                        (
3160                            crate::CollectiveOperation::ExclusiveScan,
3161                            crate::SubgroupOperation::Mul,
3162                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_product(")?,
3163                        (
3164                            crate::CollectiveOperation::InclusiveScan,
3165                            crate::SubgroupOperation::Add,
3166                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_sum(")?,
3167                        (
3168                            crate::CollectiveOperation::InclusiveScan,
3169                            crate::SubgroupOperation::Mul,
3170                        ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_product(")?,
3171                        _ => unimplemented!(),
3172                    }
3173                    self.put_expression(argument, &context.expression, true)?;
3174                    writeln!(self.out, ");")?;
3175                }
3176                crate::Statement::SubgroupGather {
3177                    mode,
3178                    argument,
3179                    result,
3180                } => {
3181                    write!(self.out, "{level}")?;
3182                    let name = self.namer.call("");
3183                    self.start_baking_expression(result, &context.expression, &name)?;
3184                    self.named_expressions.insert(result, name);
3185                    match mode {
3186                        crate::GatherMode::BroadcastFirst => {
3187                            write!(self.out, "{NAMESPACE}::simd_broadcast_first(")?;
3188                        }
3189                        crate::GatherMode::Broadcast(_) => {
3190                            write!(self.out, "{NAMESPACE}::simd_broadcast(")?;
3191                        }
3192                        crate::GatherMode::Shuffle(_) => {
3193                            write!(self.out, "{NAMESPACE}::simd_shuffle(")?;
3194                        }
3195                        crate::GatherMode::ShuffleDown(_) => {
3196                            write!(self.out, "{NAMESPACE}::simd_shuffle_down(")?;
3197                        }
3198                        crate::GatherMode::ShuffleUp(_) => {
3199                            write!(self.out, "{NAMESPACE}::simd_shuffle_up(")?;
3200                        }
3201                        crate::GatherMode::ShuffleXor(_) => {
3202                            write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?;
3203                        }
3204                    }
3205                    self.put_expression(argument, &context.expression, true)?;
3206                    match mode {
3207                        crate::GatherMode::BroadcastFirst => {}
3208                        crate::GatherMode::Broadcast(index)
3209                        | crate::GatherMode::Shuffle(index)
3210                        | crate::GatherMode::ShuffleDown(index)
3211                        | crate::GatherMode::ShuffleUp(index)
3212                        | crate::GatherMode::ShuffleXor(index) => {
3213                            write!(self.out, ", ")?;
3214                            self.put_expression(index, &context.expression, true)?;
3215                        }
3216                    }
3217                    writeln!(self.out, ");")?;
3218                }
3219            }
3220        }
3221
3222        // un-emit expressions
3223        //TODO: take care of loop/continuing?
3224        for statement in statements {
3225            if let crate::Statement::Emit(ref range) = *statement {
3226                for handle in range.clone() {
3227                    self.named_expressions.shift_remove(&handle);
3228                }
3229            }
3230        }
3231        Ok(())
3232    }
3233
3234    fn put_store(
3235        &mut self,
3236        pointer: Handle<crate::Expression>,
3237        value: Handle<crate::Expression>,
3238        level: back::Level,
3239        context: &StatementContext,
3240    ) -> BackendResult {
3241        let policy = context.expression.choose_bounds_check_policy(pointer);
3242        if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
3243            && self.put_bounds_checks(pointer, &context.expression, level, "if (")?
3244        {
3245            writeln!(self.out, ") {{")?;
3246            self.put_unchecked_store(pointer, value, policy, level.next(), context)?;
3247            writeln!(self.out, "{level}}}")?;
3248        } else {
3249            self.put_unchecked_store(pointer, value, policy, level, context)?;
3250        }
3251
3252        Ok(())
3253    }
3254
3255    fn put_unchecked_store(
3256        &mut self,
3257        pointer: Handle<crate::Expression>,
3258        value: Handle<crate::Expression>,
3259        policy: index::BoundsCheckPolicy,
3260        level: back::Level,
3261        context: &StatementContext,
3262    ) -> BackendResult {
3263        let is_atomic_pointer = context
3264            .expression
3265            .resolve_type(pointer)
3266            .is_atomic_pointer(&context.expression.module.types);
3267
3268        if is_atomic_pointer {
3269            write!(
3270                self.out,
3271                "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
3272            )?;
3273            self.put_access_chain(pointer, policy, &context.expression)?;
3274            write!(self.out, ", ")?;
3275            self.put_expression(value, &context.expression, true)?;
3276            writeln!(self.out, ", {NAMESPACE}::memory_order_relaxed);")?;
3277        } else {
3278            write!(self.out, "{level}")?;
3279            self.put_access_chain(pointer, policy, &context.expression)?;
3280            write!(self.out, " = ")?;
3281            self.put_expression(value, &context.expression, true)?;
3282            writeln!(self.out, ";")?;
3283        }
3284
3285        Ok(())
3286    }
3287
3288    pub fn write(
3289        &mut self,
3290        module: &crate::Module,
3291        info: &valid::ModuleInfo,
3292        options: &Options,
3293        pipeline_options: &PipelineOptions,
3294    ) -> Result<TranslationInfo, Error> {
3295        if !module.overrides.is_empty() {
3296            return Err(Error::Override);
3297        }
3298
3299        self.names.clear();
3300        self.namer.reset(
3301            module,
3302            super::keywords::RESERVED,
3303            &[],
3304            &[],
3305            &[CLAMPED_LOD_LOAD_PREFIX],
3306            &mut self.names,
3307        );
3308        self.struct_member_pads.clear();
3309
3310        writeln!(
3311            self.out,
3312            "// language: metal{}.{}",
3313            options.lang_version.0, options.lang_version.1
3314        )?;
3315        writeln!(self.out, "#include <metal_stdlib>")?;
3316        writeln!(self.out, "#include <simd/simd.h>")?;
3317        writeln!(self.out)?;
3318        // Work around Metal bug where `uint` is not available by default
3319        writeln!(self.out, "using {NAMESPACE}::uint;")?;
3320
3321        let mut uses_ray_query = false;
3322        for (_, ty) in module.types.iter() {
3323            match ty.inner {
3324                crate::TypeInner::AccelerationStructure => {
3325                    if options.lang_version < (2, 4) {
3326                        return Err(Error::UnsupportedRayTracing);
3327                    }
3328                }
3329                crate::TypeInner::RayQuery => {
3330                    if options.lang_version < (2, 4) {
3331                        return Err(Error::UnsupportedRayTracing);
3332                    }
3333                    uses_ray_query = true;
3334                }
3335                _ => (),
3336            }
3337        }
3338
3339        if module.special_types.ray_desc.is_some()
3340            || module.special_types.ray_intersection.is_some()
3341        {
3342            if options.lang_version < (2, 4) {
3343                return Err(Error::UnsupportedRayTracing);
3344            }
3345        }
3346
3347        if uses_ray_query {
3348            self.put_ray_query_type()?;
3349        }
3350
3351        if options
3352            .bounds_check_policies
3353            .contains(index::BoundsCheckPolicy::ReadZeroSkipWrite)
3354        {
3355            self.put_default_constructible()?;
3356        }
3357        writeln!(self.out)?;
3358
3359        {
3360            let mut indices = vec![];
3361            for (handle, var) in module.global_variables.iter() {
3362                if needs_array_length(var.ty, &module.types) {
3363                    let idx = handle.index();
3364                    indices.push(idx);
3365                }
3366            }
3367
3368            if !indices.is_empty() {
3369                writeln!(self.out, "struct _mslBufferSizes {{")?;
3370
3371                for idx in indices {
3372                    writeln!(self.out, "{}uint size{};", back::INDENT, idx)?;
3373                }
3374
3375                writeln!(self.out, "}};")?;
3376                writeln!(self.out)?;
3377            }
3378        };
3379
3380        self.write_type_defs(module)?;
3381        self.write_global_constants(module, info)?;
3382        self.write_functions(module, info, options, pipeline_options)
3383    }
3384
3385    /// Write the definition for the `DefaultConstructible` class.
3386    ///
3387    /// The [`ReadZeroSkipWrite`] bounds check policy requires us to be able to
3388    /// produce 'zero' values for any type, including structs, arrays, and so
3389    /// on. We could do this by emitting default constructor applications, but
3390    /// that would entail printing the name of the type, which is more trouble
3391    /// than you'd think. Instead, we just construct this magic C++14 class that
3392    /// can be converted to any type that can be default constructed, using
3393    /// template parameter inference to detect which type is needed, so we don't
3394    /// have to figure out the name.
3395    ///
3396    /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite
3397    fn put_default_constructible(&mut self) -> BackendResult {
3398        let tab = back::INDENT;
3399        writeln!(self.out, "struct DefaultConstructible {{")?;
3400        writeln!(self.out, "{tab}template<typename T>")?;
3401        writeln!(self.out, "{tab}operator T() && {{")?;
3402        writeln!(self.out, "{tab}{tab}return T {{}};")?;
3403        writeln!(self.out, "{tab}}}")?;
3404        writeln!(self.out, "}};")?;
3405        Ok(())
3406    }
3407
3408    fn put_ray_query_type(&mut self) -> BackendResult {
3409        let tab = back::INDENT;
3410        writeln!(self.out, "struct {RAY_QUERY_TYPE} {{")?;
3411        let full_type = format!("{RT_NAMESPACE}::intersector<{RT_NAMESPACE}::instancing, {RT_NAMESPACE}::triangle_data, {RT_NAMESPACE}::world_space_data>");
3412        writeln!(self.out, "{tab}{full_type} {RAY_QUERY_FIELD_INTERSECTOR};")?;
3413        writeln!(
3414            self.out,
3415            "{tab}{full_type}::result_type {RAY_QUERY_FIELD_INTERSECTION};"
3416        )?;
3417        writeln!(self.out, "{tab}bool {RAY_QUERY_FIELD_READY} = false;")?;
3418        writeln!(self.out, "}};")?;
3419        writeln!(self.out, "constexpr {NAMESPACE}::uint {RAY_QUERY_FUN_MAP_INTERSECTION}(const {RT_NAMESPACE}::intersection_type ty) {{")?;
3420        let v_triangle = back::RayIntersectionType::Triangle as u32;
3421        let v_bbox = back::RayIntersectionType::BoundingBox as u32;
3422        writeln!(
3423            self.out,
3424            "{tab}return ty=={RT_NAMESPACE}::intersection_type::triangle ? {v_triangle} : "
3425        )?;
3426        writeln!(
3427            self.out,
3428            "{tab}{tab}ty=={RT_NAMESPACE}::intersection_type::bounding_box ? {v_bbox} : 0;"
3429        )?;
3430        writeln!(self.out, "}}")?;
3431        Ok(())
3432    }
3433
3434    fn write_type_defs(&mut self, module: &crate::Module) -> BackendResult {
3435        for (handle, ty) in module.types.iter() {
3436            if !ty.needs_alias() {
3437                continue;
3438            }
3439            let name = &self.names[&NameKey::Type(handle)];
3440            match ty.inner {
3441                // Naga IR can pass around arrays by value, but Metal, following
3442                // C++, performs an array-to-pointer conversion (C++ [conv.array])
3443                // on expressions of array type, so assigning the array by value
3444                // isn't possible. However, Metal *does* assign structs by
3445                // value. So in our Metal output, we wrap all array types in
3446                // synthetic struct types:
3447                //
3448                //     struct type1 {
3449                //         float inner[10]
3450                //     };
3451                //
3452                // Then we carefully include `.inner` (`WRAPPED_ARRAY_FIELD`) in
3453                // any expression that actually wants access to the array.
3454                crate::TypeInner::Array {
3455                    base,
3456                    size,
3457                    stride: _,
3458                } => {
3459                    let base_name = TypeContext {
3460                        handle: base,
3461                        gctx: module.to_ctx(),
3462                        names: &self.names,
3463                        access: crate::StorageAccess::empty(),
3464                        binding: None,
3465                        first_time: false,
3466                    };
3467
3468                    match size {
3469                        crate::ArraySize::Constant(size) => {
3470                            writeln!(self.out, "struct {name} {{")?;
3471                            writeln!(
3472                                self.out,
3473                                "{}{} {}[{}];",
3474                                back::INDENT,
3475                                base_name,
3476                                WRAPPED_ARRAY_FIELD,
3477                                size
3478                            )?;
3479                            writeln!(self.out, "}};")?;
3480                        }
3481                        crate::ArraySize::Dynamic => {
3482                            writeln!(self.out, "typedef {base_name} {name}[1];")?;
3483                        }
3484                    }
3485                }
3486                crate::TypeInner::Struct {
3487                    ref members, span, ..
3488                } => {
3489                    writeln!(self.out, "struct {name} {{")?;
3490                    let mut last_offset = 0;
3491                    for (index, member) in members.iter().enumerate() {
3492                        if member.offset > last_offset {
3493                            self.struct_member_pads.insert((handle, index as u32));
3494                            let pad = member.offset - last_offset;
3495                            writeln!(self.out, "{}char _pad{}[{}];", back::INDENT, index, pad)?;
3496                        }
3497                        let ty_inner = &module.types[member.ty].inner;
3498                        last_offset = member.offset + ty_inner.size(module.to_ctx());
3499
3500                        let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
3501
3502                        // If the member should be packed (as is the case for a misaligned vec3) issue a packed vector
3503                        match should_pack_struct_member(members, span, index, module) {
3504                            Some(scalar) => {
3505                                writeln!(
3506                                    self.out,
3507                                    "{}{}::packed_{}3 {};",
3508                                    back::INDENT,
3509                                    NAMESPACE,
3510                                    scalar.to_msl_name(),
3511                                    member_name
3512                                )?;
3513                            }
3514                            None => {
3515                                let base_name = TypeContext {
3516                                    handle: member.ty,
3517                                    gctx: module.to_ctx(),
3518                                    names: &self.names,
3519                                    access: crate::StorageAccess::empty(),
3520                                    binding: None,
3521                                    first_time: false,
3522                                };
3523                                writeln!(
3524                                    self.out,
3525                                    "{}{} {};",
3526                                    back::INDENT,
3527                                    base_name,
3528                                    member_name
3529                                )?;
3530
3531                                // for 3-component vectors, add one component
3532                                if let crate::TypeInner::Vector {
3533                                    size: crate::VectorSize::Tri,
3534                                    scalar,
3535                                } = *ty_inner
3536                                {
3537                                    last_offset += scalar.width as u32;
3538                                }
3539                            }
3540                        }
3541                    }
3542                    writeln!(self.out, "}};")?;
3543                }
3544                _ => {
3545                    let ty_name = TypeContext {
3546                        handle,
3547                        gctx: module.to_ctx(),
3548                        names: &self.names,
3549                        access: crate::StorageAccess::empty(),
3550                        binding: None,
3551                        first_time: true,
3552                    };
3553                    writeln!(self.out, "typedef {ty_name} {name};")?;
3554                }
3555            }
3556        }
3557
3558        // Write functions to create special types.
3559        for (type_key, struct_ty) in module.special_types.predeclared_types.iter() {
3560            match type_key {
3561                &crate::PredeclaredType::ModfResult { size, width }
3562                | &crate::PredeclaredType::FrexpResult { size, width } => {
3563                    let arg_type_name_owner;
3564                    let arg_type_name = if let Some(size) = size {
3565                        arg_type_name_owner = format!(
3566                            "{NAMESPACE}::{}{}",
3567                            if width == 8 { "double" } else { "float" },
3568                            size as u8
3569                        );
3570                        &arg_type_name_owner
3571                    } else if width == 8 {
3572                        "double"
3573                    } else {
3574                        "float"
3575                    };
3576
3577                    let other_type_name_owner;
3578                    let (defined_func_name, called_func_name, other_type_name) =
3579                        if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) {
3580                            (MODF_FUNCTION, "modf", arg_type_name)
3581                        } else {
3582                            let other_type_name = if let Some(size) = size {
3583                                other_type_name_owner = format!("int{}", size as u8);
3584                                &other_type_name_owner
3585                            } else {
3586                                "int"
3587                            };
3588                            (FREXP_FUNCTION, "frexp", other_type_name)
3589                        };
3590
3591                    let struct_name = &self.names[&NameKey::Type(*struct_ty)];
3592
3593                    writeln!(self.out)?;
3594                    writeln!(
3595                        self.out,
3596                        "{} {defined_func_name}({arg_type_name} arg) {{
3597    {other_type_name} other;
3598    {arg_type_name} fract = {NAMESPACE}::{called_func_name}(arg, other);
3599    return {}{{ fract, other }};
3600}}",
3601                        struct_name, struct_name
3602                    )?;
3603                }
3604                &crate::PredeclaredType::AtomicCompareExchangeWeakResult { .. } => {}
3605            }
3606        }
3607
3608        Ok(())
3609    }
3610
3611    /// Writes all named constants
3612    fn write_global_constants(
3613        &mut self,
3614        module: &crate::Module,
3615        mod_info: &valid::ModuleInfo,
3616    ) -> BackendResult {
3617        let constants = module.constants.iter().filter(|&(_, c)| c.name.is_some());
3618
3619        for (handle, constant) in constants {
3620            let ty_name = TypeContext {
3621                handle: constant.ty,
3622                gctx: module.to_ctx(),
3623                names: &self.names,
3624                access: crate::StorageAccess::empty(),
3625                binding: None,
3626                first_time: false,
3627            };
3628            let name = &self.names[&NameKey::Constant(handle)];
3629            write!(self.out, "constant {ty_name} {name} = ")?;
3630            self.put_const_expression(constant.init, module, mod_info)?;
3631            writeln!(self.out, ";")?;
3632        }
3633
3634        Ok(())
3635    }
3636
3637    fn put_inline_sampler_properties(
3638        &mut self,
3639        level: back::Level,
3640        sampler: &sm::InlineSampler,
3641    ) -> BackendResult {
3642        for (&letter, address) in ['s', 't', 'r'].iter().zip(sampler.address.iter()) {
3643            writeln!(
3644                self.out,
3645                "{}{}::{}_address::{},",
3646                level,
3647                NAMESPACE,
3648                letter,
3649                address.as_str(),
3650            )?;
3651        }
3652        writeln!(
3653            self.out,
3654            "{}{}::mag_filter::{},",
3655            level,
3656            NAMESPACE,
3657            sampler.mag_filter.as_str(),
3658        )?;
3659        writeln!(
3660            self.out,
3661            "{}{}::min_filter::{},",
3662            level,
3663            NAMESPACE,
3664            sampler.min_filter.as_str(),
3665        )?;
3666        if let Some(filter) = sampler.mip_filter {
3667            writeln!(
3668                self.out,
3669                "{}{}::mip_filter::{},",
3670                level,
3671                NAMESPACE,
3672                filter.as_str(),
3673            )?;
3674        }
3675        // avoid setting it on platforms that don't support it
3676        if sampler.border_color != sm::BorderColor::TransparentBlack {
3677            writeln!(
3678                self.out,
3679                "{}{}::border_color::{},",
3680                level,
3681                NAMESPACE,
3682                sampler.border_color.as_str(),
3683            )?;
3684        }
3685        //TODO: I'm not able to feed this in a way that MSL likes:
3686        //>error: use of undeclared identifier 'lod_clamp'
3687        //>error: no member named 'max_anisotropy' in namespace 'metal'
3688        if false {
3689            if let Some(ref lod) = sampler.lod_clamp {
3690                writeln!(self.out, "{}lod_clamp({},{}),", level, lod.start, lod.end,)?;
3691            }
3692            if let Some(aniso) = sampler.max_anisotropy {
3693                writeln!(self.out, "{}max_anisotropy({}),", level, aniso.get(),)?;
3694            }
3695        }
3696        if sampler.compare_func != sm::CompareFunc::Never {
3697            writeln!(
3698                self.out,
3699                "{}{}::compare_func::{},",
3700                level,
3701                NAMESPACE,
3702                sampler.compare_func.as_str(),
3703            )?;
3704        }
3705        writeln!(
3706            self.out,
3707            "{}{}::coord::{}",
3708            level,
3709            NAMESPACE,
3710            sampler.coord.as_str()
3711        )?;
3712        Ok(())
3713    }
3714
3715    // Returns the array of mapped entry point names.
3716    fn write_functions(
3717        &mut self,
3718        module: &crate::Module,
3719        mod_info: &valid::ModuleInfo,
3720        options: &Options,
3721        pipeline_options: &PipelineOptions,
3722    ) -> Result<TranslationInfo, Error> {
3723        let mut pass_through_globals = Vec::new();
3724        for (fun_handle, fun) in module.functions.iter() {
3725            log::trace!(
3726                "function {:?}, handle {:?}",
3727                fun.name.as_deref().unwrap_or("(anonymous)"),
3728                fun_handle
3729            );
3730
3731            let fun_info = &mod_info[fun_handle];
3732            pass_through_globals.clear();
3733            let mut supports_array_length = false;
3734            for (handle, var) in module.global_variables.iter() {
3735                if !fun_info[handle].is_empty() {
3736                    if var.space.needs_pass_through() {
3737                        pass_through_globals.push(handle);
3738                    }
3739                    supports_array_length |= needs_array_length(var.ty, &module.types);
3740                }
3741            }
3742
3743            writeln!(self.out)?;
3744            let fun_name = &self.names[&NameKey::Function(fun_handle)];
3745            match fun.result {
3746                Some(ref result) => {
3747                    let ty_name = TypeContext {
3748                        handle: result.ty,
3749                        gctx: module.to_ctx(),
3750                        names: &self.names,
3751                        access: crate::StorageAccess::empty(),
3752                        binding: None,
3753                        first_time: false,
3754                    };
3755                    write!(self.out, "{ty_name}")?;
3756                }
3757                None => {
3758                    write!(self.out, "void")?;
3759                }
3760            }
3761            writeln!(self.out, " {fun_name}(")?;
3762
3763            for (index, arg) in fun.arguments.iter().enumerate() {
3764                let name = &self.names[&NameKey::FunctionArgument(fun_handle, index as u32)];
3765                let param_type_name = TypeContext {
3766                    handle: arg.ty,
3767                    gctx: module.to_ctx(),
3768                    names: &self.names,
3769                    access: crate::StorageAccess::empty(),
3770                    binding: None,
3771                    first_time: false,
3772                };
3773                let separator = separate(
3774                    !pass_through_globals.is_empty()
3775                        || index + 1 != fun.arguments.len()
3776                        || supports_array_length,
3777                );
3778                writeln!(
3779                    self.out,
3780                    "{}{} {}{}",
3781                    back::INDENT,
3782                    param_type_name,
3783                    name,
3784                    separator
3785                )?;
3786            }
3787            for (index, &handle) in pass_through_globals.iter().enumerate() {
3788                let tyvar = TypedGlobalVariable {
3789                    module,
3790                    names: &self.names,
3791                    handle,
3792                    usage: fun_info[handle],
3793                    binding: None,
3794                    reference: true,
3795                };
3796                let separator =
3797                    separate(index + 1 != pass_through_globals.len() || supports_array_length);
3798                write!(self.out, "{}", back::INDENT)?;
3799                tyvar.try_fmt(&mut self.out)?;
3800                writeln!(self.out, "{separator}")?;
3801            }
3802
3803            if supports_array_length {
3804                writeln!(
3805                    self.out,
3806                    "{}constant _mslBufferSizes& _buffer_sizes",
3807                    back::INDENT
3808                )?;
3809            }
3810
3811            writeln!(self.out, ") {{")?;
3812
3813            let guarded_indices =
3814                index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
3815
3816            let context = StatementContext {
3817                expression: ExpressionContext {
3818                    function: fun,
3819                    origin: FunctionOrigin::Handle(fun_handle),
3820                    info: fun_info,
3821                    lang_version: options.lang_version,
3822                    policies: options.bounds_check_policies,
3823                    guarded_indices,
3824                    module,
3825                    mod_info,
3826                    pipeline_options,
3827                },
3828                result_struct: None,
3829            };
3830
3831            for (local_handle, local) in fun.local_variables.iter() {
3832                let ty_name = TypeContext {
3833                    handle: local.ty,
3834                    gctx: module.to_ctx(),
3835                    names: &self.names,
3836                    access: crate::StorageAccess::empty(),
3837                    binding: None,
3838                    first_time: false,
3839                };
3840                let local_name = &self.names[&NameKey::FunctionLocal(fun_handle, local_handle)];
3841                write!(self.out, "{}{} {}", back::INDENT, ty_name, local_name)?;
3842                match local.init {
3843                    Some(value) => {
3844                        write!(self.out, " = ")?;
3845                        self.put_expression(value, &context.expression, true)?;
3846                    }
3847                    None => {
3848                        write!(self.out, " = {{}}")?;
3849                    }
3850                };
3851                writeln!(self.out, ";")?;
3852            }
3853
3854            self.update_expressions_to_bake(fun, fun_info, &context.expression);
3855            self.put_block(back::Level(1), &fun.body, &context)?;
3856            writeln!(self.out, "}}")?;
3857            self.named_expressions.clear();
3858        }
3859
3860        let mut info = TranslationInfo {
3861            entry_point_names: Vec::with_capacity(module.entry_points.len()),
3862        };
3863        for (ep_index, ep) in module.entry_points.iter().enumerate() {
3864            let fun = &ep.function;
3865            let fun_info = mod_info.get_entry_point(ep_index);
3866            let mut ep_error = None;
3867
3868            log::trace!(
3869                "entry point {:?}, index {:?}",
3870                fun.name.as_deref().unwrap_or("(anonymous)"),
3871                ep_index
3872            );
3873
3874            // Is any global variable used by this entry point dynamically sized?
3875            let supports_array_length = module
3876                .global_variables
3877                .iter()
3878                .filter(|&(handle, _)| !fun_info[handle].is_empty())
3879                .any(|(_, var)| needs_array_length(var.ty, &module.types));
3880
3881            // skip this entry point if any global bindings are missing,
3882            // or their types are incompatible.
3883            if !options.fake_missing_bindings {
3884                for (var_handle, var) in module.global_variables.iter() {
3885                    if fun_info[var_handle].is_empty() {
3886                        continue;
3887                    }
3888                    match var.space {
3889                        crate::AddressSpace::Uniform
3890                        | crate::AddressSpace::Storage { .. }
3891                        | crate::AddressSpace::Handle => {
3892                            let br = match var.binding {
3893                                Some(ref br) => br,
3894                                None => {
3895                                    let var_name = var.name.clone().unwrap_or_default();
3896                                    ep_error =
3897                                        Some(super::EntryPointError::MissingBinding(var_name));
3898                                    break;
3899                                }
3900                            };
3901                            let target = options.get_resource_binding_target(ep, br);
3902                            let good = match target {
3903                                Some(target) => {
3904                                    let binding_ty = match module.types[var.ty].inner {
3905                                        crate::TypeInner::BindingArray { base, .. } => {
3906                                            &module.types[base].inner
3907                                        }
3908                                        ref ty => ty,
3909                                    };
3910                                    match *binding_ty {
3911                                        crate::TypeInner::Image { .. } => target.texture.is_some(),
3912                                        crate::TypeInner::Sampler { .. } => {
3913                                            target.sampler.is_some()
3914                                        }
3915                                        _ => target.buffer.is_some(),
3916                                    }
3917                                }
3918                                None => false,
3919                            };
3920                            if !good {
3921                                ep_error =
3922                                    Some(super::EntryPointError::MissingBindTarget(br.clone()));
3923                                break;
3924                            }
3925                        }
3926                        crate::AddressSpace::PushConstant => {
3927                            if let Err(e) = options.resolve_push_constants(ep) {
3928                                ep_error = Some(e);
3929                                break;
3930                            }
3931                        }
3932                        crate::AddressSpace::Function
3933                        | crate::AddressSpace::Private
3934                        | crate::AddressSpace::WorkGroup => {}
3935                    }
3936                }
3937                if supports_array_length {
3938                    if let Err(err) = options.resolve_sizes_buffer(ep) {
3939                        ep_error = Some(err);
3940                    }
3941                }
3942            }
3943
3944            if let Some(err) = ep_error {
3945                info.entry_point_names.push(Err(err));
3946                continue;
3947            }
3948            let fun_name = &self.names[&NameKey::EntryPoint(ep_index as _)];
3949            info.entry_point_names.push(Ok(fun_name.clone()));
3950
3951            writeln!(self.out)?;
3952
3953            let (em_str, in_mode, out_mode) = match ep.stage {
3954                crate::ShaderStage::Vertex => (
3955                    "vertex",
3956                    LocationMode::VertexInput,
3957                    LocationMode::VertexOutput,
3958                ),
3959                crate::ShaderStage::Fragment { .. } => (
3960                    "fragment",
3961                    LocationMode::FragmentInput,
3962                    LocationMode::FragmentOutput,
3963                ),
3964                crate::ShaderStage::Compute { .. } => {
3965                    ("kernel", LocationMode::Uniform, LocationMode::Uniform)
3966                }
3967            };
3968
3969            // Since `Namer.reset` wasn't expecting struct members to be
3970            // suddenly injected into another namespace like this,
3971            // `self.names` doesn't keep them distinct from other variables.
3972            // Generate fresh names for these arguments, and remember the
3973            // mapping.
3974            let mut flattened_member_names = FastHashMap::default();
3975            // Varyings' members get their own namespace
3976            let mut varyings_namer = crate::proc::Namer::default();
3977
3978            // List all the Naga `EntryPoint`'s `Function`'s arguments,
3979            // flattening structs into their members. In Metal, we will pass
3980            // each of these values to the entry point as a separate argument—
3981            // except for the varyings, handled next.
3982            let mut flattened_arguments = Vec::new();
3983            for (arg_index, arg) in fun.arguments.iter().enumerate() {
3984                match module.types[arg.ty].inner {
3985                    crate::TypeInner::Struct { ref members, .. } => {
3986                        for (member_index, member) in members.iter().enumerate() {
3987                            let member_index = member_index as u32;
3988                            flattened_arguments.push((
3989                                NameKey::StructMember(arg.ty, member_index),
3990                                member.ty,
3991                                member.binding.as_ref(),
3992                            ));
3993                            let name_key = NameKey::StructMember(arg.ty, member_index);
3994                            let name = match member.binding {
3995                                Some(crate::Binding::Location { .. }) => {
3996                                    varyings_namer.call(&self.names[&name_key])
3997                                }
3998                                _ => self.namer.call(&self.names[&name_key]),
3999                            };
4000                            flattened_member_names.insert(name_key, name);
4001                        }
4002                    }
4003                    _ => flattened_arguments.push((
4004                        NameKey::EntryPointArgument(ep_index as _, arg_index as u32),
4005                        arg.ty,
4006                        arg.binding.as_ref(),
4007                    )),
4008                }
4009            }
4010
4011            // Identify the varyings among the argument values, and emit a
4012            // struct type named `<fun>Input` to hold them.
4013            let stage_in_name = format!("{fun_name}Input");
4014            let varyings_member_name = self.namer.call("varyings");
4015            let mut has_varyings = false;
4016            if !flattened_arguments.is_empty() {
4017                writeln!(self.out, "struct {stage_in_name} {{")?;
4018                for &(ref name_key, ty, binding) in flattened_arguments.iter() {
4019                    let binding = match binding {
4020                        Some(ref binding @ &crate::Binding::Location { .. }) => binding,
4021                        _ => continue,
4022                    };
4023                    has_varyings = true;
4024                    let name = match *name_key {
4025                        NameKey::StructMember(..) => &flattened_member_names[name_key],
4026                        _ => &self.names[name_key],
4027                    };
4028                    let ty_name = TypeContext {
4029                        handle: ty,
4030                        gctx: module.to_ctx(),
4031                        names: &self.names,
4032                        access: crate::StorageAccess::empty(),
4033                        binding: None,
4034                        first_time: false,
4035                    };
4036                    let resolved = options.resolve_local_binding(binding, in_mode)?;
4037                    write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
4038                    resolved.try_fmt(&mut self.out)?;
4039                    writeln!(self.out, ";")?;
4040                }
4041                writeln!(self.out, "}};")?;
4042            }
4043
4044            // Define a struct type named for the return value, if any, named
4045            // `<fun>Output`.
4046            let stage_out_name = format!("{fun_name}Output");
4047            let result_member_name = self.namer.call("member");
4048            let result_type_name = match fun.result {
4049                Some(ref result) => {
4050                    let mut result_members = Vec::new();
4051                    if let crate::TypeInner::Struct { ref members, .. } =
4052                        module.types[result.ty].inner
4053                    {
4054                        for (member_index, member) in members.iter().enumerate() {
4055                            result_members.push((
4056                                &self.names[&NameKey::StructMember(result.ty, member_index as u32)],
4057                                member.ty,
4058                                member.binding.as_ref(),
4059                            ));
4060                        }
4061                    } else {
4062                        result_members.push((
4063                            &result_member_name,
4064                            result.ty,
4065                            result.binding.as_ref(),
4066                        ));
4067                    }
4068
4069                    writeln!(self.out, "struct {stage_out_name} {{")?;
4070                    let mut has_point_size = false;
4071                    for (name, ty, binding) in result_members {
4072                        let ty_name = TypeContext {
4073                            handle: ty,
4074                            gctx: module.to_ctx(),
4075                            names: &self.names,
4076                            access: crate::StorageAccess::empty(),
4077                            binding: None,
4078                            first_time: true,
4079                        };
4080                        let binding = binding.ok_or_else(|| {
4081                            Error::GenericValidation("Expected binding, got None".into())
4082                        })?;
4083
4084                        if let crate::Binding::BuiltIn(crate::BuiltIn::PointSize) = *binding {
4085                            has_point_size = true;
4086                            if !pipeline_options.allow_and_force_point_size {
4087                                continue;
4088                            }
4089                        }
4090
4091                        let array_len = match module.types[ty].inner {
4092                            crate::TypeInner::Array {
4093                                size: crate::ArraySize::Constant(size),
4094                                ..
4095                            } => Some(size),
4096                            _ => None,
4097                        };
4098                        let resolved = options.resolve_local_binding(binding, out_mode)?;
4099                        write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
4100                        if let Some(array_len) = array_len {
4101                            write!(self.out, " [{array_len}]")?;
4102                        }
4103                        resolved.try_fmt(&mut self.out)?;
4104                        writeln!(self.out, ";")?;
4105                    }
4106
4107                    if pipeline_options.allow_and_force_point_size
4108                        && ep.stage == crate::ShaderStage::Vertex
4109                        && !has_point_size
4110                    {
4111                        // inject the point size output last
4112                        writeln!(
4113                            self.out,
4114                            "{}float _point_size [[point_size]];",
4115                            back::INDENT
4116                        )?;
4117                    }
4118                    writeln!(self.out, "}};")?;
4119                    &stage_out_name
4120                }
4121                None => "void",
4122            };
4123
4124            // Write the entry point function's name, and begin its argument list.
4125            writeln!(self.out, "{em_str} {result_type_name} {fun_name}(")?;
4126            let mut is_first_argument = true;
4127
4128            // If we have produced a struct holding the `EntryPoint`'s
4129            // `Function`'s arguments' varyings, pass that struct first.
4130            if has_varyings {
4131                writeln!(
4132                    self.out,
4133                    "  {stage_in_name} {varyings_member_name} [[stage_in]]"
4134                )?;
4135                is_first_argument = false;
4136            }
4137
4138            let mut local_invocation_id = None;
4139
4140            // Then pass the remaining arguments not included in the varyings
4141            // struct.
4142            for &(ref name_key, ty, binding) in flattened_arguments.iter() {
4143                let binding = match binding {
4144                    Some(binding @ &crate::Binding::BuiltIn { .. }) => binding,
4145                    _ => continue,
4146                };
4147                let name = match *name_key {
4148                    NameKey::StructMember(..) => &flattened_member_names[name_key],
4149                    _ => &self.names[name_key],
4150                };
4151
4152                if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationId) {
4153                    local_invocation_id = Some(name_key);
4154                }
4155
4156                let ty_name = TypeContext {
4157                    handle: ty,
4158                    gctx: module.to_ctx(),
4159                    names: &self.names,
4160                    access: crate::StorageAccess::empty(),
4161                    binding: None,
4162                    first_time: false,
4163                };
4164                let resolved = options.resolve_local_binding(binding, in_mode)?;
4165                let separator = if is_first_argument {
4166                    is_first_argument = false;
4167                    ' '
4168                } else {
4169                    ','
4170                };
4171                write!(self.out, "{separator} {ty_name} {name}")?;
4172                resolved.try_fmt(&mut self.out)?;
4173                writeln!(self.out)?;
4174            }
4175
4176            let need_workgroup_variables_initialization =
4177                self.need_workgroup_variables_initialization(options, ep, module, fun_info);
4178
4179            if need_workgroup_variables_initialization && local_invocation_id.is_none() {
4180                let separator = if is_first_argument {
4181                    is_first_argument = false;
4182                    ' '
4183                } else {
4184                    ','
4185                };
4186                writeln!(
4187                    self.out,
4188                    "{separator} {NAMESPACE}::uint3 __local_invocation_id [[thread_position_in_threadgroup]]"
4189                )?;
4190            }
4191
4192            // Those global variables used by this entry point and its callees
4193            // get passed as arguments. `Private` globals are an exception, they
4194            // don't outlive this invocation, so we declare them below as locals
4195            // within the entry point.
4196            for (handle, var) in module.global_variables.iter() {
4197                let usage = fun_info[handle];
4198                if usage.is_empty() || var.space == crate::AddressSpace::Private {
4199                    continue;
4200                }
4201
4202                if options.lang_version < (1, 2) {
4203                    match var.space {
4204                        // This restriction is not documented in the MSL spec
4205                        // but validation will fail if it is not upheld.
4206                        //
4207                        // We infer the required version from the "Function
4208                        // Buffer Read-Writes" section of [what's new], where
4209                        // the feature sets listed correspond with the ones
4210                        // supporting MSL 1.2.
4211                        //
4212                        // [what's new]: https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html
4213                        crate::AddressSpace::Storage { access }
4214                            if access.contains(crate::StorageAccess::STORE)
4215                                && ep.stage == crate::ShaderStage::Fragment =>
4216                        {
4217                            return Err(Error::UnsupportedWriteableStorageBuffer)
4218                        }
4219                        crate::AddressSpace::Handle => {
4220                            match module.types[var.ty].inner {
4221                                crate::TypeInner::Image {
4222                                    class: crate::ImageClass::Storage { access, .. },
4223                                    ..
4224                                } => {
4225                                    // This restriction is not documented in the MSL spec
4226                                    // but validation will fail if it is not upheld.
4227                                    //
4228                                    // We infer the required version from the "Function
4229                                    // Texture Read-Writes" section of [what's new], where
4230                                    // the feature sets listed correspond with the ones
4231                                    // supporting MSL 1.2.
4232                                    //
4233                                    // [what's new]: https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html
4234                                    if access.contains(crate::StorageAccess::STORE)
4235                                        && (ep.stage == crate::ShaderStage::Vertex
4236                                            || ep.stage == crate::ShaderStage::Fragment)
4237                                    {
4238                                        return Err(Error::UnsupportedWriteableStorageTexture(
4239                                            ep.stage,
4240                                        ));
4241                                    }
4242
4243                                    if access.contains(
4244                                        crate::StorageAccess::LOAD | crate::StorageAccess::STORE,
4245                                    ) {
4246                                        return Err(Error::UnsupportedRWStorageTexture);
4247                                    }
4248                                }
4249                                _ => {}
4250                            }
4251                        }
4252                        _ => {}
4253                    }
4254                }
4255
4256                // Check min MSL version for binding arrays
4257                match var.space {
4258                    crate::AddressSpace::Handle => match module.types[var.ty].inner {
4259                        crate::TypeInner::BindingArray { base, .. } => {
4260                            match module.types[base].inner {
4261                                crate::TypeInner::Sampler { .. } => {
4262                                    if options.lang_version < (2, 0) {
4263                                        return Err(Error::UnsupportedArrayOf(
4264                                            "samplers".to_string(),
4265                                        ));
4266                                    }
4267                                }
4268                                crate::TypeInner::Image { class, .. } => match class {
4269                                    crate::ImageClass::Sampled { .. }
4270                                    | crate::ImageClass::Depth { .. }
4271                                    | crate::ImageClass::Storage {
4272                                        access: crate::StorageAccess::LOAD,
4273                                        ..
4274                                    } => {
4275                                        // Array of textures since:
4276                                        // - iOS: Metal 1.2 (check depends on https://github.com/gfx-rs/naga/issues/2164)
4277                                        // - macOS: Metal 2
4278
4279                                        if options.lang_version < (2, 0) {
4280                                            return Err(Error::UnsupportedArrayOf(
4281                                                "textures".to_string(),
4282                                            ));
4283                                        }
4284                                    }
4285                                    crate::ImageClass::Storage {
4286                                        access: crate::StorageAccess::STORE,
4287                                        ..
4288                                    } => {
4289                                        // Array of write-only textures since:
4290                                        // - iOS: Metal 2.2 (check depends on https://github.com/gfx-rs/naga/issues/2164)
4291                                        // - macOS: Metal 2
4292
4293                                        if options.lang_version < (2, 0) {
4294                                            return Err(Error::UnsupportedArrayOf(
4295                                                "write-only textures".to_string(),
4296                                            ));
4297                                        }
4298                                    }
4299                                    crate::ImageClass::Storage { .. } => {
4300                                        return Err(Error::UnsupportedArrayOf(
4301                                            "read-write textures".to_string(),
4302                                        ));
4303                                    }
4304                                },
4305                                _ => {
4306                                    return Err(Error::UnsupportedArrayOfType(base));
4307                                }
4308                            }
4309                        }
4310                        _ => {}
4311                    },
4312                    _ => {}
4313                }
4314
4315                // the resolves have already been checked for `!fake_missing_bindings` case
4316                let resolved = match var.space {
4317                    crate::AddressSpace::PushConstant => options.resolve_push_constants(ep).ok(),
4318                    crate::AddressSpace::WorkGroup => None,
4319                    _ => options
4320                        .resolve_resource_binding(ep, var.binding.as_ref().unwrap())
4321                        .ok(),
4322                };
4323                if let Some(ref resolved) = resolved {
4324                    // Inline samplers are be defined in the EP body
4325                    if resolved.as_inline_sampler(options).is_some() {
4326                        continue;
4327                    }
4328                }
4329
4330                let tyvar = TypedGlobalVariable {
4331                    module,
4332                    names: &self.names,
4333                    handle,
4334                    usage,
4335                    binding: resolved.as_ref(),
4336                    reference: true,
4337                };
4338                let separator = if is_first_argument {
4339                    is_first_argument = false;
4340                    ' '
4341                } else {
4342                    ','
4343                };
4344                write!(self.out, "{separator} ")?;
4345                tyvar.try_fmt(&mut self.out)?;
4346                if let Some(resolved) = resolved {
4347                    resolved.try_fmt(&mut self.out)?;
4348                }
4349                if let Some(value) = var.init {
4350                    write!(self.out, " = ")?;
4351                    self.put_const_expression(value, module, mod_info)?;
4352                }
4353                writeln!(self.out)?;
4354            }
4355
4356            // If this entry uses any variable-length arrays, their sizes are
4357            // passed as a final struct-typed argument.
4358            if supports_array_length {
4359                // this is checked earlier
4360                let resolved = options.resolve_sizes_buffer(ep).unwrap();
4361                let separator = if module.global_variables.is_empty() {
4362                    ' '
4363                } else {
4364                    ','
4365                };
4366                write!(
4367                    self.out,
4368                    "{separator} constant _mslBufferSizes& _buffer_sizes",
4369                )?;
4370                resolved.try_fmt(&mut self.out)?;
4371                writeln!(self.out)?;
4372            }
4373
4374            // end of the entry point argument list
4375            writeln!(self.out, ") {{")?;
4376
4377            if need_workgroup_variables_initialization {
4378                self.write_workgroup_variables_initialization(
4379                    module,
4380                    mod_info,
4381                    fun_info,
4382                    local_invocation_id,
4383                )?;
4384            }
4385
4386            // Metal doesn't support private mutable variables outside of functions,
4387            // so we put them here, just like the locals.
4388            for (handle, var) in module.global_variables.iter() {
4389                let usage = fun_info[handle];
4390                if usage.is_empty() {
4391                    continue;
4392                }
4393                if var.space == crate::AddressSpace::Private {
4394                    let tyvar = TypedGlobalVariable {
4395                        module,
4396                        names: &self.names,
4397                        handle,
4398                        usage,
4399                        binding: None,
4400                        reference: false,
4401                    };
4402                    write!(self.out, "{}", back::INDENT)?;
4403                    tyvar.try_fmt(&mut self.out)?;
4404                    match var.init {
4405                        Some(value) => {
4406                            write!(self.out, " = ")?;
4407                            self.put_const_expression(value, module, mod_info)?;
4408                            writeln!(self.out, ";")?;
4409                        }
4410                        None => {
4411                            writeln!(self.out, " = {{}};")?;
4412                        }
4413                    };
4414                } else if let Some(ref binding) = var.binding {
4415                    // write an inline sampler
4416                    let resolved = options.resolve_resource_binding(ep, binding).unwrap();
4417                    if let Some(sampler) = resolved.as_inline_sampler(options) {
4418                        let name = &self.names[&NameKey::GlobalVariable(handle)];
4419                        writeln!(
4420                            self.out,
4421                            "{}constexpr {}::sampler {}(",
4422                            back::INDENT,
4423                            NAMESPACE,
4424                            name
4425                        )?;
4426                        self.put_inline_sampler_properties(back::Level(2), sampler)?;
4427                        writeln!(self.out, "{});", back::INDENT)?;
4428                    }
4429                }
4430            }
4431
4432            // Now take the arguments that we gathered into structs, and the
4433            // structs that we flattened into arguments, and emit local
4434            // variables with initializers that put everything back the way the
4435            // body code expects.
4436            //
4437            // If we had to generate fresh names for struct members passed as
4438            // arguments, be sure to use those names when rebuilding the struct.
4439            //
4440            // "Each day, I change some zeros to ones, and some ones to zeros.
4441            // The rest, I leave alone."
4442            for (arg_index, arg) in fun.arguments.iter().enumerate() {
4443                let arg_name =
4444                    &self.names[&NameKey::EntryPointArgument(ep_index as _, arg_index as u32)];
4445                match module.types[arg.ty].inner {
4446                    crate::TypeInner::Struct { ref members, .. } => {
4447                        let struct_name = &self.names[&NameKey::Type(arg.ty)];
4448                        write!(
4449                            self.out,
4450                            "{}const {} {} = {{ ",
4451                            back::INDENT,
4452                            struct_name,
4453                            arg_name
4454                        )?;
4455                        for (member_index, member) in members.iter().enumerate() {
4456                            let key = NameKey::StructMember(arg.ty, member_index as u32);
4457                            let name = &flattened_member_names[&key];
4458                            if member_index != 0 {
4459                                write!(self.out, ", ")?;
4460                            }
4461                            // insert padding initialization, if needed
4462                            if self
4463                                .struct_member_pads
4464                                .contains(&(arg.ty, member_index as u32))
4465                            {
4466                                write!(self.out, "{{}}, ")?;
4467                            }
4468                            if let Some(crate::Binding::Location { .. }) = member.binding {
4469                                write!(self.out, "{varyings_member_name}.")?;
4470                            }
4471                            write!(self.out, "{name}")?;
4472                        }
4473                        writeln!(self.out, " }};")?;
4474                    }
4475                    _ => {
4476                        if let Some(crate::Binding::Location { .. }) = arg.binding {
4477                            writeln!(
4478                                self.out,
4479                                "{}const auto {} = {}.{};",
4480                                back::INDENT,
4481                                arg_name,
4482                                varyings_member_name,
4483                                arg_name
4484                            )?;
4485                        }
4486                    }
4487                }
4488            }
4489
4490            let guarded_indices =
4491                index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
4492
4493            let context = StatementContext {
4494                expression: ExpressionContext {
4495                    function: fun,
4496                    origin: FunctionOrigin::EntryPoint(ep_index as _),
4497                    info: fun_info,
4498                    lang_version: options.lang_version,
4499                    policies: options.bounds_check_policies,
4500                    guarded_indices,
4501                    module,
4502                    mod_info,
4503                    pipeline_options,
4504                },
4505                result_struct: Some(&stage_out_name),
4506            };
4507
4508            // Finally, declare all the local variables that we need
4509            //TODO: we can postpone this till the relevant expressions are emitted
4510            for (local_handle, local) in fun.local_variables.iter() {
4511                let name = &self.names[&NameKey::EntryPointLocal(ep_index as _, local_handle)];
4512                let ty_name = TypeContext {
4513                    handle: local.ty,
4514                    gctx: module.to_ctx(),
4515                    names: &self.names,
4516                    access: crate::StorageAccess::empty(),
4517                    binding: None,
4518                    first_time: false,
4519                };
4520                write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
4521                match local.init {
4522                    Some(value) => {
4523                        write!(self.out, " = ")?;
4524                        self.put_expression(value, &context.expression, true)?;
4525                    }
4526                    None => {
4527                        write!(self.out, " = {{}}")?;
4528                    }
4529                };
4530                writeln!(self.out, ";")?;
4531            }
4532
4533            self.update_expressions_to_bake(fun, fun_info, &context.expression);
4534            self.put_block(back::Level(1), &fun.body, &context)?;
4535            writeln!(self.out, "}}")?;
4536            if ep_index + 1 != module.entry_points.len() {
4537                writeln!(self.out)?;
4538            }
4539            self.named_expressions.clear();
4540        }
4541
4542        Ok(info)
4543    }
4544
4545    fn write_barrier(&mut self, flags: crate::Barrier, level: back::Level) -> BackendResult {
4546        // Note: OR-ring bitflags requires `__HAVE_MEMFLAG_OPERATORS__`,
4547        // so we try to avoid it here.
4548        if flags.is_empty() {
4549            writeln!(
4550                self.out,
4551                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_none);",
4552            )?;
4553        }
4554        if flags.contains(crate::Barrier::STORAGE) {
4555            writeln!(
4556                self.out,
4557                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_device);",
4558            )?;
4559        }
4560        if flags.contains(crate::Barrier::WORK_GROUP) {
4561            writeln!(
4562                self.out,
4563                "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
4564            )?;
4565        }
4566        if flags.contains(crate::Barrier::SUB_GROUP) {
4567            writeln!(
4568                self.out,
4569                "{level}{NAMESPACE}::simdgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
4570            )?;
4571        }
4572        Ok(())
4573    }
4574}
4575
4576/// Initializing workgroup variables is more tricky for Metal because we have to deal
4577/// with atomics at the type-level (which don't have a copy constructor).
4578mod workgroup_mem_init {
4579    use crate::EntryPoint;
4580
4581    use super::*;
4582
4583    enum Access {
4584        GlobalVariable(Handle<crate::GlobalVariable>),
4585        StructMember(Handle<crate::Type>, u32),
4586        Array(usize),
4587    }
4588
4589    impl Access {
4590        fn write<W: Write>(
4591            &self,
4592            writer: &mut W,
4593            names: &FastHashMap<NameKey, String>,
4594        ) -> Result<(), core::fmt::Error> {
4595            match *self {
4596                Access::GlobalVariable(handle) => {
4597                    write!(writer, "{}", &names[&NameKey::GlobalVariable(handle)])
4598                }
4599                Access::StructMember(handle, index) => {
4600                    write!(writer, ".{}", &names[&NameKey::StructMember(handle, index)])
4601                }
4602                Access::Array(depth) => write!(writer, ".{WRAPPED_ARRAY_FIELD}[__i{depth}]"),
4603            }
4604        }
4605    }
4606
4607    struct AccessStack {
4608        stack: Vec<Access>,
4609        array_depth: usize,
4610    }
4611
4612    impl AccessStack {
4613        const fn new() -> Self {
4614            Self {
4615                stack: Vec::new(),
4616                array_depth: 0,
4617            }
4618        }
4619
4620        fn enter_array<R>(&mut self, cb: impl FnOnce(&mut Self, usize) -> R) -> R {
4621            let array_depth = self.array_depth;
4622            self.stack.push(Access::Array(array_depth));
4623            self.array_depth += 1;
4624            let res = cb(self, array_depth);
4625            self.stack.pop();
4626            self.array_depth -= 1;
4627            res
4628        }
4629
4630        fn enter<R>(&mut self, new: Access, cb: impl FnOnce(&mut Self) -> R) -> R {
4631            self.stack.push(new);
4632            let res = cb(self);
4633            self.stack.pop();
4634            res
4635        }
4636
4637        fn write<W: Write>(
4638            &self,
4639            writer: &mut W,
4640            names: &FastHashMap<NameKey, String>,
4641        ) -> Result<(), core::fmt::Error> {
4642            for next in self.stack.iter() {
4643                next.write(writer, names)?;
4644            }
4645            Ok(())
4646        }
4647    }
4648
4649    impl<W: Write> Writer<W> {
4650        pub(super) fn need_workgroup_variables_initialization(
4651            &mut self,
4652            options: &Options,
4653            ep: &EntryPoint,
4654            module: &crate::Module,
4655            fun_info: &valid::FunctionInfo,
4656        ) -> bool {
4657            options.zero_initialize_workgroup_memory
4658                && ep.stage == crate::ShaderStage::Compute
4659                && module.global_variables.iter().any(|(handle, var)| {
4660                    !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
4661                })
4662        }
4663
4664        pub(super) fn write_workgroup_variables_initialization(
4665            &mut self,
4666            module: &crate::Module,
4667            module_info: &valid::ModuleInfo,
4668            fun_info: &valid::FunctionInfo,
4669            local_invocation_id: Option<&NameKey>,
4670        ) -> BackendResult {
4671            let level = back::Level(1);
4672
4673            writeln!(
4674                self.out,
4675                "{}if ({}::all({} == {}::uint3(0u))) {{",
4676                level,
4677                NAMESPACE,
4678                local_invocation_id
4679                    .map(|name_key| self.names[name_key].as_str())
4680                    .unwrap_or("__local_invocation_id"),
4681                NAMESPACE,
4682            )?;
4683
4684            let mut access_stack = AccessStack::new();
4685
4686            let vars = module.global_variables.iter().filter(|&(handle, var)| {
4687                !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
4688            });
4689
4690            for (handle, var) in vars {
4691                access_stack.enter(Access::GlobalVariable(handle), |access_stack| {
4692                    self.write_workgroup_variable_initialization(
4693                        module,
4694                        module_info,
4695                        var.ty,
4696                        access_stack,
4697                        level.next(),
4698                    )
4699                })?;
4700            }
4701
4702            writeln!(self.out, "{level}}}")?;
4703            self.write_barrier(crate::Barrier::WORK_GROUP, level)
4704        }
4705
4706        fn write_workgroup_variable_initialization(
4707            &mut self,
4708            module: &crate::Module,
4709            module_info: &valid::ModuleInfo,
4710            ty: Handle<crate::Type>,
4711            access_stack: &mut AccessStack,
4712            level: back::Level,
4713        ) -> BackendResult {
4714            if module_info[ty].contains(valid::TypeFlags::CONSTRUCTIBLE) {
4715                write!(self.out, "{level}")?;
4716                access_stack.write(&mut self.out, &self.names)?;
4717                writeln!(self.out, " = {{}};")?;
4718            } else {
4719                match module.types[ty].inner {
4720                    crate::TypeInner::Atomic { .. } => {
4721                        write!(
4722                            self.out,
4723                            "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
4724                        )?;
4725                        access_stack.write(&mut self.out, &self.names)?;
4726                        writeln!(self.out, ", 0, {NAMESPACE}::memory_order_relaxed);")?;
4727                    }
4728                    crate::TypeInner::Array { base, size, .. } => {
4729                        let count = match size.to_indexable_length(module).expect("Bad array size")
4730                        {
4731                            proc::IndexableLength::Known(count) => count,
4732                            proc::IndexableLength::Dynamic => unreachable!(),
4733                        };
4734
4735                        access_stack.enter_array(|access_stack, array_depth| {
4736                            writeln!(
4737                                self.out,
4738                                "{level}for (int __i{array_depth} = 0; __i{array_depth} < {count}; __i{array_depth}++) {{"
4739                            )?;
4740                            self.write_workgroup_variable_initialization(
4741                                module,
4742                                module_info,
4743                                base,
4744                                access_stack,
4745                                level.next(),
4746                            )?;
4747                            writeln!(self.out, "{level}}}")?;
4748                            BackendResult::Ok(())
4749                        })?;
4750                    }
4751                    crate::TypeInner::Struct { ref members, .. } => {
4752                        for (index, member) in members.iter().enumerate() {
4753                            access_stack.enter(
4754                                Access::StructMember(ty, index as u32),
4755                                |access_stack| {
4756                                    self.write_workgroup_variable_initialization(
4757                                        module,
4758                                        module_info,
4759                                        member.ty,
4760                                        access_stack,
4761                                        level,
4762                                    )
4763                                },
4764                            )?;
4765                        }
4766                    }
4767                    _ => unreachable!(),
4768                }
4769            }
4770
4771            Ok(())
4772        }
4773    }
4774}
4775
4776#[test]
4777fn test_stack_size() {
4778    use crate::valid::{Capabilities, ValidationFlags};
4779    // create a module with at least one expression nested
4780    let mut module = crate::Module::default();
4781    let mut fun = crate::Function::default();
4782    let const_expr = fun.expressions.append(
4783        crate::Expression::Literal(crate::Literal::F32(1.0)),
4784        Default::default(),
4785    );
4786    let nested_expr = fun.expressions.append(
4787        crate::Expression::Unary {
4788            op: crate::UnaryOperator::Negate,
4789            expr: const_expr,
4790        },
4791        Default::default(),
4792    );
4793    fun.body.push(
4794        crate::Statement::Emit(fun.expressions.range_from(1)),
4795        Default::default(),
4796    );
4797    fun.body.push(
4798        crate::Statement::If {
4799            condition: nested_expr,
4800            accept: crate::Block::new(),
4801            reject: crate::Block::new(),
4802        },
4803        Default::default(),
4804    );
4805    let _ = module.functions.append(fun, Default::default());
4806    // analyse the module
4807    let info = crate::valid::Validator::new(ValidationFlags::empty(), Capabilities::empty())
4808        .validate(&module)
4809        .unwrap();
4810    // process the module
4811    let mut writer = Writer::new(String::new());
4812    writer
4813        .write(&module, &info, &Default::default(), &Default::default())
4814        .unwrap();
4815
4816    {
4817        // check expression stack
4818        let mut addresses_start = usize::MAX;
4819        let mut addresses_end = 0usize;
4820        for pointer in writer.put_expression_stack_pointers {
4821            addresses_start = addresses_start.min(pointer as usize);
4822            addresses_end = addresses_end.max(pointer as usize);
4823        }
4824        let stack_size = addresses_end - addresses_start;
4825        // check the size (in debug only)
4826        // last observed macOS value: 20528 (CI)
4827        if !(11000..=25000).contains(&stack_size) {
4828            panic!("`put_expression` stack size {stack_size} has changed!");
4829        }
4830    }
4831
4832    {
4833        // check block stack
4834        let mut addresses_start = usize::MAX;
4835        let mut addresses_end = 0usize;
4836        for pointer in writer.put_block_stack_pointers {
4837            addresses_start = addresses_start.min(pointer as usize);
4838            addresses_end = addresses_end.max(pointer as usize);
4839        }
4840        let stack_size = addresses_end - addresses_start;
4841        // check the size (in debug only)
4842        // last observed macOS value: 22256 (CI)
4843        if !(15000..=25000).contains(&stack_size) {
4844            panic!("`put_block` stack size {stack_size} has changed!");
4845        }
4846    }
4847}