1use super::{sampler as sm, Error, LocationMode, Options, PipelineOptions, TranslationInfo};
2use crate::{
3 arena::Handle,
4 back,
5 proc::index,
6 proc::{self, NameKey, TypeResolution},
7 valid, FastHashMap, FastHashSet,
8};
9use bit_set::BitSet;
10use std::{
11 fmt::{Display, Error as FmtError, Formatter, Write},
12 iter,
13};
14
15type BackendResult = Result<(), Error>;
17
18const NAMESPACE: &str = "metal";
19const WRAPPED_ARRAY_FIELD: &str = "inner";
23const ATOMIC_REFERENCE: &str = "&";
27
28const RT_NAMESPACE: &str = "metal::raytracing";
29const RAY_QUERY_TYPE: &str = "_RayQuery";
30const RAY_QUERY_FIELD_INTERSECTOR: &str = "intersector";
31const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection";
32const RAY_QUERY_FIELD_READY: &str = "ready";
33const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type";
34
35pub(crate) const MODF_FUNCTION: &str = "naga_modf";
36pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
37
38fn put_numeric_type(
47 out: &mut impl Write,
48 scalar: crate::Scalar,
49 sizes: &[crate::VectorSize],
50) -> Result<(), FmtError> {
51 match (scalar, sizes) {
52 (scalar, &[]) => {
53 write!(out, "{}", scalar.to_msl_name())
54 }
55 (scalar, &[rows]) => {
56 write!(
57 out,
58 "{}::{}{}",
59 NAMESPACE,
60 scalar.to_msl_name(),
61 back::vector_size_str(rows)
62 )
63 }
64 (scalar, &[rows, columns]) => {
65 write!(
66 out,
67 "{}::{}{}x{}",
68 NAMESPACE,
69 scalar.to_msl_name(),
70 back::vector_size_str(columns),
71 back::vector_size_str(rows)
72 )
73 }
74 (_, _) => Ok(()), }
76}
77
78const CLAMPED_LOD_LOAD_PREFIX: &str = "clamped_lod_e";
80
81struct TypeContext<'a> {
82 handle: Handle<crate::Type>,
83 gctx: proc::GlobalCtx<'a>,
84 names: &'a FastHashMap<NameKey, String>,
85 access: crate::StorageAccess,
86 binding: Option<&'a super::ResolvedBinding>,
87 first_time: bool,
88}
89
90impl<'a> Display for TypeContext<'a> {
91 fn fmt(&self, out: &mut Formatter<'_>) -> Result<(), FmtError> {
92 let ty = &self.gctx.types[self.handle];
93 if ty.needs_alias() && !self.first_time {
94 let name = &self.names[&NameKey::Type(self.handle)];
95 return write!(out, "{name}");
96 }
97
98 match ty.inner {
99 crate::TypeInner::Scalar(scalar) => put_numeric_type(out, scalar, &[]),
100 crate::TypeInner::Atomic(scalar) => {
101 write!(out, "{}::atomic_{}", NAMESPACE, scalar.to_msl_name())
102 }
103 crate::TypeInner::Vector { size, scalar } => put_numeric_type(out, scalar, &[size]),
104 crate::TypeInner::Matrix { columns, rows, .. } => {
105 put_numeric_type(out, crate::Scalar::F32, &[rows, columns])
106 }
107 crate::TypeInner::Pointer { base, space } => {
108 let sub = Self {
109 handle: base,
110 first_time: false,
111 ..*self
112 };
113 let space_name = match space.to_msl_name() {
114 Some(name) => name,
115 None => return Ok(()),
116 };
117 write!(out, "{space_name} {sub}&")
118 }
119 crate::TypeInner::ValuePointer {
120 size,
121 scalar,
122 space,
123 } => {
124 match space.to_msl_name() {
125 Some(name) => write!(out, "{name} ")?,
126 None => return Ok(()),
127 };
128 match size {
129 Some(rows) => put_numeric_type(out, scalar, &[rows])?,
130 None => put_numeric_type(out, scalar, &[])?,
131 };
132
133 write!(out, "&")
134 }
135 crate::TypeInner::Array { base, .. } => {
136 let sub = Self {
137 handle: base,
138 first_time: false,
139 ..*self
140 };
141 write!(out, "{sub}")
144 }
145 crate::TypeInner::Struct { .. } => unreachable!(),
146 crate::TypeInner::Image {
147 dim,
148 arrayed,
149 class,
150 } => {
151 let dim_str = match dim {
152 crate::ImageDimension::D1 => "1d",
153 crate::ImageDimension::D2 => "2d",
154 crate::ImageDimension::D3 => "3d",
155 crate::ImageDimension::Cube => "cube",
156 };
157 let (texture_str, msaa_str, kind, access) = match class {
158 crate::ImageClass::Sampled { kind, multi } => {
159 let (msaa_str, access) = if multi {
160 ("_ms", "read")
161 } else {
162 ("", "sample")
163 };
164 ("texture", msaa_str, kind, access)
165 }
166 crate::ImageClass::Depth { multi } => {
167 let (msaa_str, access) = if multi {
168 ("_ms", "read")
169 } else {
170 ("", "sample")
171 };
172 ("depth", msaa_str, crate::ScalarKind::Float, access)
173 }
174 crate::ImageClass::Storage { format, .. } => {
175 let access = if self
176 .access
177 .contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE)
178 {
179 "read_write"
180 } else if self.access.contains(crate::StorageAccess::STORE) {
181 "write"
182 } else if self.access.contains(crate::StorageAccess::LOAD) {
183 "read"
184 } else {
185 log::warn!(
186 "Storage access for {:?} (name '{}'): {:?}",
187 self.handle,
188 ty.name.as_deref().unwrap_or_default(),
189 self.access
190 );
191 unreachable!("module is not valid");
192 };
193 ("texture", "", format.into(), access)
194 }
195 };
196 let base_name = crate::Scalar { kind, width: 4 }.to_msl_name();
197 let array_str = if arrayed { "_array" } else { "" };
198 write!(
199 out,
200 "{NAMESPACE}::{texture_str}{dim_str}{msaa_str}{array_str}<{base_name}, {NAMESPACE}::access::{access}>",
201 )
202 }
203 crate::TypeInner::Sampler { comparison: _ } => {
204 write!(out, "{NAMESPACE}::sampler")
205 }
206 crate::TypeInner::AccelerationStructure => {
207 write!(out, "{RT_NAMESPACE}::instance_acceleration_structure")
208 }
209 crate::TypeInner::RayQuery => {
210 write!(out, "{RAY_QUERY_TYPE}")
211 }
212 crate::TypeInner::BindingArray { base, size } => {
213 let base_tyname = Self {
214 handle: base,
215 first_time: false,
216 ..*self
217 };
218
219 if let Some(&super::ResolvedBinding::Resource(super::BindTarget {
220 binding_array_size: Some(override_size),
221 ..
222 })) = self.binding
223 {
224 write!(out, "{NAMESPACE}::array<{base_tyname}, {override_size}>")
225 } else if let crate::ArraySize::Constant(size) = size {
226 write!(out, "{NAMESPACE}::array<{base_tyname}, {size}>")
227 } else {
228 unreachable!("metal requires all arrays be constant sized");
229 }
230 }
231 }
232 }
233}
234
235struct TypedGlobalVariable<'a> {
236 module: &'a crate::Module,
237 names: &'a FastHashMap<NameKey, String>,
238 handle: Handle<crate::GlobalVariable>,
239 usage: valid::GlobalUse,
240 binding: Option<&'a super::ResolvedBinding>,
241 reference: bool,
242}
243
244impl<'a> TypedGlobalVariable<'a> {
245 fn try_fmt<W: Write>(&self, out: &mut W) -> BackendResult {
246 let var = &self.module.global_variables[self.handle];
247 let name = &self.names[&NameKey::GlobalVariable(self.handle)];
248
249 let storage_access = match var.space {
250 crate::AddressSpace::Storage { access } => access,
251 _ => match self.module.types[var.ty].inner {
252 crate::TypeInner::Image {
253 class: crate::ImageClass::Storage { access, .. },
254 ..
255 } => access,
256 crate::TypeInner::BindingArray { base, .. } => {
257 match self.module.types[base].inner {
258 crate::TypeInner::Image {
259 class: crate::ImageClass::Storage { access, .. },
260 ..
261 } => access,
262 _ => crate::StorageAccess::default(),
263 }
264 }
265 _ => crate::StorageAccess::default(),
266 },
267 };
268 let ty_name = TypeContext {
269 handle: var.ty,
270 gctx: self.module.to_ctx(),
271 names: self.names,
272 access: storage_access,
273 binding: self.binding,
274 first_time: false,
275 };
276
277 let (space, access, reference) = match var.space.to_msl_name() {
278 Some(space) if self.reference => {
279 let access = if var.space.needs_access_qualifier()
280 && !self.usage.contains(valid::GlobalUse::WRITE)
281 {
282 "const"
283 } else {
284 ""
285 };
286 (space, access, "&")
287 }
288 _ => ("", "", ""),
289 };
290
291 Ok(write!(
292 out,
293 "{}{}{}{}{}{} {}",
294 space,
295 if space.is_empty() { "" } else { " " },
296 ty_name,
297 if access.is_empty() { "" } else { " " },
298 access,
299 reference,
300 name,
301 )?)
302 }
303}
304
305pub struct Writer<W> {
306 out: W,
307 names: FastHashMap<NameKey, String>,
308 named_expressions: crate::NamedExpressions,
309 need_bake_expressions: back::NeedBakeExpressions,
311 namer: proc::Namer,
312 #[cfg(test)]
313 put_expression_stack_pointers: FastHashSet<*const ()>,
314 #[cfg(test)]
315 put_block_stack_pointers: FastHashSet<*const ()>,
316 struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>,
319}
320
321impl crate::Scalar {
322 fn to_msl_name(self) -> &'static str {
323 use crate::ScalarKind as Sk;
324 match self {
325 Self {
326 kind: Sk::Float,
327 width: _,
328 } => "float",
329 Self {
330 kind: Sk::Sint,
331 width: 4,
332 } => "int",
333 Self {
334 kind: Sk::Uint,
335 width: 4,
336 } => "uint",
337 Self {
338 kind: Sk::Sint,
339 width: 8,
340 } => "long",
341 Self {
342 kind: Sk::Uint,
343 width: 8,
344 } => "ulong",
345 Self {
346 kind: Sk::Bool,
347 width: _,
348 } => "bool",
349 Self {
350 kind: Sk::AbstractInt | Sk::AbstractFloat,
351 width: _,
352 } => unreachable!("Found Abstract scalar kind"),
353 _ => unreachable!("Unsupported scalar kind: {:?}", self),
354 }
355 }
356}
357
358const fn separate(need_separator: bool) -> &'static str {
359 if need_separator {
360 ","
361 } else {
362 ""
363 }
364}
365
366fn should_pack_struct_member(
367 members: &[crate::StructMember],
368 span: u32,
369 index: usize,
370 module: &crate::Module,
371) -> Option<crate::Scalar> {
372 let member = &members[index];
373
374 let ty_inner = &module.types[member.ty].inner;
375 let last_offset = member.offset + ty_inner.size(module.to_ctx());
376 let next_offset = match members.get(index + 1) {
377 Some(next) => next.offset,
378 None => span,
379 };
380 let is_tight = next_offset == last_offset;
381
382 match *ty_inner {
383 crate::TypeInner::Vector {
384 size: crate::VectorSize::Tri,
385 scalar: scalar @ crate::Scalar { width: 4, .. },
386 } if is_tight => Some(scalar),
387 _ => None,
388 }
389}
390
391fn needs_array_length(ty: Handle<crate::Type>, arena: &crate::UniqueArena<crate::Type>) -> bool {
392 match arena[ty].inner {
393 crate::TypeInner::Struct { ref members, .. } => {
394 if let Some(member) = members.last() {
395 if let crate::TypeInner::Array {
396 size: crate::ArraySize::Dynamic,
397 ..
398 } = arena[member.ty].inner
399 {
400 return true;
401 }
402 }
403 false
404 }
405 crate::TypeInner::Array {
406 size: crate::ArraySize::Dynamic,
407 ..
408 } => true,
409 _ => false,
410 }
411}
412
413impl crate::AddressSpace {
414 const fn needs_pass_through(&self) -> bool {
418 match *self {
419 Self::Uniform
420 | Self::Storage { .. }
421 | Self::Private
422 | Self::WorkGroup
423 | Self::PushConstant
424 | Self::Handle => true,
425 Self::Function => false,
426 }
427 }
428
429 const fn needs_access_qualifier(&self) -> bool {
431 match *self {
432 Self::Storage { .. } => true,
437 Self::Private | Self::WorkGroup => false,
439 Self::Uniform | Self::PushConstant => false,
441 Self::Handle | Self::Function => false,
443 }
444 }
445
446 const fn to_msl_name(self) -> Option<&'static str> {
447 match self {
448 Self::Handle => None,
449 Self::Uniform | Self::PushConstant => Some("constant"),
450 Self::Storage { .. } => Some("device"),
451 Self::Private | Self::Function => Some("thread"),
452 Self::WorkGroup => Some("threadgroup"),
453 }
454 }
455}
456
457impl crate::Type {
458 const fn needs_alias(&self) -> bool {
460 use crate::TypeInner as Ti;
461
462 match self.inner {
463 Ti::Scalar(_)
465 | Ti::Vector { .. }
466 | Ti::Matrix { .. }
467 | Ti::Atomic(_)
468 | Ti::Pointer { .. }
469 | Ti::ValuePointer { .. } => self.name.is_some(),
470 Ti::Struct { .. } | Ti::Array { .. } => true,
472 Ti::Image { .. }
474 | Ti::Sampler { .. }
475 | Ti::AccelerationStructure
476 | Ti::RayQuery
477 | Ti::BindingArray { .. } => false,
478 }
479 }
480}
481
482enum FunctionOrigin {
483 Handle(Handle<crate::Function>),
484 EntryPoint(proc::EntryPointIndex),
485}
486
487#[derive(Clone, Copy)]
497enum LevelOfDetail {
498 Direct(Handle<crate::Expression>),
499 Restricted(Handle<crate::Expression>),
500}
501
502struct TexelAddress {
512 coordinate: Handle<crate::Expression>,
513 array_index: Option<Handle<crate::Expression>>,
514 sample: Option<Handle<crate::Expression>>,
515 level: Option<LevelOfDetail>,
516}
517
518struct ExpressionContext<'a> {
519 function: &'a crate::Function,
520 origin: FunctionOrigin,
521 info: &'a valid::FunctionInfo,
522 module: &'a crate::Module,
523 mod_info: &'a valid::ModuleInfo,
524 pipeline_options: &'a PipelineOptions,
525 lang_version: (u8, u8),
526 policies: index::BoundsCheckPolicies,
527
528 guarded_indices: BitSet,
533}
534
535impl<'a> ExpressionContext<'a> {
536 fn resolve_type(&self, handle: Handle<crate::Expression>) -> &'a crate::TypeInner {
537 self.info[handle].ty.inner_with(&self.module.types)
538 }
539
540 fn image_needs_lod(&self, image: Handle<crate::Expression>) -> bool {
547 let image_ty = self.resolve_type(image);
548 if let crate::TypeInner::Image { dim, class, .. } = *image_ty {
549 class.is_mipmapped() && dim != crate::ImageDimension::D1
550 } else {
551 false
552 }
553 }
554
555 fn choose_bounds_check_policy(
556 &self,
557 pointer: Handle<crate::Expression>,
558 ) -> index::BoundsCheckPolicy {
559 self.policies
560 .choose_policy(pointer, &self.module.types, self.info)
561 }
562
563 fn access_needs_check(
564 &self,
565 base: Handle<crate::Expression>,
566 index: index::GuardedIndex,
567 ) -> Option<index::IndexableLength> {
568 index::access_needs_check(base, index, self.module, self.function, self.info)
569 }
570
571 fn get_packed_vec_kind(&self, expr_handle: Handle<crate::Expression>) -> Option<crate::Scalar> {
572 match self.function.expressions[expr_handle] {
573 crate::Expression::AccessIndex { base, index } => {
574 let ty = match *self.resolve_type(base) {
575 crate::TypeInner::Pointer { base, .. } => &self.module.types[base].inner,
576 ref ty => ty,
577 };
578 match *ty {
579 crate::TypeInner::Struct {
580 ref members, span, ..
581 } => should_pack_struct_member(members, span, index as usize, self.module),
582 _ => None,
583 }
584 }
585 _ => None,
586 }
587 }
588}
589
590struct StatementContext<'a> {
591 expression: ExpressionContext<'a>,
592 result_struct: Option<&'a str>,
593}
594
595impl<W: Write> Writer<W> {
596 pub fn new(out: W) -> Self {
598 Writer {
599 out,
600 names: FastHashMap::default(),
601 named_expressions: Default::default(),
602 need_bake_expressions: Default::default(),
603 namer: proc::Namer::default(),
604 #[cfg(test)]
605 put_expression_stack_pointers: Default::default(),
606 #[cfg(test)]
607 put_block_stack_pointers: Default::default(),
608 struct_member_pads: FastHashSet::default(),
609 }
610 }
611
612 #[allow(clippy::missing_const_for_fn)]
615 pub fn finish(self) -> W {
616 self.out
617 }
618
619 fn put_call_parameters(
620 &mut self,
621 parameters: impl Iterator<Item = Handle<crate::Expression>>,
622 context: &ExpressionContext,
623 ) -> BackendResult {
624 self.put_call_parameters_impl(parameters, context, |writer, context, expr| {
625 writer.put_expression(expr, context, true)
626 })
627 }
628
629 fn put_call_parameters_impl<C, E>(
630 &mut self,
631 parameters: impl Iterator<Item = Handle<crate::Expression>>,
632 ctx: &C,
633 put_expression: E,
634 ) -> BackendResult
635 where
636 E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
637 {
638 write!(self.out, "(")?;
639 for (i, handle) in parameters.enumerate() {
640 if i != 0 {
641 write!(self.out, ", ")?;
642 }
643 put_expression(self, ctx, handle)?;
644 }
645 write!(self.out, ")")?;
646 Ok(())
647 }
648
649 fn put_level_of_detail(
650 &mut self,
651 level: LevelOfDetail,
652 context: &ExpressionContext,
653 ) -> BackendResult {
654 match level {
655 LevelOfDetail::Direct(expr) => self.put_expression(expr, context, true)?,
656 LevelOfDetail::Restricted(load) => {
657 write!(self.out, "{}{}", CLAMPED_LOD_LOAD_PREFIX, load.index())?
658 }
659 }
660 Ok(())
661 }
662
663 fn put_image_query(
664 &mut self,
665 image: Handle<crate::Expression>,
666 query: &str,
667 level: Option<LevelOfDetail>,
668 context: &ExpressionContext,
669 ) -> BackendResult {
670 self.put_expression(image, context, false)?;
671 write!(self.out, ".get_{query}(")?;
672 if let Some(level) = level {
673 self.put_level_of_detail(level, context)?;
674 }
675 write!(self.out, ")")?;
676 Ok(())
677 }
678
679 fn put_image_size_query(
680 &mut self,
681 image: Handle<crate::Expression>,
682 level: Option<LevelOfDetail>,
683 kind: crate::ScalarKind,
684 context: &ExpressionContext,
685 ) -> BackendResult {
686 let dim = match *context.resolve_type(image) {
689 crate::TypeInner::Image { dim, .. } => dim,
690 ref other => unreachable!("Unexpected type {:?}", other),
691 };
692 let scalar = crate::Scalar { kind, width: 4 };
693 let coordinate_type = scalar.to_msl_name();
694 match dim {
695 crate::ImageDimension::D1 => {
696 if kind == crate::ScalarKind::Uint {
700 self.put_image_query(image, "width", None, context)?;
702 } else {
703 write!(self.out, "int(")?;
705 self.put_image_query(image, "width", None, context)?;
706 write!(self.out, ")")?;
707 }
708 }
709 crate::ImageDimension::D2 => {
710 write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
711 self.put_image_query(image, "width", level, context)?;
712 write!(self.out, ", ")?;
713 self.put_image_query(image, "height", level, context)?;
714 write!(self.out, ")")?;
715 }
716 crate::ImageDimension::D3 => {
717 write!(self.out, "{NAMESPACE}::{coordinate_type}3(")?;
718 self.put_image_query(image, "width", level, context)?;
719 write!(self.out, ", ")?;
720 self.put_image_query(image, "height", level, context)?;
721 write!(self.out, ", ")?;
722 self.put_image_query(image, "depth", level, context)?;
723 write!(self.out, ")")?;
724 }
725 crate::ImageDimension::Cube => {
726 write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
727 self.put_image_query(image, "width", level, context)?;
728 write!(self.out, ")")?;
729 }
730 }
731 Ok(())
732 }
733
734 fn put_cast_to_uint_scalar_or_vector(
735 &mut self,
736 expr: Handle<crate::Expression>,
737 context: &ExpressionContext,
738 ) -> BackendResult {
739 match *context.resolve_type(expr) {
741 crate::TypeInner::Scalar(_) => {
742 put_numeric_type(&mut self.out, crate::Scalar::U32, &[])?
743 }
744 crate::TypeInner::Vector { size, .. } => {
745 put_numeric_type(&mut self.out, crate::Scalar::U32, &[size])?
746 }
747 _ => {
748 return Err(Error::GenericValidation(
749 "Invalid type for image coordinate".into(),
750 ))
751 }
752 };
753
754 write!(self.out, "(")?;
755 self.put_expression(expr, context, true)?;
756 write!(self.out, ")")?;
757 Ok(())
758 }
759
760 fn put_image_sample_level(
761 &mut self,
762 image: Handle<crate::Expression>,
763 level: crate::SampleLevel,
764 context: &ExpressionContext,
765 ) -> BackendResult {
766 let has_levels = context.image_needs_lod(image);
767 match level {
768 crate::SampleLevel::Auto => {}
769 crate::SampleLevel::Zero => {
770 }
772 _ if !has_levels => {
773 log::warn!("1D image can't be sampled with level {:?}", level);
774 }
775 crate::SampleLevel::Exact(h) => {
776 write!(self.out, ", {NAMESPACE}::level(")?;
777 self.put_expression(h, context, true)?;
778 write!(self.out, ")")?;
779 }
780 crate::SampleLevel::Bias(h) => {
781 write!(self.out, ", {NAMESPACE}::bias(")?;
782 self.put_expression(h, context, true)?;
783 write!(self.out, ")")?;
784 }
785 crate::SampleLevel::Gradient { x, y } => {
786 write!(self.out, ", {NAMESPACE}::gradient2d(")?;
787 self.put_expression(x, context, true)?;
788 write!(self.out, ", ")?;
789 self.put_expression(y, context, true)?;
790 write!(self.out, ")")?;
791 }
792 }
793 Ok(())
794 }
795
796 fn put_image_coordinate_limits(
797 &mut self,
798 image: Handle<crate::Expression>,
799 level: Option<LevelOfDetail>,
800 context: &ExpressionContext,
801 ) -> BackendResult {
802 self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
803 write!(self.out, " - 1")?;
804 Ok(())
805 }
806
807 fn put_restricted_scalar_image_index(
825 &mut self,
826 image: Handle<crate::Expression>,
827 index: Handle<crate::Expression>,
828 limit_method: &str,
829 context: &ExpressionContext,
830 ) -> BackendResult {
831 write!(self.out, "{NAMESPACE}::min(uint(")?;
832 self.put_expression(index, context, true)?;
833 write!(self.out, "), ")?;
834 self.put_expression(image, context, false)?;
835 write!(self.out, ".{limit_method}() - 1)")?;
836 Ok(())
837 }
838
839 fn put_restricted_texel_address(
840 &mut self,
841 image: Handle<crate::Expression>,
842 address: &TexelAddress,
843 context: &ExpressionContext,
844 ) -> BackendResult {
845 write!(self.out, "{NAMESPACE}::min(")?;
847 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
848 write!(self.out, ", ")?;
849 self.put_image_coordinate_limits(image, address.level, context)?;
850 write!(self.out, ")")?;
851
852 if let Some(array_index) = address.array_index {
854 write!(self.out, ", ")?;
855 self.put_restricted_scalar_image_index(image, array_index, "get_array_size", context)?;
856 }
857
858 if let Some(sample) = address.sample {
860 write!(self.out, ", ")?;
861 self.put_restricted_scalar_image_index(image, sample, "get_num_samples", context)?;
862 }
863
864 if let Some(level) = address.level {
867 write!(self.out, ", ")?;
868 self.put_level_of_detail(level, context)?;
869 }
870
871 Ok(())
872 }
873
874 fn put_image_access_bounds_check(
876 &mut self,
877 image: Handle<crate::Expression>,
878 address: &TexelAddress,
879 context: &ExpressionContext,
880 ) -> BackendResult {
881 let mut conjunction = "";
882
883 let level = if let Some(level) = address.level {
886 write!(self.out, "uint(")?;
887 self.put_level_of_detail(level, context)?;
888 write!(self.out, ") < ")?;
889 self.put_expression(image, context, true)?;
890 write!(self.out, ".get_num_mip_levels()")?;
891 conjunction = " && ";
892 Some(level)
893 } else {
894 None
895 };
896
897 if let Some(sample) = address.sample {
899 write!(self.out, "uint(")?;
900 self.put_expression(sample, context, true)?;
901 write!(self.out, ") < ")?;
902 self.put_expression(image, context, true)?;
903 write!(self.out, ".get_num_samples()")?;
904 conjunction = " && ";
905 }
906
907 if let Some(array_index) = address.array_index {
909 write!(self.out, "{conjunction}uint(")?;
910 self.put_expression(array_index, context, true)?;
911 write!(self.out, ") < ")?;
912 self.put_expression(image, context, true)?;
913 write!(self.out, ".get_array_size()")?;
914 conjunction = " && ";
915 }
916
917 let coord_is_vector = match *context.resolve_type(address.coordinate) {
919 crate::TypeInner::Vector { .. } => true,
920 _ => false,
921 };
922 write!(self.out, "{conjunction}")?;
923 if coord_is_vector {
924 write!(self.out, "{NAMESPACE}::all(")?;
925 }
926 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
927 write!(self.out, " < ")?;
928 self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
929 if coord_is_vector {
930 write!(self.out, ")")?;
931 }
932
933 Ok(())
934 }
935
936 fn put_image_load(
937 &mut self,
938 load: Handle<crate::Expression>,
939 image: Handle<crate::Expression>,
940 mut address: TexelAddress,
941 context: &ExpressionContext,
942 ) -> BackendResult {
943 match context.policies.image_load {
944 proc::BoundsCheckPolicy::Restrict => {
945 if address.level.is_some() {
948 address.level = if context.image_needs_lod(image) {
949 Some(LevelOfDetail::Restricted(load))
950 } else {
951 None
952 }
953 }
954
955 self.put_expression(image, context, false)?;
956 write!(self.out, ".read(")?;
957 self.put_restricted_texel_address(image, &address, context)?;
958 write!(self.out, ")")?;
959 }
960 proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
961 write!(self.out, "(")?;
962 self.put_image_access_bounds_check(image, &address, context)?;
963 write!(self.out, " ? ")?;
964 self.put_unchecked_image_load(image, &address, context)?;
965 write!(self.out, ": DefaultConstructible())")?;
966 }
967 proc::BoundsCheckPolicy::Unchecked => {
968 self.put_unchecked_image_load(image, &address, context)?;
969 }
970 }
971
972 Ok(())
973 }
974
975 fn put_unchecked_image_load(
976 &mut self,
977 image: Handle<crate::Expression>,
978 address: &TexelAddress,
979 context: &ExpressionContext,
980 ) -> BackendResult {
981 self.put_expression(image, context, false)?;
982 write!(self.out, ".read(")?;
983 self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
985 if let Some(expr) = address.array_index {
986 write!(self.out, ", ")?;
987 self.put_expression(expr, context, true)?;
988 }
989 if let Some(sample) = address.sample {
990 write!(self.out, ", ")?;
991 self.put_expression(sample, context, true)?;
992 }
993 if let Some(level) = address.level {
994 if context.image_needs_lod(image) {
995 write!(self.out, ", ")?;
996 self.put_level_of_detail(level, context)?;
997 }
998 }
999 write!(self.out, ")")?;
1000
1001 Ok(())
1002 }
1003
1004 fn put_image_store(
1005 &mut self,
1006 level: back::Level,
1007 image: Handle<crate::Expression>,
1008 address: &TexelAddress,
1009 value: Handle<crate::Expression>,
1010 context: &StatementContext,
1011 ) -> BackendResult {
1012 match context.expression.policies.image_store {
1013 proc::BoundsCheckPolicy::Restrict => {
1014 debug_assert!(address.level.is_none());
1017
1018 write!(self.out, "{level}")?;
1019 self.put_expression(image, &context.expression, false)?;
1020 write!(self.out, ".write(")?;
1021 self.put_expression(value, &context.expression, true)?;
1022 write!(self.out, ", ")?;
1023 self.put_restricted_texel_address(image, address, &context.expression)?;
1024 writeln!(self.out, ");")?;
1025 }
1026 proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
1027 write!(self.out, "{level}if (")?;
1028 self.put_image_access_bounds_check(image, address, &context.expression)?;
1029 writeln!(self.out, ") {{")?;
1030 self.put_unchecked_image_store(level.next(), image, address, value, context)?;
1031 writeln!(self.out, "{level}}}")?;
1032 }
1033 proc::BoundsCheckPolicy::Unchecked => {
1034 self.put_unchecked_image_store(level, image, address, value, context)?;
1035 }
1036 }
1037
1038 Ok(())
1039 }
1040
1041 fn put_unchecked_image_store(
1042 &mut self,
1043 level: back::Level,
1044 image: Handle<crate::Expression>,
1045 address: &TexelAddress,
1046 value: Handle<crate::Expression>,
1047 context: &StatementContext,
1048 ) -> BackendResult {
1049 write!(self.out, "{level}")?;
1050 self.put_expression(image, &context.expression, false)?;
1051 write!(self.out, ".write(")?;
1052 self.put_expression(value, &context.expression, true)?;
1053 write!(self.out, ", ")?;
1054 self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
1056 if let Some(expr) = address.array_index {
1057 write!(self.out, ", ")?;
1058 self.put_expression(expr, &context.expression, true)?;
1059 }
1060 writeln!(self.out, ");")?;
1061
1062 Ok(())
1063 }
1064
1065 fn put_dynamic_array_max_index(
1076 &mut self,
1077 handle: Handle<crate::GlobalVariable>,
1078 context: &ExpressionContext,
1079 ) -> BackendResult {
1080 let global = &context.module.global_variables[handle];
1081 let (offset, array_ty) = match context.module.types[global.ty].inner {
1082 crate::TypeInner::Struct { ref members, .. } => match members.last() {
1083 Some(&crate::StructMember { offset, ty, .. }) => (offset, ty),
1084 None => return Err(Error::GenericValidation("Struct has no members".into())),
1085 },
1086 crate::TypeInner::Array {
1087 size: crate::ArraySize::Dynamic,
1088 ..
1089 } => (0, global.ty),
1090 ref ty => {
1091 return Err(Error::GenericValidation(format!(
1092 "Expected type with dynamic array, got {ty:?}"
1093 )))
1094 }
1095 };
1096
1097 let (size, stride) = match context.module.types[array_ty].inner {
1098 crate::TypeInner::Array { base, stride, .. } => (
1099 context.module.types[base]
1100 .inner
1101 .size(context.module.to_ctx()),
1102 stride,
1103 ),
1104 ref ty => {
1105 return Err(Error::GenericValidation(format!(
1106 "Expected array type, got {ty:?}"
1107 )))
1108 }
1109 };
1110
1111 write!(
1124 self.out,
1125 "(_buffer_sizes.size{idx} - {offset} - {size}) / {stride}",
1126 idx = handle.index(),
1127 offset = offset,
1128 size = size,
1129 stride = stride,
1130 )?;
1131 Ok(())
1132 }
1133
1134 fn put_atomic_operation(
1135 &mut self,
1136 pointer: Handle<crate::Expression>,
1137 key: &str,
1138 value: Handle<crate::Expression>,
1139 context: &ExpressionContext,
1140 ) -> BackendResult {
1141 let policy = context.choose_bounds_check_policy(pointer);
1145 let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
1146 && self.put_bounds_checks(pointer, context, back::Level(0), "")?;
1147
1148 if checked {
1150 write!(self.out, " ? ")?;
1151 }
1152
1153 write!(
1154 self.out,
1155 "{NAMESPACE}::atomic_{key}_explicit({ATOMIC_REFERENCE}"
1156 )?;
1157 self.put_access_chain(pointer, policy, context)?;
1158 write!(self.out, ", ")?;
1159 self.put_expression(value, context, true)?;
1160 write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
1161
1162 if checked {
1164 write!(self.out, " : DefaultConstructible()")?;
1165 }
1166
1167 Ok(())
1168 }
1169
1170 fn put_dot_product(
1173 &mut self,
1174 arg: Handle<crate::Expression>,
1175 arg1: Handle<crate::Expression>,
1176 size: usize,
1177 context: &ExpressionContext,
1178 ) -> BackendResult {
1179 write!(self.out, "(")?;
1182
1183 for index in 0..size {
1185 let component = back::COMPONENTS[index];
1186 write!(self.out, " + ")?;
1189 self.put_expression(arg, context, true)?;
1193 write!(self.out, ".{component} * ")?;
1195 self.put_expression(arg1, context, true)?;
1199 write!(self.out, ".{component}")?;
1201 }
1202
1203 write!(self.out, ")")?;
1204 Ok(())
1205 }
1206
1207 fn put_isign(
1210 &mut self,
1211 arg: Handle<crate::Expression>,
1212 context: &ExpressionContext,
1213 ) -> BackendResult {
1214 write!(self.out, "{NAMESPACE}::select({NAMESPACE}::select(")?;
1215 match context.resolve_type(arg) {
1216 &crate::TypeInner::Vector { size, .. } => {
1217 let size = back::vector_size_str(size);
1218 write!(self.out, "int{size}(-1), int{size}(1)")?;
1219 }
1220 _ => {
1221 write!(self.out, "-1, 1")?;
1222 }
1223 }
1224 write!(self.out, ", (")?;
1225 self.put_expression(arg, context, true)?;
1226 write!(self.out, " > 0)), 0, (")?;
1227 self.put_expression(arg, context, true)?;
1228 write!(self.out, " == 0))")?;
1229 Ok(())
1230 }
1231
1232 fn put_const_expression(
1233 &mut self,
1234 expr_handle: Handle<crate::Expression>,
1235 module: &crate::Module,
1236 mod_info: &valid::ModuleInfo,
1237 ) -> BackendResult {
1238 self.put_possibly_const_expression(
1239 expr_handle,
1240 &module.global_expressions,
1241 module,
1242 mod_info,
1243 &(module, mod_info),
1244 |&(_, mod_info), expr| &mod_info[expr],
1245 |writer, &(module, _), expr| writer.put_const_expression(expr, module, mod_info),
1246 )
1247 }
1248
1249 #[allow(clippy::too_many_arguments)]
1250 fn put_possibly_const_expression<C, I, E>(
1251 &mut self,
1252 expr_handle: Handle<crate::Expression>,
1253 expressions: &crate::Arena<crate::Expression>,
1254 module: &crate::Module,
1255 mod_info: &valid::ModuleInfo,
1256 ctx: &C,
1257 get_expr_ty: I,
1258 put_expression: E,
1259 ) -> BackendResult
1260 where
1261 I: Fn(&C, Handle<crate::Expression>) -> &TypeResolution,
1262 E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
1263 {
1264 match expressions[expr_handle] {
1265 crate::Expression::Literal(literal) => match literal {
1266 crate::Literal::F64(_) => {
1267 return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64))
1268 }
1269 crate::Literal::F32(value) => {
1270 if value.is_infinite() {
1271 let sign = if value.is_sign_negative() { "-" } else { "" };
1272 write!(self.out, "{sign}INFINITY")?;
1273 } else if value.is_nan() {
1274 write!(self.out, "NAN")?;
1275 } else {
1276 let suffix = if value.fract() == 0.0 { ".0" } else { "" };
1277 write!(self.out, "{value}{suffix}")?;
1278 }
1279 }
1280 crate::Literal::U32(value) => {
1281 write!(self.out, "{value}u")?;
1282 }
1283 crate::Literal::I32(value) => {
1284 write!(self.out, "{value}")?;
1285 }
1286 crate::Literal::U64(value) => {
1287 write!(self.out, "{value}uL")?;
1288 }
1289 crate::Literal::I64(value) => {
1290 write!(self.out, "{value}L")?;
1291 }
1292 crate::Literal::Bool(value) => {
1293 write!(self.out, "{value}")?;
1294 }
1295 crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
1296 return Err(Error::GenericValidation(
1297 "Unsupported abstract literal".into(),
1298 ));
1299 }
1300 },
1301 crate::Expression::Constant(handle) => {
1302 let constant = &module.constants[handle];
1303 if constant.name.is_some() {
1304 write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
1305 } else {
1306 self.put_const_expression(constant.init, module, mod_info)?;
1307 }
1308 }
1309 crate::Expression::ZeroValue(ty) => {
1310 let ty_name = TypeContext {
1311 handle: ty,
1312 gctx: module.to_ctx(),
1313 names: &self.names,
1314 access: crate::StorageAccess::empty(),
1315 binding: None,
1316 first_time: false,
1317 };
1318 write!(self.out, "{ty_name} {{}}")?;
1319 }
1320 crate::Expression::Compose { ty, ref components } => {
1321 let ty_name = TypeContext {
1322 handle: ty,
1323 gctx: module.to_ctx(),
1324 names: &self.names,
1325 access: crate::StorageAccess::empty(),
1326 binding: None,
1327 first_time: false,
1328 };
1329 write!(self.out, "{ty_name}")?;
1330 match module.types[ty].inner {
1331 crate::TypeInner::Scalar(_)
1332 | crate::TypeInner::Vector { .. }
1333 | crate::TypeInner::Matrix { .. } => {
1334 self.put_call_parameters_impl(
1335 components.iter().copied(),
1336 ctx,
1337 put_expression,
1338 )?;
1339 }
1340 crate::TypeInner::Array { .. } | crate::TypeInner::Struct { .. } => {
1341 write!(self.out, " {{")?;
1342 for (index, &component) in components.iter().enumerate() {
1343 if index != 0 {
1344 write!(self.out, ", ")?;
1345 }
1346 if self.struct_member_pads.contains(&(ty, index as u32)) {
1348 write!(self.out, "{{}}, ")?;
1349 }
1350 put_expression(self, ctx, component)?;
1351 }
1352 write!(self.out, "}}")?;
1353 }
1354 _ => return Err(Error::UnsupportedCompose(ty)),
1355 }
1356 }
1357 crate::Expression::Splat { size, value } => {
1358 let scalar = match *get_expr_ty(ctx, value).inner_with(&module.types) {
1359 crate::TypeInner::Scalar(scalar) => scalar,
1360 ref ty => {
1361 return Err(Error::GenericValidation(format!(
1362 "Expected splat value type must be a scalar, got {ty:?}",
1363 )))
1364 }
1365 };
1366 put_numeric_type(&mut self.out, scalar, &[size])?;
1367 write!(self.out, "(")?;
1368 put_expression(self, ctx, value)?;
1369 write!(self.out, ")")?;
1370 }
1371 _ => unreachable!(),
1372 }
1373
1374 Ok(())
1375 }
1376
1377 fn put_expression(
1389 &mut self,
1390 expr_handle: Handle<crate::Expression>,
1391 context: &ExpressionContext,
1392 is_scoped: bool,
1393 ) -> BackendResult {
1394 #[cfg(test)]
1396 #[allow(trivial_casts)]
1397 self.put_expression_stack_pointers
1398 .insert(&expr_handle as *const _ as *const ());
1399
1400 if let Some(name) = self.named_expressions.get(&expr_handle) {
1401 write!(self.out, "{name}")?;
1402 return Ok(());
1403 }
1404
1405 let expression = &context.function.expressions[expr_handle];
1406 log::trace!("expression {:?} = {:?}", expr_handle, expression);
1407 match *expression {
1408 crate::Expression::Literal(_)
1409 | crate::Expression::Constant(_)
1410 | crate::Expression::ZeroValue(_)
1411 | crate::Expression::Compose { .. }
1412 | crate::Expression::Splat { .. } => {
1413 self.put_possibly_const_expression(
1414 expr_handle,
1415 &context.function.expressions,
1416 context.module,
1417 context.mod_info,
1418 context,
1419 |context, expr: Handle<crate::Expression>| &context.info[expr].ty,
1420 |writer, context, expr| writer.put_expression(expr, context, true),
1421 )?;
1422 }
1423 crate::Expression::Override(_) => return Err(Error::Override),
1424 crate::Expression::Access { base, .. }
1425 | crate::Expression::AccessIndex { base, .. } => {
1426 let policy = context.choose_bounds_check_policy(base);
1432 if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
1433 && self.put_bounds_checks(
1434 expr_handle,
1435 context,
1436 back::Level(0),
1437 if is_scoped { "" } else { "(" },
1438 )?
1439 {
1440 write!(self.out, " ? ")?;
1441 self.put_access_chain(expr_handle, policy, context)?;
1442 write!(self.out, " : DefaultConstructible()")?;
1443
1444 if !is_scoped {
1445 write!(self.out, ")")?;
1446 }
1447 } else {
1448 self.put_access_chain(expr_handle, policy, context)?;
1449 }
1450 }
1451 crate::Expression::Swizzle {
1452 size,
1453 vector,
1454 pattern,
1455 } => {
1456 self.put_wrapped_expression_for_packed_vec3_access(vector, context, false)?;
1457 write!(self.out, ".")?;
1458 for &sc in pattern[..size as usize].iter() {
1459 write!(self.out, "{}", back::COMPONENTS[sc as usize])?;
1460 }
1461 }
1462 crate::Expression::FunctionArgument(index) => {
1463 let name_key = match context.origin {
1464 FunctionOrigin::Handle(handle) => NameKey::FunctionArgument(handle, index),
1465 FunctionOrigin::EntryPoint(ep_index) => {
1466 NameKey::EntryPointArgument(ep_index, index)
1467 }
1468 };
1469 let name = &self.names[&name_key];
1470 write!(self.out, "{name}")?;
1471 }
1472 crate::Expression::GlobalVariable(handle) => {
1473 let name = &self.names[&NameKey::GlobalVariable(handle)];
1474 write!(self.out, "{name}")?;
1475 }
1476 crate::Expression::LocalVariable(handle) => {
1477 let name_key = match context.origin {
1478 FunctionOrigin::Handle(fun_handle) => {
1479 NameKey::FunctionLocal(fun_handle, handle)
1480 }
1481 FunctionOrigin::EntryPoint(ep_index) => {
1482 NameKey::EntryPointLocal(ep_index, handle)
1483 }
1484 };
1485 let name = &self.names[&name_key];
1486 write!(self.out, "{name}")?;
1487 }
1488 crate::Expression::Load { pointer } => self.put_load(pointer, context, is_scoped)?,
1489 crate::Expression::ImageSample {
1490 image,
1491 sampler,
1492 gather,
1493 coordinate,
1494 array_index,
1495 offset,
1496 level,
1497 depth_ref,
1498 } => {
1499 let main_op = match gather {
1500 Some(_) => "gather",
1501 None => "sample",
1502 };
1503 let comparison_op = match depth_ref {
1504 Some(_) => "_compare",
1505 None => "",
1506 };
1507 self.put_expression(image, context, false)?;
1508 write!(self.out, ".{main_op}{comparison_op}(")?;
1509 self.put_expression(sampler, context, true)?;
1510 write!(self.out, ", ")?;
1511 self.put_expression(coordinate, context, true)?;
1512 if let Some(expr) = array_index {
1513 write!(self.out, ", ")?;
1514 self.put_expression(expr, context, true)?;
1515 }
1516 if let Some(dref) = depth_ref {
1517 write!(self.out, ", ")?;
1518 self.put_expression(dref, context, true)?;
1519 }
1520
1521 self.put_image_sample_level(image, level, context)?;
1522
1523 if let Some(offset) = offset {
1524 write!(self.out, ", ")?;
1525 self.put_const_expression(offset, context.module, context.mod_info)?;
1526 }
1527
1528 match gather {
1529 None | Some(crate::SwizzleComponent::X) => {}
1530 Some(component) => {
1531 let is_cube_map = match *context.resolve_type(image) {
1532 crate::TypeInner::Image {
1533 dim: crate::ImageDimension::Cube,
1534 ..
1535 } => true,
1536 _ => false,
1537 };
1538 if offset.is_none() && !is_cube_map {
1541 write!(self.out, ", {NAMESPACE}::int2(0)")?;
1542 }
1543 let letter = back::COMPONENTS[component as usize];
1544 write!(self.out, ", {NAMESPACE}::component::{letter}")?;
1545 }
1546 }
1547 write!(self.out, ")")?;
1548 }
1549 crate::Expression::ImageLoad {
1550 image,
1551 coordinate,
1552 array_index,
1553 sample,
1554 level,
1555 } => {
1556 let address = TexelAddress {
1557 coordinate,
1558 array_index,
1559 sample,
1560 level: level.map(LevelOfDetail::Direct),
1561 };
1562 self.put_image_load(expr_handle, image, address, context)?;
1563 }
1564 crate::Expression::ImageQuery { image, query } => match query {
1567 crate::ImageQuery::Size { level } => {
1568 self.put_image_size_query(
1569 image,
1570 level.map(LevelOfDetail::Direct),
1571 crate::ScalarKind::Uint,
1572 context,
1573 )?;
1574 }
1575 crate::ImageQuery::NumLevels => {
1576 self.put_expression(image, context, false)?;
1577 write!(self.out, ".get_num_mip_levels()")?;
1578 }
1579 crate::ImageQuery::NumLayers => {
1580 self.put_expression(image, context, false)?;
1581 write!(self.out, ".get_array_size()")?;
1582 }
1583 crate::ImageQuery::NumSamples => {
1584 self.put_expression(image, context, false)?;
1585 write!(self.out, ".get_num_samples()")?;
1586 }
1587 },
1588 crate::Expression::Unary { op, expr } => {
1589 let op_str = match op {
1590 crate::UnaryOperator::Negate => "-",
1591 crate::UnaryOperator::LogicalNot => "!",
1592 crate::UnaryOperator::BitwiseNot => "~",
1593 };
1594 write!(self.out, "{op_str}(")?;
1595 self.put_expression(expr, context, false)?;
1596 write!(self.out, ")")?;
1597 }
1598 crate::Expression::Binary { op, left, right } => {
1599 let op_str = crate::back::binary_operation_str(op);
1600 let kind = context
1601 .resolve_type(left)
1602 .scalar_kind()
1603 .ok_or(Error::UnsupportedBinaryOp(op))?;
1604
1605 if op == crate::BinaryOperator::Modulo && kind == crate::ScalarKind::Float {
1619 write!(self.out, "{NAMESPACE}::fmod(")?;
1620 self.put_expression(left, context, true)?;
1621 write!(self.out, ", ")?;
1622 self.put_expression(right, context, true)?;
1623 write!(self.out, ")")?;
1624 } else {
1625 if !is_scoped {
1626 write!(self.out, "(")?;
1627 }
1628
1629 if op == crate::BinaryOperator::Multiply
1632 && matches!(
1633 context.resolve_type(right),
1634 &crate::TypeInner::Matrix { .. }
1635 )
1636 {
1637 self.put_wrapped_expression_for_packed_vec3_access(left, context, false)?;
1638 } else {
1639 self.put_expression(left, context, false)?;
1640 }
1641
1642 write!(self.out, " {op_str} ")?;
1643
1644 if op == crate::BinaryOperator::Multiply
1646 && matches!(context.resolve_type(left), &crate::TypeInner::Matrix { .. })
1647 {
1648 self.put_wrapped_expression_for_packed_vec3_access(right, context, false)?;
1649 } else {
1650 self.put_expression(right, context, false)?;
1651 }
1652
1653 if !is_scoped {
1654 write!(self.out, ")")?;
1655 }
1656 }
1657 }
1658 crate::Expression::Select {
1659 condition,
1660 accept,
1661 reject,
1662 } => match *context.resolve_type(condition) {
1663 crate::TypeInner::Scalar(crate::Scalar {
1664 kind: crate::ScalarKind::Bool,
1665 ..
1666 }) => {
1667 if !is_scoped {
1668 write!(self.out, "(")?;
1669 }
1670 self.put_expression(condition, context, false)?;
1671 write!(self.out, " ? ")?;
1672 self.put_expression(accept, context, false)?;
1673 write!(self.out, " : ")?;
1674 self.put_expression(reject, context, false)?;
1675 if !is_scoped {
1676 write!(self.out, ")")?;
1677 }
1678 }
1679 crate::TypeInner::Vector {
1680 scalar:
1681 crate::Scalar {
1682 kind: crate::ScalarKind::Bool,
1683 ..
1684 },
1685 ..
1686 } => {
1687 write!(self.out, "{NAMESPACE}::select(")?;
1688 self.put_expression(reject, context, true)?;
1689 write!(self.out, ", ")?;
1690 self.put_expression(accept, context, true)?;
1691 write!(self.out, ", ")?;
1692 self.put_expression(condition, context, true)?;
1693 write!(self.out, ")")?;
1694 }
1695 ref ty => {
1696 return Err(Error::GenericValidation(format!(
1697 "Expected select condition to be a non-bool type, got {ty:?}",
1698 )))
1699 }
1700 },
1701 crate::Expression::Derivative { axis, expr, .. } => {
1702 use crate::DerivativeAxis as Axis;
1703 let op = match axis {
1704 Axis::X => "dfdx",
1705 Axis::Y => "dfdy",
1706 Axis::Width => "fwidth",
1707 };
1708 write!(self.out, "{NAMESPACE}::{op}")?;
1709 self.put_call_parameters(iter::once(expr), context)?;
1710 }
1711 crate::Expression::Relational { fun, argument } => {
1712 let op = match fun {
1713 crate::RelationalFunction::Any => "any",
1714 crate::RelationalFunction::All => "all",
1715 crate::RelationalFunction::IsNan => "isnan",
1716 crate::RelationalFunction::IsInf => "isinf",
1717 };
1718 write!(self.out, "{NAMESPACE}::{op}")?;
1719 self.put_call_parameters(iter::once(argument), context)?;
1720 }
1721 crate::Expression::Math {
1722 fun,
1723 arg,
1724 arg1,
1725 arg2,
1726 arg3,
1727 } => {
1728 use crate::MathFunction as Mf;
1729
1730 let arg_type = context.resolve_type(arg);
1731 let scalar_argument = match arg_type {
1732 &crate::TypeInner::Scalar(_) => true,
1733 _ => false,
1734 };
1735
1736 let fun_name = match fun {
1737 Mf::Abs => "abs",
1739 Mf::Min => "min",
1740 Mf::Max => "max",
1741 Mf::Clamp => "clamp",
1742 Mf::Saturate => "saturate",
1743 Mf::Cos => "cos",
1745 Mf::Cosh => "cosh",
1746 Mf::Sin => "sin",
1747 Mf::Sinh => "sinh",
1748 Mf::Tan => "tan",
1749 Mf::Tanh => "tanh",
1750 Mf::Acos => "acos",
1751 Mf::Asin => "asin",
1752 Mf::Atan => "atan",
1753 Mf::Atan2 => "atan2",
1754 Mf::Asinh => "asinh",
1755 Mf::Acosh => "acosh",
1756 Mf::Atanh => "atanh",
1757 Mf::Radians => "",
1758 Mf::Degrees => "",
1759 Mf::Ceil => "ceil",
1761 Mf::Floor => "floor",
1762 Mf::Round => "rint",
1763 Mf::Fract => "fract",
1764 Mf::Trunc => "trunc",
1765 Mf::Modf => MODF_FUNCTION,
1766 Mf::Frexp => FREXP_FUNCTION,
1767 Mf::Ldexp => "ldexp",
1768 Mf::Exp => "exp",
1770 Mf::Exp2 => "exp2",
1771 Mf::Log => "log",
1772 Mf::Log2 => "log2",
1773 Mf::Pow => "pow",
1774 Mf::Dot => match *context.resolve_type(arg) {
1776 crate::TypeInner::Vector {
1777 scalar:
1778 crate::Scalar {
1779 kind: crate::ScalarKind::Float,
1780 ..
1781 },
1782 ..
1783 } => "dot",
1784 crate::TypeInner::Vector { size, .. } => {
1785 return self.put_dot_product(arg, arg1.unwrap(), size as usize, context)
1786 }
1787 _ => unreachable!(
1788 "Correct TypeInner for dot product should be already validated"
1789 ),
1790 },
1791 Mf::Outer => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
1792 Mf::Cross => "cross",
1793 Mf::Distance => "distance",
1794 Mf::Length if scalar_argument => "abs",
1795 Mf::Length => "length",
1796 Mf::Normalize => "normalize",
1797 Mf::FaceForward => "faceforward",
1798 Mf::Reflect => "reflect",
1799 Mf::Refract => "refract",
1800 Mf::Sign => match arg_type.scalar_kind() {
1802 Some(crate::ScalarKind::Sint) => {
1803 return self.put_isign(arg, context);
1804 }
1805 _ => "sign",
1806 },
1807 Mf::Fma => "fma",
1808 Mf::Mix => "mix",
1809 Mf::Step => "step",
1810 Mf::SmoothStep => "smoothstep",
1811 Mf::Sqrt => "sqrt",
1812 Mf::InverseSqrt => "rsqrt",
1813 Mf::Inverse => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
1814 Mf::Transpose => "transpose",
1815 Mf::Determinant => "determinant",
1816 Mf::CountTrailingZeros => "ctz",
1818 Mf::CountLeadingZeros => "clz",
1819 Mf::CountOneBits => "popcount",
1820 Mf::ReverseBits => "reverse_bits",
1821 Mf::ExtractBits => "",
1822 Mf::InsertBits => "",
1823 Mf::FindLsb => "",
1824 Mf::FindMsb => "",
1825 Mf::Pack4x8snorm => "pack_float_to_snorm4x8",
1827 Mf::Pack4x8unorm => "pack_float_to_unorm4x8",
1828 Mf::Pack2x16snorm => "pack_float_to_snorm2x16",
1829 Mf::Pack2x16unorm => "pack_float_to_unorm2x16",
1830 Mf::Pack2x16float => "",
1831 Mf::Unpack4x8snorm => "unpack_snorm4x8_to_float",
1833 Mf::Unpack4x8unorm => "unpack_unorm4x8_to_float",
1834 Mf::Unpack2x16snorm => "unpack_snorm2x16_to_float",
1835 Mf::Unpack2x16unorm => "unpack_unorm2x16_to_float",
1836 Mf::Unpack2x16float => "",
1837 };
1838
1839 match fun {
1840 Mf::ReverseBits | Mf::ExtractBits | Mf::InsertBits => {
1841 if context.lang_version < (1, 2) {
1850 return Err(Error::UnsupportedFunction(fun_name.to_string()));
1851 }
1852 }
1853 _ => {}
1854 }
1855
1856 if fun == Mf::Distance && scalar_argument {
1857 write!(self.out, "{NAMESPACE}::abs(")?;
1858 self.put_expression(arg, context, false)?;
1859 write!(self.out, " - ")?;
1860 self.put_expression(arg1.unwrap(), context, false)?;
1861 write!(self.out, ")")?;
1862 } else if fun == Mf::FindLsb {
1863 let scalar = context.resolve_type(arg).scalar().unwrap();
1864 let constant = scalar.width * 8 + 1;
1865
1866 write!(self.out, "((({NAMESPACE}::ctz(")?;
1867 self.put_expression(arg, context, true)?;
1868 write!(self.out, ") + 1) % {constant}) - 1)")?;
1869 } else if fun == Mf::FindMsb {
1870 let inner = context.resolve_type(arg);
1871 let scalar = inner.scalar().unwrap();
1872 let constant = scalar.width * 8 - 1;
1873
1874 write!(
1875 self.out,
1876 "{NAMESPACE}::select({constant} - {NAMESPACE}::clz("
1877 )?;
1878
1879 if scalar.kind == crate::ScalarKind::Sint {
1880 write!(self.out, "{NAMESPACE}::select(")?;
1881 self.put_expression(arg, context, true)?;
1882 write!(self.out, ", ~")?;
1883 self.put_expression(arg, context, true)?;
1884 write!(self.out, ", ")?;
1885 self.put_expression(arg, context, true)?;
1886 write!(self.out, " < 0)")?;
1887 } else {
1888 self.put_expression(arg, context, true)?;
1889 }
1890
1891 write!(self.out, "), ")?;
1892
1893 match *inner {
1895 crate::TypeInner::Vector { size, scalar } => {
1896 let size = back::vector_size_str(size);
1897 let name = scalar.to_msl_name();
1898 write!(self.out, "{name}{size}")?;
1899 }
1900 crate::TypeInner::Scalar(scalar) => {
1901 let name = scalar.to_msl_name();
1902 write!(self.out, "{name}")?;
1903 }
1904 _ => (),
1905 }
1906
1907 write!(self.out, "(-1), ")?;
1908 self.put_expression(arg, context, true)?;
1909 write!(self.out, " == 0 || ")?;
1910 self.put_expression(arg, context, true)?;
1911 write!(self.out, " == -1)")?;
1912 } else if fun == Mf::Unpack2x16float {
1913 write!(self.out, "float2(as_type<half2>(")?;
1914 self.put_expression(arg, context, false)?;
1915 write!(self.out, "))")?;
1916 } else if fun == Mf::Pack2x16float {
1917 write!(self.out, "as_type<uint>(half2(")?;
1918 self.put_expression(arg, context, false)?;
1919 write!(self.out, "))")?;
1920 } else if fun == Mf::ExtractBits {
1921 let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
1938
1939 write!(self.out, "{NAMESPACE}::extract_bits(")?;
1940 self.put_expression(arg, context, true)?;
1941 write!(self.out, ", {NAMESPACE}::min(")?;
1942 self.put_expression(arg1.unwrap(), context, true)?;
1943 write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
1944 self.put_expression(arg2.unwrap(), context, true)?;
1945 write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
1946 self.put_expression(arg1.unwrap(), context, true)?;
1947 write!(self.out, ", {scalar_bits}u)))")?;
1948 } else if fun == Mf::InsertBits {
1949 let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
1954
1955 write!(self.out, "{NAMESPACE}::insert_bits(")?;
1956 self.put_expression(arg, context, true)?;
1957 write!(self.out, ", ")?;
1958 self.put_expression(arg1.unwrap(), context, true)?;
1959 write!(self.out, ", {NAMESPACE}::min(")?;
1960 self.put_expression(arg2.unwrap(), context, true)?;
1961 write!(self.out, ", {scalar_bits}u), {NAMESPACE}::min(")?;
1962 self.put_expression(arg3.unwrap(), context, true)?;
1963 write!(self.out, ", {scalar_bits}u - {NAMESPACE}::min(")?;
1964 self.put_expression(arg2.unwrap(), context, true)?;
1965 write!(self.out, ", {scalar_bits}u)))")?;
1966 } else if fun == Mf::Radians {
1967 write!(self.out, "((")?;
1968 self.put_expression(arg, context, false)?;
1969 write!(self.out, ") * 0.017453292519943295474)")?;
1970 } else if fun == Mf::Degrees {
1971 write!(self.out, "((")?;
1972 self.put_expression(arg, context, false)?;
1973 write!(self.out, ") * 57.295779513082322865)")?;
1974 } else if fun == Mf::Modf || fun == Mf::Frexp {
1975 write!(self.out, "{fun_name}")?;
1976 self.put_call_parameters(iter::once(arg), context)?;
1977 } else {
1978 write!(self.out, "{NAMESPACE}::{fun_name}")?;
1979 self.put_call_parameters(
1980 iter::once(arg).chain(arg1).chain(arg2).chain(arg3),
1981 context,
1982 )?;
1983 }
1984 }
1985 crate::Expression::As {
1986 expr,
1987 kind,
1988 convert,
1989 } => match *context.resolve_type(expr) {
1990 crate::TypeInner::Scalar(src) | crate::TypeInner::Vector { scalar: src, .. } => {
1991 let target_scalar = crate::Scalar {
1992 kind,
1993 width: convert.unwrap_or(src.width),
1994 };
1995 let op = match convert {
1996 Some(_) => "static_cast",
1997 None => "as_type",
1998 };
1999 write!(self.out, "{op}<")?;
2000 match *context.resolve_type(expr) {
2001 crate::TypeInner::Vector { size, .. } => {
2002 put_numeric_type(&mut self.out, target_scalar, &[size])?
2003 }
2004 _ => put_numeric_type(&mut self.out, target_scalar, &[])?,
2005 };
2006 write!(self.out, ">(")?;
2007 self.put_expression(expr, context, true)?;
2008 write!(self.out, ")")?;
2009 }
2010 crate::TypeInner::Matrix {
2011 columns,
2012 rows,
2013 scalar,
2014 } => {
2015 let target_scalar = crate::Scalar {
2016 kind,
2017 width: convert.unwrap_or(scalar.width),
2018 };
2019 put_numeric_type(&mut self.out, target_scalar, &[rows, columns])?;
2020 write!(self.out, "(")?;
2021 self.put_expression(expr, context, true)?;
2022 write!(self.out, ")")?;
2023 }
2024 ref ty => {
2025 return Err(Error::GenericValidation(format!(
2026 "Unsupported type for As: {ty:?}"
2027 )))
2028 }
2029 },
2030 crate::Expression::CallResult(_)
2032 | crate::Expression::AtomicResult { .. }
2033 | crate::Expression::WorkGroupUniformLoadResult { .. }
2034 | crate::Expression::SubgroupBallotResult
2035 | crate::Expression::SubgroupOperationResult { .. }
2036 | crate::Expression::RayQueryProceedResult => {
2037 unreachable!()
2038 }
2039 crate::Expression::ArrayLength(expr) => {
2040 let global = match context.function.expressions[expr] {
2042 crate::Expression::AccessIndex { base, .. } => {
2043 match context.function.expressions[base] {
2044 crate::Expression::GlobalVariable(handle) => handle,
2045 ref ex => {
2046 return Err(Error::GenericValidation(format!(
2047 "Expected global variable in AccessIndex, got {ex:?}"
2048 )))
2049 }
2050 }
2051 }
2052 crate::Expression::GlobalVariable(handle) => handle,
2053 ref ex => {
2054 return Err(Error::GenericValidation(format!(
2055 "Unexpected expression in ArrayLength, got {ex:?}"
2056 )))
2057 }
2058 };
2059
2060 if !is_scoped {
2061 write!(self.out, "(")?;
2062 }
2063 write!(self.out, "1 + ")?;
2064 self.put_dynamic_array_max_index(global, context)?;
2065 if !is_scoped {
2066 write!(self.out, ")")?;
2067 }
2068 }
2069 crate::Expression::RayQueryGetIntersection { query, committed } => {
2070 if context.lang_version < (2, 4) {
2071 return Err(Error::UnsupportedRayTracing);
2072 }
2073
2074 if !committed {
2075 unimplemented!()
2076 }
2077 let ty = context.module.special_types.ray_intersection.unwrap();
2078 let type_name = &self.names[&NameKey::Type(ty)];
2079 write!(self.out, "{type_name} {{{RAY_QUERY_FUN_MAP_INTERSECTION}(")?;
2080 self.put_expression(query, context, true)?;
2081 write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.type)")?;
2082 let fields = [
2083 "distance",
2084 "user_instance_id", "instance_id",
2086 "", "geometry_id",
2088 "primitive_id",
2089 "triangle_barycentric_coord",
2090 "triangle_front_facing",
2091 "", "object_to_world_transform", "world_to_object_transform", ];
2095 for field in fields {
2096 write!(self.out, ", ")?;
2097 if field.is_empty() {
2098 write!(self.out, "{{}}")?;
2099 } else {
2100 self.put_expression(query, context, true)?;
2101 write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.{field}")?;
2102 }
2103 }
2104 write!(self.out, "}}")?;
2105 }
2106 }
2107 Ok(())
2108 }
2109
2110 fn put_wrapped_expression_for_packed_vec3_access(
2112 &mut self,
2113 expr_handle: Handle<crate::Expression>,
2114 context: &ExpressionContext,
2115 is_scoped: bool,
2116 ) -> BackendResult {
2117 if let Some(scalar) = context.get_packed_vec_kind(expr_handle) {
2118 write!(self.out, "{}::{}3(", NAMESPACE, scalar.to_msl_name())?;
2119 self.put_expression(expr_handle, context, is_scoped)?;
2120 write!(self.out, ")")?;
2121 } else {
2122 self.put_expression(expr_handle, context, is_scoped)?;
2123 }
2124 Ok(())
2125 }
2126
2127 fn put_index(
2129 &mut self,
2130 index: index::GuardedIndex,
2131 context: &ExpressionContext,
2132 is_scoped: bool,
2133 ) -> BackendResult {
2134 match index {
2135 index::GuardedIndex::Expression(expr) => {
2136 self.put_expression(expr, context, is_scoped)?
2137 }
2138 index::GuardedIndex::Known(value) => write!(self.out, "{value}")?,
2139 }
2140 Ok(())
2141 }
2142
2143 #[allow(unused_variables)]
2173 fn put_bounds_checks(
2174 &mut self,
2175 mut chain: Handle<crate::Expression>,
2176 context: &ExpressionContext,
2177 level: back::Level,
2178 prefix: &'static str,
2179 ) -> Result<bool, Error> {
2180 let mut check_written = false;
2181
2182 loop {
2184 let (base, guarded_index) = match context.function.expressions[chain] {
2187 crate::Expression::Access { base, index } => {
2188 (base, Some(index::GuardedIndex::Expression(index)))
2189 }
2190 crate::Expression::AccessIndex { base, index } => {
2191 let mut base_inner = context.resolve_type(base);
2194 if let crate::TypeInner::Pointer { base, .. } = *base_inner {
2195 base_inner = &context.module.types[base].inner;
2196 }
2197 match *base_inner {
2198 crate::TypeInner::Struct { .. } => (base, None),
2199 _ => (base, Some(index::GuardedIndex::Known(index))),
2200 }
2201 }
2202 _ => break,
2203 };
2204
2205 if let Some(index) = guarded_index {
2206 if let Some(length) = context.access_needs_check(base, index) {
2207 if check_written {
2208 write!(self.out, " && ")?;
2209 } else {
2210 write!(self.out, "{level}{prefix}")?;
2211 check_written = true;
2212 }
2213
2214 write!(self.out, "uint(")?;
2218 self.put_index(index, context, true)?;
2219 self.out.write_str(") < ")?;
2220 match length {
2221 index::IndexableLength::Known(value) => write!(self.out, "{value}")?,
2222 index::IndexableLength::Dynamic => {
2223 let global =
2224 context.function.originating_global(base).ok_or_else(|| {
2225 Error::GenericValidation(
2226 "Could not find originating global".into(),
2227 )
2228 })?;
2229 write!(self.out, "1 + ")?;
2230 self.put_dynamic_array_max_index(global, context)?
2231 }
2232 }
2233 }
2234 }
2235
2236 chain = base
2237 }
2238
2239 Ok(check_written)
2240 }
2241
2242 fn put_access_chain(
2262 &mut self,
2263 chain: Handle<crate::Expression>,
2264 policy: index::BoundsCheckPolicy,
2265 context: &ExpressionContext,
2266 ) -> BackendResult {
2267 match context.function.expressions[chain] {
2268 crate::Expression::Access { base, index } => {
2269 let mut base_ty = context.resolve_type(base);
2270
2271 if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
2273 base_ty = &context.module.types[base].inner;
2274 }
2275
2276 self.put_subscripted_access_chain(
2277 base,
2278 base_ty,
2279 index::GuardedIndex::Expression(index),
2280 policy,
2281 context,
2282 )?;
2283 }
2284 crate::Expression::AccessIndex { base, index } => {
2285 let base_resolution = &context.info[base].ty;
2286 let mut base_ty = base_resolution.inner_with(&context.module.types);
2287 let mut base_ty_handle = base_resolution.handle();
2288
2289 if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
2291 base_ty = &context.module.types[base].inner;
2292 base_ty_handle = Some(base);
2293 }
2294
2295 match *base_ty {
2299 crate::TypeInner::Struct { .. } => {
2300 let base_ty = base_ty_handle.unwrap();
2301 self.put_access_chain(base, policy, context)?;
2302 let name = &self.names[&NameKey::StructMember(base_ty, index)];
2303 write!(self.out, ".{name}")?;
2304 }
2305 crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Vector { .. } => {
2306 self.put_access_chain(base, policy, context)?;
2307 if context.get_packed_vec_kind(base).is_some() {
2310 write!(self.out, "[{index}]")?;
2311 } else {
2312 write!(self.out, ".{}", back::COMPONENTS[index as usize])?;
2313 }
2314 }
2315 _ => {
2316 self.put_subscripted_access_chain(
2317 base,
2318 base_ty,
2319 index::GuardedIndex::Known(index),
2320 policy,
2321 context,
2322 )?;
2323 }
2324 }
2325 }
2326 _ => self.put_expression(chain, context, false)?,
2327 }
2328
2329 Ok(())
2330 }
2331
2332 fn put_subscripted_access_chain(
2349 &mut self,
2350 base: Handle<crate::Expression>,
2351 base_ty: &crate::TypeInner,
2352 index: index::GuardedIndex,
2353 policy: index::BoundsCheckPolicy,
2354 context: &ExpressionContext,
2355 ) -> BackendResult {
2356 let accessing_wrapped_array = match *base_ty {
2357 crate::TypeInner::Array {
2358 size: crate::ArraySize::Constant(_),
2359 ..
2360 } => true,
2361 _ => false,
2362 };
2363
2364 self.put_access_chain(base, policy, context)?;
2365 if accessing_wrapped_array {
2366 write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
2367 }
2368 write!(self.out, "[")?;
2369
2370 let restriction_needed = if policy == index::BoundsCheckPolicy::Restrict {
2372 context.access_needs_check(base, index)
2373 } else {
2374 None
2375 };
2376 if let Some(limit) = restriction_needed {
2377 write!(self.out, "{NAMESPACE}::min(unsigned(")?;
2378 self.put_index(index, context, true)?;
2379 write!(self.out, "), ")?;
2380 match limit {
2381 index::IndexableLength::Known(limit) => {
2382 write!(self.out, "{}u", limit - 1)?;
2383 }
2384 index::IndexableLength::Dynamic => {
2385 let global = context.function.originating_global(base).ok_or_else(|| {
2386 Error::GenericValidation("Could not find originating global".into())
2387 })?;
2388 self.put_dynamic_array_max_index(global, context)?;
2389 }
2390 }
2391 write!(self.out, ")")?;
2392 } else {
2393 self.put_index(index, context, true)?;
2394 }
2395
2396 write!(self.out, "]")?;
2397
2398 Ok(())
2399 }
2400
2401 fn put_load(
2402 &mut self,
2403 pointer: Handle<crate::Expression>,
2404 context: &ExpressionContext,
2405 is_scoped: bool,
2406 ) -> BackendResult {
2407 let policy = context.choose_bounds_check_policy(pointer);
2410 if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
2411 && self.put_bounds_checks(
2412 pointer,
2413 context,
2414 back::Level(0),
2415 if is_scoped { "" } else { "(" },
2416 )?
2417 {
2418 write!(self.out, " ? ")?;
2419 self.put_unchecked_load(pointer, policy, context)?;
2420 write!(self.out, " : DefaultConstructible()")?;
2421
2422 if !is_scoped {
2423 write!(self.out, ")")?;
2424 }
2425 } else {
2426 self.put_unchecked_load(pointer, policy, context)?;
2427 }
2428
2429 Ok(())
2430 }
2431
2432 fn put_unchecked_load(
2433 &mut self,
2434 pointer: Handle<crate::Expression>,
2435 policy: index::BoundsCheckPolicy,
2436 context: &ExpressionContext,
2437 ) -> BackendResult {
2438 let is_atomic_pointer = context
2439 .resolve_type(pointer)
2440 .is_atomic_pointer(&context.module.types);
2441
2442 if is_atomic_pointer {
2443 write!(
2444 self.out,
2445 "{NAMESPACE}::atomic_load_explicit({ATOMIC_REFERENCE}"
2446 )?;
2447 self.put_access_chain(pointer, policy, context)?;
2448 write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
2449 } else {
2450 self.put_access_chain(pointer, policy, context)?;
2454 }
2455
2456 Ok(())
2457 }
2458
2459 fn put_return_value(
2460 &mut self,
2461 level: back::Level,
2462 expr_handle: Handle<crate::Expression>,
2463 result_struct: Option<&str>,
2464 context: &ExpressionContext,
2465 ) -> BackendResult {
2466 match result_struct {
2467 Some(struct_name) => {
2468 let mut has_point_size = false;
2469 let result_ty = context.function.result.as_ref().unwrap().ty;
2470 match context.module.types[result_ty].inner {
2471 crate::TypeInner::Struct { ref members, .. } => {
2472 let tmp = "_tmp";
2473 write!(self.out, "{level}const auto {tmp} = ")?;
2474 self.put_expression(expr_handle, context, true)?;
2475 writeln!(self.out, ";")?;
2476 write!(self.out, "{level}return {struct_name} {{")?;
2477
2478 let mut is_first = true;
2479
2480 for (index, member) in members.iter().enumerate() {
2481 if let Some(crate::Binding::BuiltIn(crate::BuiltIn::PointSize)) =
2482 member.binding
2483 {
2484 has_point_size = true;
2485 if !context.pipeline_options.allow_and_force_point_size {
2486 continue;
2487 }
2488 }
2489
2490 let comma = if is_first { "" } else { "," };
2491 is_first = false;
2492 let name = &self.names[&NameKey::StructMember(result_ty, index as u32)];
2493 if let crate::TypeInner::Array {
2497 size: crate::ArraySize::Constant(size),
2498 ..
2499 } = context.module.types[member.ty].inner
2500 {
2501 write!(self.out, "{comma} {{")?;
2502 for j in 0..size.get() {
2503 if j != 0 {
2504 write!(self.out, ",")?;
2505 }
2506 write!(self.out, "{tmp}.{name}.{WRAPPED_ARRAY_FIELD}[{j}]")?;
2507 }
2508 write!(self.out, "}}")?;
2509 } else {
2510 write!(self.out, "{comma} {tmp}.{name}")?;
2511 }
2512 }
2513 }
2514 _ => {
2515 write!(self.out, "{level}return {struct_name} {{ ")?;
2516 self.put_expression(expr_handle, context, true)?;
2517 }
2518 }
2519
2520 if let FunctionOrigin::EntryPoint(ep_index) = context.origin {
2521 let stage = context.module.entry_points[ep_index as usize].stage;
2522 if context.pipeline_options.allow_and_force_point_size
2523 && stage == crate::ShaderStage::Vertex
2524 && !has_point_size
2525 {
2526 write!(self.out, ", 1.0")?;
2528 }
2529 }
2530 write!(self.out, " }}")?;
2531 }
2532 None => {
2533 write!(self.out, "{level}return ")?;
2534 self.put_expression(expr_handle, context, true)?;
2535 }
2536 }
2537 writeln!(self.out, ";")?;
2538 Ok(())
2539 }
2540
2541 fn update_expressions_to_bake(
2546 &mut self,
2547 func: &crate::Function,
2548 info: &valid::FunctionInfo,
2549 context: &ExpressionContext,
2550 ) {
2551 use crate::Expression;
2552 self.need_bake_expressions.clear();
2553
2554 for (expr_handle, expr) in func.expressions.iter() {
2555 let expr_info = &info[expr_handle];
2558 let min_ref_count = func.expressions[expr_handle].bake_ref_count();
2559 if min_ref_count <= expr_info.ref_count {
2560 self.need_bake_expressions.insert(expr_handle);
2561 } else {
2562 match expr_info.ty {
2563 TypeResolution::Handle(h)
2565 if Some(h) == context.module.special_types.ray_desc =>
2566 {
2567 self.need_bake_expressions.insert(expr_handle);
2568 }
2569 _ => {}
2570 }
2571 }
2572
2573 if let Expression::Math {
2574 fun,
2575 arg,
2576 arg1,
2577 arg2,
2578 ..
2579 } = *expr
2580 {
2581 match fun {
2582 crate::MathFunction::Dot => {
2583 let inner = context.resolve_type(expr_handle);
2592 if let crate::TypeInner::Scalar(scalar) = *inner {
2593 match scalar.kind {
2594 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
2595 self.need_bake_expressions.insert(arg);
2596 self.need_bake_expressions.insert(arg1.unwrap());
2597 }
2598 _ => {}
2599 }
2600 }
2601 }
2602 crate::MathFunction::FindMsb => {
2603 self.need_bake_expressions.insert(arg);
2604 }
2605 crate::MathFunction::ExtractBits => {
2606 self.need_bake_expressions.insert(arg1.unwrap());
2608 }
2609 crate::MathFunction::InsertBits => {
2610 self.need_bake_expressions.insert(arg2.unwrap());
2612 }
2613 crate::MathFunction::Sign => {
2614 let inner = context.resolve_type(expr_handle);
2619 if inner.scalar_kind() == Some(crate::ScalarKind::Sint) {
2620 self.need_bake_expressions.insert(arg);
2621 }
2622 }
2623 _ => {}
2624 }
2625 }
2626 }
2627 }
2628
2629 fn start_baking_expression(
2630 &mut self,
2631 handle: Handle<crate::Expression>,
2632 context: &ExpressionContext,
2633 name: &str,
2634 ) -> BackendResult {
2635 match context.info[handle].ty {
2636 TypeResolution::Handle(ty_handle) => {
2637 let ty_name = TypeContext {
2638 handle: ty_handle,
2639 gctx: context.module.to_ctx(),
2640 names: &self.names,
2641 access: crate::StorageAccess::empty(),
2642 binding: None,
2643 first_time: false,
2644 };
2645 write!(self.out, "{ty_name}")?;
2646 }
2647 TypeResolution::Value(crate::TypeInner::Scalar(scalar)) => {
2648 put_numeric_type(&mut self.out, scalar, &[])?;
2649 }
2650 TypeResolution::Value(crate::TypeInner::Vector { size, scalar }) => {
2651 put_numeric_type(&mut self.out, scalar, &[size])?;
2652 }
2653 TypeResolution::Value(crate::TypeInner::Matrix {
2654 columns,
2655 rows,
2656 scalar,
2657 }) => {
2658 put_numeric_type(&mut self.out, scalar, &[rows, columns])?;
2659 }
2660 TypeResolution::Value(ref other) => {
2661 log::warn!("Type {:?} isn't a known local", other); return Err(Error::FeatureNotImplemented("weird local type".to_string()));
2663 }
2664 }
2665
2666 write!(self.out, " {name} = ")?;
2668
2669 Ok(())
2670 }
2671
2672 fn put_cache_restricted_level(
2685 &mut self,
2686 load: Handle<crate::Expression>,
2687 image: Handle<crate::Expression>,
2688 mip_level: Option<Handle<crate::Expression>>,
2689 indent: back::Level,
2690 context: &StatementContext,
2691 ) -> BackendResult {
2692 let level_of_detail = match mip_level {
2695 Some(level) => level,
2696 None => return Ok(()),
2697 };
2698
2699 if context.expression.policies.image_load != index::BoundsCheckPolicy::Restrict
2700 || !context.expression.image_needs_lod(image)
2701 {
2702 return Ok(());
2703 }
2704
2705 write!(
2706 self.out,
2707 "{}uint {}{} = ",
2708 indent,
2709 CLAMPED_LOD_LOAD_PREFIX,
2710 load.index(),
2711 )?;
2712 self.put_restricted_scalar_image_index(
2713 image,
2714 level_of_detail,
2715 "get_num_mip_levels",
2716 &context.expression,
2717 )?;
2718 writeln!(self.out, ";")?;
2719
2720 Ok(())
2721 }
2722
2723 fn put_block(
2724 &mut self,
2725 level: back::Level,
2726 statements: &[crate::Statement],
2727 context: &StatementContext,
2728 ) -> BackendResult {
2729 #[cfg(test)]
2731 #[allow(trivial_casts)]
2732 self.put_block_stack_pointers
2733 .insert(&level as *const _ as *const ());
2734
2735 for statement in statements {
2736 log::trace!("statement[{}] {:?}", level.0, statement);
2737 match *statement {
2738 crate::Statement::Emit(ref range) => {
2739 for handle in range.clone() {
2740 if let crate::Expression::ImageLoad {
2743 image,
2744 level: mip_level,
2745 ..
2746 } = context.expression.function.expressions[handle]
2747 {
2748 self.put_cache_restricted_level(
2749 handle, image, mip_level, level, context,
2750 )?;
2751 }
2752
2753 let ptr_class = context.expression.resolve_type(handle).pointer_space();
2754 let expr_name = if ptr_class.is_some() {
2755 None } else if let Some(name) =
2757 context.expression.function.named_expressions.get(&handle)
2758 {
2759 Some(self.namer.call(name))
2769 } else {
2770 let bake =
2774 if context.expression.guarded_indices.contains(handle.index()) {
2775 true
2776 } else {
2777 self.need_bake_expressions.contains(&handle)
2778 };
2779
2780 if bake {
2781 Some(format!("{}{}", back::BAKE_PREFIX, handle.index()))
2782 } else {
2783 None
2784 }
2785 };
2786
2787 if let Some(name) = expr_name {
2788 write!(self.out, "{level}")?;
2789 self.start_baking_expression(handle, &context.expression, &name)?;
2790 self.put_expression(handle, &context.expression, true)?;
2791 self.named_expressions.insert(handle, name);
2792 writeln!(self.out, ";")?;
2793 }
2794 }
2795 }
2796 crate::Statement::Block(ref block) => {
2797 if !block.is_empty() {
2798 writeln!(self.out, "{level}{{")?;
2799 self.put_block(level.next(), block, context)?;
2800 writeln!(self.out, "{level}}}")?;
2801 }
2802 }
2803 crate::Statement::If {
2804 condition,
2805 ref accept,
2806 ref reject,
2807 } => {
2808 write!(self.out, "{level}if (")?;
2809 self.put_expression(condition, &context.expression, true)?;
2810 writeln!(self.out, ") {{")?;
2811 self.put_block(level.next(), accept, context)?;
2812 if !reject.is_empty() {
2813 writeln!(self.out, "{level}}} else {{")?;
2814 self.put_block(level.next(), reject, context)?;
2815 }
2816 writeln!(self.out, "{level}}}")?;
2817 }
2818 crate::Statement::Switch {
2819 selector,
2820 ref cases,
2821 } => {
2822 write!(self.out, "{level}switch(")?;
2823 self.put_expression(selector, &context.expression, true)?;
2824 writeln!(self.out, ") {{")?;
2825 let lcase = level.next();
2826 for case in cases.iter() {
2827 match case.value {
2828 crate::SwitchValue::I32(value) => {
2829 write!(self.out, "{lcase}case {value}:")?;
2830 }
2831 crate::SwitchValue::U32(value) => {
2832 write!(self.out, "{lcase}case {value}u:")?;
2833 }
2834 crate::SwitchValue::Default => {
2835 write!(self.out, "{lcase}default:")?;
2836 }
2837 }
2838
2839 let write_block_braces = !(case.fall_through && case.body.is_empty());
2840 if write_block_braces {
2841 writeln!(self.out, " {{")?;
2842 } else {
2843 writeln!(self.out)?;
2844 }
2845
2846 self.put_block(lcase.next(), &case.body, context)?;
2847 if !case.fall_through
2848 && case.body.last().map_or(true, |s| !s.is_terminator())
2849 {
2850 writeln!(self.out, "{}break;", lcase.next())?;
2851 }
2852
2853 if write_block_braces {
2854 writeln!(self.out, "{lcase}}}")?;
2855 }
2856 }
2857 writeln!(self.out, "{level}}}")?;
2858 }
2859 crate::Statement::Loop {
2860 ref body,
2861 ref continuing,
2862 break_if,
2863 } => {
2864 if !continuing.is_empty() || break_if.is_some() {
2865 let gate_name = self.namer.call("loop_init");
2866 writeln!(self.out, "{level}bool {gate_name} = true;")?;
2867 writeln!(self.out, "{level}while(true) {{")?;
2868 let lif = level.next();
2869 let lcontinuing = lif.next();
2870 writeln!(self.out, "{lif}if (!{gate_name}) {{")?;
2871 self.put_block(lcontinuing, continuing, context)?;
2872 if let Some(condition) = break_if {
2873 write!(self.out, "{lcontinuing}if (")?;
2874 self.put_expression(condition, &context.expression, true)?;
2875 writeln!(self.out, ") {{")?;
2876 writeln!(self.out, "{}break;", lcontinuing.next())?;
2877 writeln!(self.out, "{lcontinuing}}}")?;
2878 }
2879 writeln!(self.out, "{lif}}}")?;
2880 writeln!(self.out, "{lif}{gate_name} = false;")?;
2881 } else {
2882 writeln!(self.out, "{level}while(true) {{")?;
2883 }
2884 self.put_block(level.next(), body, context)?;
2885 writeln!(self.out, "{level}}}")?;
2886 }
2887 crate::Statement::Break => {
2888 writeln!(self.out, "{level}break;")?;
2889 }
2890 crate::Statement::Continue => {
2891 writeln!(self.out, "{level}continue;")?;
2892 }
2893 crate::Statement::Return {
2894 value: Some(expr_handle),
2895 } => {
2896 self.put_return_value(
2897 level,
2898 expr_handle,
2899 context.result_struct,
2900 &context.expression,
2901 )?;
2902 }
2903 crate::Statement::Return { value: None } => {
2904 writeln!(self.out, "{level}return;")?;
2905 }
2906 crate::Statement::Kill => {
2907 writeln!(self.out, "{level}{NAMESPACE}::discard_fragment();")?;
2908 }
2909 crate::Statement::Barrier(flags) => {
2910 self.write_barrier(flags, level)?;
2911 }
2912 crate::Statement::Store { pointer, value } => {
2913 self.put_store(pointer, value, level, context)?
2914 }
2915 crate::Statement::ImageStore {
2916 image,
2917 coordinate,
2918 array_index,
2919 value,
2920 } => {
2921 let address = TexelAddress {
2922 coordinate,
2923 array_index,
2924 sample: None,
2925 level: None,
2926 };
2927 self.put_image_store(level, image, &address, value, context)?
2928 }
2929 crate::Statement::Call {
2930 function,
2931 ref arguments,
2932 result,
2933 } => {
2934 write!(self.out, "{level}")?;
2935 if let Some(expr) = result {
2936 let name = format!("{}{}", back::BAKE_PREFIX, expr.index());
2937 self.start_baking_expression(expr, &context.expression, &name)?;
2938 self.named_expressions.insert(expr, name);
2939 }
2940 let fun_name = &self.names[&NameKey::Function(function)];
2941 write!(self.out, "{fun_name}(")?;
2942 for (i, &handle) in arguments.iter().enumerate() {
2944 if i != 0 {
2945 write!(self.out, ", ")?;
2946 }
2947 self.put_expression(handle, &context.expression, true)?;
2948 }
2949 let mut separate = !arguments.is_empty();
2951 let fun_info = &context.expression.mod_info[function];
2952 let mut supports_array_length = false;
2953 for (handle, var) in context.expression.module.global_variables.iter() {
2954 if fun_info[handle].is_empty() {
2955 continue;
2956 }
2957 if var.space.needs_pass_through() {
2958 let name = &self.names[&NameKey::GlobalVariable(handle)];
2959 if separate {
2960 write!(self.out, ", ")?;
2961 } else {
2962 separate = true;
2963 }
2964 write!(self.out, "{name}")?;
2965 }
2966 supports_array_length |=
2967 needs_array_length(var.ty, &context.expression.module.types);
2968 }
2969 if supports_array_length {
2970 if separate {
2971 write!(self.out, ", ")?;
2972 }
2973 write!(self.out, "_buffer_sizes")?;
2974 }
2975
2976 writeln!(self.out, ");")?;
2978 }
2979 crate::Statement::Atomic {
2980 pointer,
2981 ref fun,
2982 value,
2983 result,
2984 } => {
2985 write!(self.out, "{level}")?;
2986 let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
2987 self.start_baking_expression(result, &context.expression, &res_name)?;
2988 self.named_expressions.insert(result, res_name);
2989 let fun_str = fun.to_msl()?;
2990 self.put_atomic_operation(pointer, fun_str, value, &context.expression)?;
2991 writeln!(self.out, ";")?;
2993 }
2994 crate::Statement::WorkGroupUniformLoad { pointer, result } => {
2995 self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
2996
2997 write!(self.out, "{level}")?;
2998 let name = self.namer.call("");
2999 self.start_baking_expression(result, &context.expression, &name)?;
3000 self.put_load(pointer, &context.expression, true)?;
3001 self.named_expressions.insert(result, name);
3002
3003 writeln!(self.out, ";")?;
3004 self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
3005 }
3006 crate::Statement::RayQuery { query, ref fun } => {
3007 if context.expression.lang_version < (2, 4) {
3008 return Err(Error::UnsupportedRayTracing);
3009 }
3010
3011 match *fun {
3012 crate::RayQueryFunction::Initialize {
3013 acceleration_structure,
3014 descriptor,
3015 } => {
3016 write!(self.out, "{level}")?;
3018 self.put_expression(query, &context.expression, true)?;
3019 writeln!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.assume_geometry_type({RT_NAMESPACE}::geometry_type::triangle);")?;
3020 {
3021 let f_opaque = back::RayFlag::CULL_OPAQUE.bits();
3022 let f_no_opaque = back::RayFlag::CULL_NO_OPAQUE.bits();
3023 write!(self.out, "{level}")?;
3024 self.put_expression(query, &context.expression, true)?;
3025 write!(
3026 self.out,
3027 ".{RAY_QUERY_FIELD_INTERSECTOR}.set_opacity_cull_mode(("
3028 )?;
3029 self.put_expression(descriptor, &context.expression, true)?;
3030 write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::opaque : (")?;
3031 self.put_expression(descriptor, &context.expression, true)?;
3032 write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::non_opaque : ")?;
3033 writeln!(self.out, "{RT_NAMESPACE}::opacity_cull_mode::none);")?;
3034 }
3035 {
3036 let f_opaque = back::RayFlag::OPAQUE.bits();
3037 let f_no_opaque = back::RayFlag::NO_OPAQUE.bits();
3038 write!(self.out, "{level}")?;
3039 self.put_expression(query, &context.expression, true)?;
3040 write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.force_opacity((")?;
3041 self.put_expression(descriptor, &context.expression, true)?;
3042 write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::opaque : (")?;
3043 self.put_expression(descriptor, &context.expression, true)?;
3044 write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::non_opaque : ")?;
3045 writeln!(self.out, "{RT_NAMESPACE}::forced_opacity::none);")?;
3046 }
3047 {
3048 let flag = back::RayFlag::TERMINATE_ON_FIRST_HIT.bits();
3049 write!(self.out, "{level}")?;
3050 self.put_expression(query, &context.expression, true)?;
3051 write!(
3052 self.out,
3053 ".{RAY_QUERY_FIELD_INTERSECTOR}.accept_any_intersection(("
3054 )?;
3055 self.put_expression(descriptor, &context.expression, true)?;
3056 writeln!(self.out, ".flags & {flag}) != 0);")?;
3057 }
3058
3059 write!(self.out, "{level}")?;
3060 self.put_expression(query, &context.expression, true)?;
3061 write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION} = ")?;
3062 self.put_expression(query, &context.expression, true)?;
3063 write!(
3064 self.out,
3065 ".{RAY_QUERY_FIELD_INTERSECTOR}.intersect({RT_NAMESPACE}::ray("
3066 )?;
3067 self.put_expression(descriptor, &context.expression, true)?;
3068 write!(self.out, ".origin, ")?;
3069 self.put_expression(descriptor, &context.expression, true)?;
3070 write!(self.out, ".dir, ")?;
3071 self.put_expression(descriptor, &context.expression, true)?;
3072 write!(self.out, ".tmin, ")?;
3073 self.put_expression(descriptor, &context.expression, true)?;
3074 write!(self.out, ".tmax), ")?;
3075 self.put_expression(acceleration_structure, &context.expression, true)?;
3076 write!(self.out, ", ")?;
3077 self.put_expression(descriptor, &context.expression, true)?;
3078 write!(self.out, ".cull_mask);")?;
3079
3080 write!(self.out, "{level}")?;
3081 self.put_expression(query, &context.expression, true)?;
3082 writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = true;")?;
3083 }
3084 crate::RayQueryFunction::Proceed { result } => {
3085 write!(self.out, "{level}")?;
3086 let name = format!("{}{}", back::BAKE_PREFIX, result.index());
3087 self.start_baking_expression(result, &context.expression, &name)?;
3088 self.named_expressions.insert(result, name);
3089 self.put_expression(query, &context.expression, true)?;
3090 writeln!(self.out, ".{RAY_QUERY_FIELD_READY};")?;
3091 write!(self.out, "{level}")?;
3094 self.put_expression(query, &context.expression, true)?;
3095 writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = false;")?;
3096 }
3097 crate::RayQueryFunction::Terminate => {
3098 write!(self.out, "{level}")?;
3099 self.put_expression(query, &context.expression, true)?;
3100 writeln!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.abort();")?;
3101 }
3102 }
3103 }
3104 crate::Statement::SubgroupBallot { result, predicate } => {
3105 write!(self.out, "{level}")?;
3106 let name = self.namer.call("");
3107 self.start_baking_expression(result, &context.expression, &name)?;
3108 self.named_expressions.insert(result, name);
3109 write!(self.out, "uint4((uint64_t){NAMESPACE}::simd_ballot(")?;
3110 if let Some(predicate) = predicate {
3111 self.put_expression(predicate, &context.expression, true)?;
3112 } else {
3113 write!(self.out, "true")?;
3114 }
3115 writeln!(self.out, "), 0, 0, 0);")?;
3116 }
3117 crate::Statement::SubgroupCollectiveOperation {
3118 op,
3119 collective_op,
3120 argument,
3121 result,
3122 } => {
3123 write!(self.out, "{level}")?;
3124 let name = self.namer.call("");
3125 self.start_baking_expression(result, &context.expression, &name)?;
3126 self.named_expressions.insert(result, name);
3127 match (collective_op, op) {
3128 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
3129 write!(self.out, "{NAMESPACE}::simd_all(")?
3130 }
3131 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
3132 write!(self.out, "{NAMESPACE}::simd_any(")?
3133 }
3134 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
3135 write!(self.out, "{NAMESPACE}::simd_sum(")?
3136 }
3137 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
3138 write!(self.out, "{NAMESPACE}::simd_product(")?
3139 }
3140 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
3141 write!(self.out, "{NAMESPACE}::simd_max(")?
3142 }
3143 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
3144 write!(self.out, "{NAMESPACE}::simd_min(")?
3145 }
3146 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
3147 write!(self.out, "{NAMESPACE}::simd_and(")?
3148 }
3149 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
3150 write!(self.out, "{NAMESPACE}::simd_or(")?
3151 }
3152 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
3153 write!(self.out, "{NAMESPACE}::simd_xor(")?
3154 }
3155 (
3156 crate::CollectiveOperation::ExclusiveScan,
3157 crate::SubgroupOperation::Add,
3158 ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_sum(")?,
3159 (
3160 crate::CollectiveOperation::ExclusiveScan,
3161 crate::SubgroupOperation::Mul,
3162 ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_product(")?,
3163 (
3164 crate::CollectiveOperation::InclusiveScan,
3165 crate::SubgroupOperation::Add,
3166 ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_sum(")?,
3167 (
3168 crate::CollectiveOperation::InclusiveScan,
3169 crate::SubgroupOperation::Mul,
3170 ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_product(")?,
3171 _ => unimplemented!(),
3172 }
3173 self.put_expression(argument, &context.expression, true)?;
3174 writeln!(self.out, ");")?;
3175 }
3176 crate::Statement::SubgroupGather {
3177 mode,
3178 argument,
3179 result,
3180 } => {
3181 write!(self.out, "{level}")?;
3182 let name = self.namer.call("");
3183 self.start_baking_expression(result, &context.expression, &name)?;
3184 self.named_expressions.insert(result, name);
3185 match mode {
3186 crate::GatherMode::BroadcastFirst => {
3187 write!(self.out, "{NAMESPACE}::simd_broadcast_first(")?;
3188 }
3189 crate::GatherMode::Broadcast(_) => {
3190 write!(self.out, "{NAMESPACE}::simd_broadcast(")?;
3191 }
3192 crate::GatherMode::Shuffle(_) => {
3193 write!(self.out, "{NAMESPACE}::simd_shuffle(")?;
3194 }
3195 crate::GatherMode::ShuffleDown(_) => {
3196 write!(self.out, "{NAMESPACE}::simd_shuffle_down(")?;
3197 }
3198 crate::GatherMode::ShuffleUp(_) => {
3199 write!(self.out, "{NAMESPACE}::simd_shuffle_up(")?;
3200 }
3201 crate::GatherMode::ShuffleXor(_) => {
3202 write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?;
3203 }
3204 }
3205 self.put_expression(argument, &context.expression, true)?;
3206 match mode {
3207 crate::GatherMode::BroadcastFirst => {}
3208 crate::GatherMode::Broadcast(index)
3209 | crate::GatherMode::Shuffle(index)
3210 | crate::GatherMode::ShuffleDown(index)
3211 | crate::GatherMode::ShuffleUp(index)
3212 | crate::GatherMode::ShuffleXor(index) => {
3213 write!(self.out, ", ")?;
3214 self.put_expression(index, &context.expression, true)?;
3215 }
3216 }
3217 writeln!(self.out, ");")?;
3218 }
3219 }
3220 }
3221
3222 for statement in statements {
3225 if let crate::Statement::Emit(ref range) = *statement {
3226 for handle in range.clone() {
3227 self.named_expressions.shift_remove(&handle);
3228 }
3229 }
3230 }
3231 Ok(())
3232 }
3233
3234 fn put_store(
3235 &mut self,
3236 pointer: Handle<crate::Expression>,
3237 value: Handle<crate::Expression>,
3238 level: back::Level,
3239 context: &StatementContext,
3240 ) -> BackendResult {
3241 let policy = context.expression.choose_bounds_check_policy(pointer);
3242 if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
3243 && self.put_bounds_checks(pointer, &context.expression, level, "if (")?
3244 {
3245 writeln!(self.out, ") {{")?;
3246 self.put_unchecked_store(pointer, value, policy, level.next(), context)?;
3247 writeln!(self.out, "{level}}}")?;
3248 } else {
3249 self.put_unchecked_store(pointer, value, policy, level, context)?;
3250 }
3251
3252 Ok(())
3253 }
3254
3255 fn put_unchecked_store(
3256 &mut self,
3257 pointer: Handle<crate::Expression>,
3258 value: Handle<crate::Expression>,
3259 policy: index::BoundsCheckPolicy,
3260 level: back::Level,
3261 context: &StatementContext,
3262 ) -> BackendResult {
3263 let is_atomic_pointer = context
3264 .expression
3265 .resolve_type(pointer)
3266 .is_atomic_pointer(&context.expression.module.types);
3267
3268 if is_atomic_pointer {
3269 write!(
3270 self.out,
3271 "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
3272 )?;
3273 self.put_access_chain(pointer, policy, &context.expression)?;
3274 write!(self.out, ", ")?;
3275 self.put_expression(value, &context.expression, true)?;
3276 writeln!(self.out, ", {NAMESPACE}::memory_order_relaxed);")?;
3277 } else {
3278 write!(self.out, "{level}")?;
3279 self.put_access_chain(pointer, policy, &context.expression)?;
3280 write!(self.out, " = ")?;
3281 self.put_expression(value, &context.expression, true)?;
3282 writeln!(self.out, ";")?;
3283 }
3284
3285 Ok(())
3286 }
3287
3288 pub fn write(
3289 &mut self,
3290 module: &crate::Module,
3291 info: &valid::ModuleInfo,
3292 options: &Options,
3293 pipeline_options: &PipelineOptions,
3294 ) -> Result<TranslationInfo, Error> {
3295 if !module.overrides.is_empty() {
3296 return Err(Error::Override);
3297 }
3298
3299 self.names.clear();
3300 self.namer.reset(
3301 module,
3302 super::keywords::RESERVED,
3303 &[],
3304 &[],
3305 &[CLAMPED_LOD_LOAD_PREFIX],
3306 &mut self.names,
3307 );
3308 self.struct_member_pads.clear();
3309
3310 writeln!(
3311 self.out,
3312 "// language: metal{}.{}",
3313 options.lang_version.0, options.lang_version.1
3314 )?;
3315 writeln!(self.out, "#include <metal_stdlib>")?;
3316 writeln!(self.out, "#include <simd/simd.h>")?;
3317 writeln!(self.out)?;
3318 writeln!(self.out, "using {NAMESPACE}::uint;")?;
3320
3321 let mut uses_ray_query = false;
3322 for (_, ty) in module.types.iter() {
3323 match ty.inner {
3324 crate::TypeInner::AccelerationStructure => {
3325 if options.lang_version < (2, 4) {
3326 return Err(Error::UnsupportedRayTracing);
3327 }
3328 }
3329 crate::TypeInner::RayQuery => {
3330 if options.lang_version < (2, 4) {
3331 return Err(Error::UnsupportedRayTracing);
3332 }
3333 uses_ray_query = true;
3334 }
3335 _ => (),
3336 }
3337 }
3338
3339 if module.special_types.ray_desc.is_some()
3340 || module.special_types.ray_intersection.is_some()
3341 {
3342 if options.lang_version < (2, 4) {
3343 return Err(Error::UnsupportedRayTracing);
3344 }
3345 }
3346
3347 if uses_ray_query {
3348 self.put_ray_query_type()?;
3349 }
3350
3351 if options
3352 .bounds_check_policies
3353 .contains(index::BoundsCheckPolicy::ReadZeroSkipWrite)
3354 {
3355 self.put_default_constructible()?;
3356 }
3357 writeln!(self.out)?;
3358
3359 {
3360 let mut indices = vec![];
3361 for (handle, var) in module.global_variables.iter() {
3362 if needs_array_length(var.ty, &module.types) {
3363 let idx = handle.index();
3364 indices.push(idx);
3365 }
3366 }
3367
3368 if !indices.is_empty() {
3369 writeln!(self.out, "struct _mslBufferSizes {{")?;
3370
3371 for idx in indices {
3372 writeln!(self.out, "{}uint size{};", back::INDENT, idx)?;
3373 }
3374
3375 writeln!(self.out, "}};")?;
3376 writeln!(self.out)?;
3377 }
3378 };
3379
3380 self.write_type_defs(module)?;
3381 self.write_global_constants(module, info)?;
3382 self.write_functions(module, info, options, pipeline_options)
3383 }
3384
3385 fn put_default_constructible(&mut self) -> BackendResult {
3398 let tab = back::INDENT;
3399 writeln!(self.out, "struct DefaultConstructible {{")?;
3400 writeln!(self.out, "{tab}template<typename T>")?;
3401 writeln!(self.out, "{tab}operator T() && {{")?;
3402 writeln!(self.out, "{tab}{tab}return T {{}};")?;
3403 writeln!(self.out, "{tab}}}")?;
3404 writeln!(self.out, "}};")?;
3405 Ok(())
3406 }
3407
3408 fn put_ray_query_type(&mut self) -> BackendResult {
3409 let tab = back::INDENT;
3410 writeln!(self.out, "struct {RAY_QUERY_TYPE} {{")?;
3411 let full_type = format!("{RT_NAMESPACE}::intersector<{RT_NAMESPACE}::instancing, {RT_NAMESPACE}::triangle_data, {RT_NAMESPACE}::world_space_data>");
3412 writeln!(self.out, "{tab}{full_type} {RAY_QUERY_FIELD_INTERSECTOR};")?;
3413 writeln!(
3414 self.out,
3415 "{tab}{full_type}::result_type {RAY_QUERY_FIELD_INTERSECTION};"
3416 )?;
3417 writeln!(self.out, "{tab}bool {RAY_QUERY_FIELD_READY} = false;")?;
3418 writeln!(self.out, "}};")?;
3419 writeln!(self.out, "constexpr {NAMESPACE}::uint {RAY_QUERY_FUN_MAP_INTERSECTION}(const {RT_NAMESPACE}::intersection_type ty) {{")?;
3420 let v_triangle = back::RayIntersectionType::Triangle as u32;
3421 let v_bbox = back::RayIntersectionType::BoundingBox as u32;
3422 writeln!(
3423 self.out,
3424 "{tab}return ty=={RT_NAMESPACE}::intersection_type::triangle ? {v_triangle} : "
3425 )?;
3426 writeln!(
3427 self.out,
3428 "{tab}{tab}ty=={RT_NAMESPACE}::intersection_type::bounding_box ? {v_bbox} : 0;"
3429 )?;
3430 writeln!(self.out, "}}")?;
3431 Ok(())
3432 }
3433
3434 fn write_type_defs(&mut self, module: &crate::Module) -> BackendResult {
3435 for (handle, ty) in module.types.iter() {
3436 if !ty.needs_alias() {
3437 continue;
3438 }
3439 let name = &self.names[&NameKey::Type(handle)];
3440 match ty.inner {
3441 crate::TypeInner::Array {
3455 base,
3456 size,
3457 stride: _,
3458 } => {
3459 let base_name = TypeContext {
3460 handle: base,
3461 gctx: module.to_ctx(),
3462 names: &self.names,
3463 access: crate::StorageAccess::empty(),
3464 binding: None,
3465 first_time: false,
3466 };
3467
3468 match size {
3469 crate::ArraySize::Constant(size) => {
3470 writeln!(self.out, "struct {name} {{")?;
3471 writeln!(
3472 self.out,
3473 "{}{} {}[{}];",
3474 back::INDENT,
3475 base_name,
3476 WRAPPED_ARRAY_FIELD,
3477 size
3478 )?;
3479 writeln!(self.out, "}};")?;
3480 }
3481 crate::ArraySize::Dynamic => {
3482 writeln!(self.out, "typedef {base_name} {name}[1];")?;
3483 }
3484 }
3485 }
3486 crate::TypeInner::Struct {
3487 ref members, span, ..
3488 } => {
3489 writeln!(self.out, "struct {name} {{")?;
3490 let mut last_offset = 0;
3491 for (index, member) in members.iter().enumerate() {
3492 if member.offset > last_offset {
3493 self.struct_member_pads.insert((handle, index as u32));
3494 let pad = member.offset - last_offset;
3495 writeln!(self.out, "{}char _pad{}[{}];", back::INDENT, index, pad)?;
3496 }
3497 let ty_inner = &module.types[member.ty].inner;
3498 last_offset = member.offset + ty_inner.size(module.to_ctx());
3499
3500 let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
3501
3502 match should_pack_struct_member(members, span, index, module) {
3504 Some(scalar) => {
3505 writeln!(
3506 self.out,
3507 "{}{}::packed_{}3 {};",
3508 back::INDENT,
3509 NAMESPACE,
3510 scalar.to_msl_name(),
3511 member_name
3512 )?;
3513 }
3514 None => {
3515 let base_name = TypeContext {
3516 handle: member.ty,
3517 gctx: module.to_ctx(),
3518 names: &self.names,
3519 access: crate::StorageAccess::empty(),
3520 binding: None,
3521 first_time: false,
3522 };
3523 writeln!(
3524 self.out,
3525 "{}{} {};",
3526 back::INDENT,
3527 base_name,
3528 member_name
3529 )?;
3530
3531 if let crate::TypeInner::Vector {
3533 size: crate::VectorSize::Tri,
3534 scalar,
3535 } = *ty_inner
3536 {
3537 last_offset += scalar.width as u32;
3538 }
3539 }
3540 }
3541 }
3542 writeln!(self.out, "}};")?;
3543 }
3544 _ => {
3545 let ty_name = TypeContext {
3546 handle,
3547 gctx: module.to_ctx(),
3548 names: &self.names,
3549 access: crate::StorageAccess::empty(),
3550 binding: None,
3551 first_time: true,
3552 };
3553 writeln!(self.out, "typedef {ty_name} {name};")?;
3554 }
3555 }
3556 }
3557
3558 for (type_key, struct_ty) in module.special_types.predeclared_types.iter() {
3560 match type_key {
3561 &crate::PredeclaredType::ModfResult { size, width }
3562 | &crate::PredeclaredType::FrexpResult { size, width } => {
3563 let arg_type_name_owner;
3564 let arg_type_name = if let Some(size) = size {
3565 arg_type_name_owner = format!(
3566 "{NAMESPACE}::{}{}",
3567 if width == 8 { "double" } else { "float" },
3568 size as u8
3569 );
3570 &arg_type_name_owner
3571 } else if width == 8 {
3572 "double"
3573 } else {
3574 "float"
3575 };
3576
3577 let other_type_name_owner;
3578 let (defined_func_name, called_func_name, other_type_name) =
3579 if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) {
3580 (MODF_FUNCTION, "modf", arg_type_name)
3581 } else {
3582 let other_type_name = if let Some(size) = size {
3583 other_type_name_owner = format!("int{}", size as u8);
3584 &other_type_name_owner
3585 } else {
3586 "int"
3587 };
3588 (FREXP_FUNCTION, "frexp", other_type_name)
3589 };
3590
3591 let struct_name = &self.names[&NameKey::Type(*struct_ty)];
3592
3593 writeln!(self.out)?;
3594 writeln!(
3595 self.out,
3596 "{} {defined_func_name}({arg_type_name} arg) {{
3597 {other_type_name} other;
3598 {arg_type_name} fract = {NAMESPACE}::{called_func_name}(arg, other);
3599 return {}{{ fract, other }};
3600}}",
3601 struct_name, struct_name
3602 )?;
3603 }
3604 &crate::PredeclaredType::AtomicCompareExchangeWeakResult { .. } => {}
3605 }
3606 }
3607
3608 Ok(())
3609 }
3610
3611 fn write_global_constants(
3613 &mut self,
3614 module: &crate::Module,
3615 mod_info: &valid::ModuleInfo,
3616 ) -> BackendResult {
3617 let constants = module.constants.iter().filter(|&(_, c)| c.name.is_some());
3618
3619 for (handle, constant) in constants {
3620 let ty_name = TypeContext {
3621 handle: constant.ty,
3622 gctx: module.to_ctx(),
3623 names: &self.names,
3624 access: crate::StorageAccess::empty(),
3625 binding: None,
3626 first_time: false,
3627 };
3628 let name = &self.names[&NameKey::Constant(handle)];
3629 write!(self.out, "constant {ty_name} {name} = ")?;
3630 self.put_const_expression(constant.init, module, mod_info)?;
3631 writeln!(self.out, ";")?;
3632 }
3633
3634 Ok(())
3635 }
3636
3637 fn put_inline_sampler_properties(
3638 &mut self,
3639 level: back::Level,
3640 sampler: &sm::InlineSampler,
3641 ) -> BackendResult {
3642 for (&letter, address) in ['s', 't', 'r'].iter().zip(sampler.address.iter()) {
3643 writeln!(
3644 self.out,
3645 "{}{}::{}_address::{},",
3646 level,
3647 NAMESPACE,
3648 letter,
3649 address.as_str(),
3650 )?;
3651 }
3652 writeln!(
3653 self.out,
3654 "{}{}::mag_filter::{},",
3655 level,
3656 NAMESPACE,
3657 sampler.mag_filter.as_str(),
3658 )?;
3659 writeln!(
3660 self.out,
3661 "{}{}::min_filter::{},",
3662 level,
3663 NAMESPACE,
3664 sampler.min_filter.as_str(),
3665 )?;
3666 if let Some(filter) = sampler.mip_filter {
3667 writeln!(
3668 self.out,
3669 "{}{}::mip_filter::{},",
3670 level,
3671 NAMESPACE,
3672 filter.as_str(),
3673 )?;
3674 }
3675 if sampler.border_color != sm::BorderColor::TransparentBlack {
3677 writeln!(
3678 self.out,
3679 "{}{}::border_color::{},",
3680 level,
3681 NAMESPACE,
3682 sampler.border_color.as_str(),
3683 )?;
3684 }
3685 if false {
3689 if let Some(ref lod) = sampler.lod_clamp {
3690 writeln!(self.out, "{}lod_clamp({},{}),", level, lod.start, lod.end,)?;
3691 }
3692 if let Some(aniso) = sampler.max_anisotropy {
3693 writeln!(self.out, "{}max_anisotropy({}),", level, aniso.get(),)?;
3694 }
3695 }
3696 if sampler.compare_func != sm::CompareFunc::Never {
3697 writeln!(
3698 self.out,
3699 "{}{}::compare_func::{},",
3700 level,
3701 NAMESPACE,
3702 sampler.compare_func.as_str(),
3703 )?;
3704 }
3705 writeln!(
3706 self.out,
3707 "{}{}::coord::{}",
3708 level,
3709 NAMESPACE,
3710 sampler.coord.as_str()
3711 )?;
3712 Ok(())
3713 }
3714
3715 fn write_functions(
3717 &mut self,
3718 module: &crate::Module,
3719 mod_info: &valid::ModuleInfo,
3720 options: &Options,
3721 pipeline_options: &PipelineOptions,
3722 ) -> Result<TranslationInfo, Error> {
3723 let mut pass_through_globals = Vec::new();
3724 for (fun_handle, fun) in module.functions.iter() {
3725 log::trace!(
3726 "function {:?}, handle {:?}",
3727 fun.name.as_deref().unwrap_or("(anonymous)"),
3728 fun_handle
3729 );
3730
3731 let fun_info = &mod_info[fun_handle];
3732 pass_through_globals.clear();
3733 let mut supports_array_length = false;
3734 for (handle, var) in module.global_variables.iter() {
3735 if !fun_info[handle].is_empty() {
3736 if var.space.needs_pass_through() {
3737 pass_through_globals.push(handle);
3738 }
3739 supports_array_length |= needs_array_length(var.ty, &module.types);
3740 }
3741 }
3742
3743 writeln!(self.out)?;
3744 let fun_name = &self.names[&NameKey::Function(fun_handle)];
3745 match fun.result {
3746 Some(ref result) => {
3747 let ty_name = TypeContext {
3748 handle: result.ty,
3749 gctx: module.to_ctx(),
3750 names: &self.names,
3751 access: crate::StorageAccess::empty(),
3752 binding: None,
3753 first_time: false,
3754 };
3755 write!(self.out, "{ty_name}")?;
3756 }
3757 None => {
3758 write!(self.out, "void")?;
3759 }
3760 }
3761 writeln!(self.out, " {fun_name}(")?;
3762
3763 for (index, arg) in fun.arguments.iter().enumerate() {
3764 let name = &self.names[&NameKey::FunctionArgument(fun_handle, index as u32)];
3765 let param_type_name = TypeContext {
3766 handle: arg.ty,
3767 gctx: module.to_ctx(),
3768 names: &self.names,
3769 access: crate::StorageAccess::empty(),
3770 binding: None,
3771 first_time: false,
3772 };
3773 let separator = separate(
3774 !pass_through_globals.is_empty()
3775 || index + 1 != fun.arguments.len()
3776 || supports_array_length,
3777 );
3778 writeln!(
3779 self.out,
3780 "{}{} {}{}",
3781 back::INDENT,
3782 param_type_name,
3783 name,
3784 separator
3785 )?;
3786 }
3787 for (index, &handle) in pass_through_globals.iter().enumerate() {
3788 let tyvar = TypedGlobalVariable {
3789 module,
3790 names: &self.names,
3791 handle,
3792 usage: fun_info[handle],
3793 binding: None,
3794 reference: true,
3795 };
3796 let separator =
3797 separate(index + 1 != pass_through_globals.len() || supports_array_length);
3798 write!(self.out, "{}", back::INDENT)?;
3799 tyvar.try_fmt(&mut self.out)?;
3800 writeln!(self.out, "{separator}")?;
3801 }
3802
3803 if supports_array_length {
3804 writeln!(
3805 self.out,
3806 "{}constant _mslBufferSizes& _buffer_sizes",
3807 back::INDENT
3808 )?;
3809 }
3810
3811 writeln!(self.out, ") {{")?;
3812
3813 let guarded_indices =
3814 index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
3815
3816 let context = StatementContext {
3817 expression: ExpressionContext {
3818 function: fun,
3819 origin: FunctionOrigin::Handle(fun_handle),
3820 info: fun_info,
3821 lang_version: options.lang_version,
3822 policies: options.bounds_check_policies,
3823 guarded_indices,
3824 module,
3825 mod_info,
3826 pipeline_options,
3827 },
3828 result_struct: None,
3829 };
3830
3831 for (local_handle, local) in fun.local_variables.iter() {
3832 let ty_name = TypeContext {
3833 handle: local.ty,
3834 gctx: module.to_ctx(),
3835 names: &self.names,
3836 access: crate::StorageAccess::empty(),
3837 binding: None,
3838 first_time: false,
3839 };
3840 let local_name = &self.names[&NameKey::FunctionLocal(fun_handle, local_handle)];
3841 write!(self.out, "{}{} {}", back::INDENT, ty_name, local_name)?;
3842 match local.init {
3843 Some(value) => {
3844 write!(self.out, " = ")?;
3845 self.put_expression(value, &context.expression, true)?;
3846 }
3847 None => {
3848 write!(self.out, " = {{}}")?;
3849 }
3850 };
3851 writeln!(self.out, ";")?;
3852 }
3853
3854 self.update_expressions_to_bake(fun, fun_info, &context.expression);
3855 self.put_block(back::Level(1), &fun.body, &context)?;
3856 writeln!(self.out, "}}")?;
3857 self.named_expressions.clear();
3858 }
3859
3860 let mut info = TranslationInfo {
3861 entry_point_names: Vec::with_capacity(module.entry_points.len()),
3862 };
3863 for (ep_index, ep) in module.entry_points.iter().enumerate() {
3864 let fun = &ep.function;
3865 let fun_info = mod_info.get_entry_point(ep_index);
3866 let mut ep_error = None;
3867
3868 log::trace!(
3869 "entry point {:?}, index {:?}",
3870 fun.name.as_deref().unwrap_or("(anonymous)"),
3871 ep_index
3872 );
3873
3874 let supports_array_length = module
3876 .global_variables
3877 .iter()
3878 .filter(|&(handle, _)| !fun_info[handle].is_empty())
3879 .any(|(_, var)| needs_array_length(var.ty, &module.types));
3880
3881 if !options.fake_missing_bindings {
3884 for (var_handle, var) in module.global_variables.iter() {
3885 if fun_info[var_handle].is_empty() {
3886 continue;
3887 }
3888 match var.space {
3889 crate::AddressSpace::Uniform
3890 | crate::AddressSpace::Storage { .. }
3891 | crate::AddressSpace::Handle => {
3892 let br = match var.binding {
3893 Some(ref br) => br,
3894 None => {
3895 let var_name = var.name.clone().unwrap_or_default();
3896 ep_error =
3897 Some(super::EntryPointError::MissingBinding(var_name));
3898 break;
3899 }
3900 };
3901 let target = options.get_resource_binding_target(ep, br);
3902 let good = match target {
3903 Some(target) => {
3904 let binding_ty = match module.types[var.ty].inner {
3905 crate::TypeInner::BindingArray { base, .. } => {
3906 &module.types[base].inner
3907 }
3908 ref ty => ty,
3909 };
3910 match *binding_ty {
3911 crate::TypeInner::Image { .. } => target.texture.is_some(),
3912 crate::TypeInner::Sampler { .. } => {
3913 target.sampler.is_some()
3914 }
3915 _ => target.buffer.is_some(),
3916 }
3917 }
3918 None => false,
3919 };
3920 if !good {
3921 ep_error =
3922 Some(super::EntryPointError::MissingBindTarget(br.clone()));
3923 break;
3924 }
3925 }
3926 crate::AddressSpace::PushConstant => {
3927 if let Err(e) = options.resolve_push_constants(ep) {
3928 ep_error = Some(e);
3929 break;
3930 }
3931 }
3932 crate::AddressSpace::Function
3933 | crate::AddressSpace::Private
3934 | crate::AddressSpace::WorkGroup => {}
3935 }
3936 }
3937 if supports_array_length {
3938 if let Err(err) = options.resolve_sizes_buffer(ep) {
3939 ep_error = Some(err);
3940 }
3941 }
3942 }
3943
3944 if let Some(err) = ep_error {
3945 info.entry_point_names.push(Err(err));
3946 continue;
3947 }
3948 let fun_name = &self.names[&NameKey::EntryPoint(ep_index as _)];
3949 info.entry_point_names.push(Ok(fun_name.clone()));
3950
3951 writeln!(self.out)?;
3952
3953 let (em_str, in_mode, out_mode) = match ep.stage {
3954 crate::ShaderStage::Vertex => (
3955 "vertex",
3956 LocationMode::VertexInput,
3957 LocationMode::VertexOutput,
3958 ),
3959 crate::ShaderStage::Fragment { .. } => (
3960 "fragment",
3961 LocationMode::FragmentInput,
3962 LocationMode::FragmentOutput,
3963 ),
3964 crate::ShaderStage::Compute { .. } => {
3965 ("kernel", LocationMode::Uniform, LocationMode::Uniform)
3966 }
3967 };
3968
3969 let mut flattened_member_names = FastHashMap::default();
3975 let mut varyings_namer = crate::proc::Namer::default();
3977
3978 let mut flattened_arguments = Vec::new();
3983 for (arg_index, arg) in fun.arguments.iter().enumerate() {
3984 match module.types[arg.ty].inner {
3985 crate::TypeInner::Struct { ref members, .. } => {
3986 for (member_index, member) in members.iter().enumerate() {
3987 let member_index = member_index as u32;
3988 flattened_arguments.push((
3989 NameKey::StructMember(arg.ty, member_index),
3990 member.ty,
3991 member.binding.as_ref(),
3992 ));
3993 let name_key = NameKey::StructMember(arg.ty, member_index);
3994 let name = match member.binding {
3995 Some(crate::Binding::Location { .. }) => {
3996 varyings_namer.call(&self.names[&name_key])
3997 }
3998 _ => self.namer.call(&self.names[&name_key]),
3999 };
4000 flattened_member_names.insert(name_key, name);
4001 }
4002 }
4003 _ => flattened_arguments.push((
4004 NameKey::EntryPointArgument(ep_index as _, arg_index as u32),
4005 arg.ty,
4006 arg.binding.as_ref(),
4007 )),
4008 }
4009 }
4010
4011 let stage_in_name = format!("{fun_name}Input");
4014 let varyings_member_name = self.namer.call("varyings");
4015 let mut has_varyings = false;
4016 if !flattened_arguments.is_empty() {
4017 writeln!(self.out, "struct {stage_in_name} {{")?;
4018 for &(ref name_key, ty, binding) in flattened_arguments.iter() {
4019 let binding = match binding {
4020 Some(ref binding @ &crate::Binding::Location { .. }) => binding,
4021 _ => continue,
4022 };
4023 has_varyings = true;
4024 let name = match *name_key {
4025 NameKey::StructMember(..) => &flattened_member_names[name_key],
4026 _ => &self.names[name_key],
4027 };
4028 let ty_name = TypeContext {
4029 handle: ty,
4030 gctx: module.to_ctx(),
4031 names: &self.names,
4032 access: crate::StorageAccess::empty(),
4033 binding: None,
4034 first_time: false,
4035 };
4036 let resolved = options.resolve_local_binding(binding, in_mode)?;
4037 write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
4038 resolved.try_fmt(&mut self.out)?;
4039 writeln!(self.out, ";")?;
4040 }
4041 writeln!(self.out, "}};")?;
4042 }
4043
4044 let stage_out_name = format!("{fun_name}Output");
4047 let result_member_name = self.namer.call("member");
4048 let result_type_name = match fun.result {
4049 Some(ref result) => {
4050 let mut result_members = Vec::new();
4051 if let crate::TypeInner::Struct { ref members, .. } =
4052 module.types[result.ty].inner
4053 {
4054 for (member_index, member) in members.iter().enumerate() {
4055 result_members.push((
4056 &self.names[&NameKey::StructMember(result.ty, member_index as u32)],
4057 member.ty,
4058 member.binding.as_ref(),
4059 ));
4060 }
4061 } else {
4062 result_members.push((
4063 &result_member_name,
4064 result.ty,
4065 result.binding.as_ref(),
4066 ));
4067 }
4068
4069 writeln!(self.out, "struct {stage_out_name} {{")?;
4070 let mut has_point_size = false;
4071 for (name, ty, binding) in result_members {
4072 let ty_name = TypeContext {
4073 handle: ty,
4074 gctx: module.to_ctx(),
4075 names: &self.names,
4076 access: crate::StorageAccess::empty(),
4077 binding: None,
4078 first_time: true,
4079 };
4080 let binding = binding.ok_or_else(|| {
4081 Error::GenericValidation("Expected binding, got None".into())
4082 })?;
4083
4084 if let crate::Binding::BuiltIn(crate::BuiltIn::PointSize) = *binding {
4085 has_point_size = true;
4086 if !pipeline_options.allow_and_force_point_size {
4087 continue;
4088 }
4089 }
4090
4091 let array_len = match module.types[ty].inner {
4092 crate::TypeInner::Array {
4093 size: crate::ArraySize::Constant(size),
4094 ..
4095 } => Some(size),
4096 _ => None,
4097 };
4098 let resolved = options.resolve_local_binding(binding, out_mode)?;
4099 write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
4100 if let Some(array_len) = array_len {
4101 write!(self.out, " [{array_len}]")?;
4102 }
4103 resolved.try_fmt(&mut self.out)?;
4104 writeln!(self.out, ";")?;
4105 }
4106
4107 if pipeline_options.allow_and_force_point_size
4108 && ep.stage == crate::ShaderStage::Vertex
4109 && !has_point_size
4110 {
4111 writeln!(
4113 self.out,
4114 "{}float _point_size [[point_size]];",
4115 back::INDENT
4116 )?;
4117 }
4118 writeln!(self.out, "}};")?;
4119 &stage_out_name
4120 }
4121 None => "void",
4122 };
4123
4124 writeln!(self.out, "{em_str} {result_type_name} {fun_name}(")?;
4126 let mut is_first_argument = true;
4127
4128 if has_varyings {
4131 writeln!(
4132 self.out,
4133 " {stage_in_name} {varyings_member_name} [[stage_in]]"
4134 )?;
4135 is_first_argument = false;
4136 }
4137
4138 let mut local_invocation_id = None;
4139
4140 for &(ref name_key, ty, binding) in flattened_arguments.iter() {
4143 let binding = match binding {
4144 Some(binding @ &crate::Binding::BuiltIn { .. }) => binding,
4145 _ => continue,
4146 };
4147 let name = match *name_key {
4148 NameKey::StructMember(..) => &flattened_member_names[name_key],
4149 _ => &self.names[name_key],
4150 };
4151
4152 if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationId) {
4153 local_invocation_id = Some(name_key);
4154 }
4155
4156 let ty_name = TypeContext {
4157 handle: ty,
4158 gctx: module.to_ctx(),
4159 names: &self.names,
4160 access: crate::StorageAccess::empty(),
4161 binding: None,
4162 first_time: false,
4163 };
4164 let resolved = options.resolve_local_binding(binding, in_mode)?;
4165 let separator = if is_first_argument {
4166 is_first_argument = false;
4167 ' '
4168 } else {
4169 ','
4170 };
4171 write!(self.out, "{separator} {ty_name} {name}")?;
4172 resolved.try_fmt(&mut self.out)?;
4173 writeln!(self.out)?;
4174 }
4175
4176 let need_workgroup_variables_initialization =
4177 self.need_workgroup_variables_initialization(options, ep, module, fun_info);
4178
4179 if need_workgroup_variables_initialization && local_invocation_id.is_none() {
4180 let separator = if is_first_argument {
4181 is_first_argument = false;
4182 ' '
4183 } else {
4184 ','
4185 };
4186 writeln!(
4187 self.out,
4188 "{separator} {NAMESPACE}::uint3 __local_invocation_id [[thread_position_in_threadgroup]]"
4189 )?;
4190 }
4191
4192 for (handle, var) in module.global_variables.iter() {
4197 let usage = fun_info[handle];
4198 if usage.is_empty() || var.space == crate::AddressSpace::Private {
4199 continue;
4200 }
4201
4202 if options.lang_version < (1, 2) {
4203 match var.space {
4204 crate::AddressSpace::Storage { access }
4214 if access.contains(crate::StorageAccess::STORE)
4215 && ep.stage == crate::ShaderStage::Fragment =>
4216 {
4217 return Err(Error::UnsupportedWriteableStorageBuffer)
4218 }
4219 crate::AddressSpace::Handle => {
4220 match module.types[var.ty].inner {
4221 crate::TypeInner::Image {
4222 class: crate::ImageClass::Storage { access, .. },
4223 ..
4224 } => {
4225 if access.contains(crate::StorageAccess::STORE)
4235 && (ep.stage == crate::ShaderStage::Vertex
4236 || ep.stage == crate::ShaderStage::Fragment)
4237 {
4238 return Err(Error::UnsupportedWriteableStorageTexture(
4239 ep.stage,
4240 ));
4241 }
4242
4243 if access.contains(
4244 crate::StorageAccess::LOAD | crate::StorageAccess::STORE,
4245 ) {
4246 return Err(Error::UnsupportedRWStorageTexture);
4247 }
4248 }
4249 _ => {}
4250 }
4251 }
4252 _ => {}
4253 }
4254 }
4255
4256 match var.space {
4258 crate::AddressSpace::Handle => match module.types[var.ty].inner {
4259 crate::TypeInner::BindingArray { base, .. } => {
4260 match module.types[base].inner {
4261 crate::TypeInner::Sampler { .. } => {
4262 if options.lang_version < (2, 0) {
4263 return Err(Error::UnsupportedArrayOf(
4264 "samplers".to_string(),
4265 ));
4266 }
4267 }
4268 crate::TypeInner::Image { class, .. } => match class {
4269 crate::ImageClass::Sampled { .. }
4270 | crate::ImageClass::Depth { .. }
4271 | crate::ImageClass::Storage {
4272 access: crate::StorageAccess::LOAD,
4273 ..
4274 } => {
4275 if options.lang_version < (2, 0) {
4280 return Err(Error::UnsupportedArrayOf(
4281 "textures".to_string(),
4282 ));
4283 }
4284 }
4285 crate::ImageClass::Storage {
4286 access: crate::StorageAccess::STORE,
4287 ..
4288 } => {
4289 if options.lang_version < (2, 0) {
4294 return Err(Error::UnsupportedArrayOf(
4295 "write-only textures".to_string(),
4296 ));
4297 }
4298 }
4299 crate::ImageClass::Storage { .. } => {
4300 return Err(Error::UnsupportedArrayOf(
4301 "read-write textures".to_string(),
4302 ));
4303 }
4304 },
4305 _ => {
4306 return Err(Error::UnsupportedArrayOfType(base));
4307 }
4308 }
4309 }
4310 _ => {}
4311 },
4312 _ => {}
4313 }
4314
4315 let resolved = match var.space {
4317 crate::AddressSpace::PushConstant => options.resolve_push_constants(ep).ok(),
4318 crate::AddressSpace::WorkGroup => None,
4319 _ => options
4320 .resolve_resource_binding(ep, var.binding.as_ref().unwrap())
4321 .ok(),
4322 };
4323 if let Some(ref resolved) = resolved {
4324 if resolved.as_inline_sampler(options).is_some() {
4326 continue;
4327 }
4328 }
4329
4330 let tyvar = TypedGlobalVariable {
4331 module,
4332 names: &self.names,
4333 handle,
4334 usage,
4335 binding: resolved.as_ref(),
4336 reference: true,
4337 };
4338 let separator = if is_first_argument {
4339 is_first_argument = false;
4340 ' '
4341 } else {
4342 ','
4343 };
4344 write!(self.out, "{separator} ")?;
4345 tyvar.try_fmt(&mut self.out)?;
4346 if let Some(resolved) = resolved {
4347 resolved.try_fmt(&mut self.out)?;
4348 }
4349 if let Some(value) = var.init {
4350 write!(self.out, " = ")?;
4351 self.put_const_expression(value, module, mod_info)?;
4352 }
4353 writeln!(self.out)?;
4354 }
4355
4356 if supports_array_length {
4359 let resolved = options.resolve_sizes_buffer(ep).unwrap();
4361 let separator = if module.global_variables.is_empty() {
4362 ' '
4363 } else {
4364 ','
4365 };
4366 write!(
4367 self.out,
4368 "{separator} constant _mslBufferSizes& _buffer_sizes",
4369 )?;
4370 resolved.try_fmt(&mut self.out)?;
4371 writeln!(self.out)?;
4372 }
4373
4374 writeln!(self.out, ") {{")?;
4376
4377 if need_workgroup_variables_initialization {
4378 self.write_workgroup_variables_initialization(
4379 module,
4380 mod_info,
4381 fun_info,
4382 local_invocation_id,
4383 )?;
4384 }
4385
4386 for (handle, var) in module.global_variables.iter() {
4389 let usage = fun_info[handle];
4390 if usage.is_empty() {
4391 continue;
4392 }
4393 if var.space == crate::AddressSpace::Private {
4394 let tyvar = TypedGlobalVariable {
4395 module,
4396 names: &self.names,
4397 handle,
4398 usage,
4399 binding: None,
4400 reference: false,
4401 };
4402 write!(self.out, "{}", back::INDENT)?;
4403 tyvar.try_fmt(&mut self.out)?;
4404 match var.init {
4405 Some(value) => {
4406 write!(self.out, " = ")?;
4407 self.put_const_expression(value, module, mod_info)?;
4408 writeln!(self.out, ";")?;
4409 }
4410 None => {
4411 writeln!(self.out, " = {{}};")?;
4412 }
4413 };
4414 } else if let Some(ref binding) = var.binding {
4415 let resolved = options.resolve_resource_binding(ep, binding).unwrap();
4417 if let Some(sampler) = resolved.as_inline_sampler(options) {
4418 let name = &self.names[&NameKey::GlobalVariable(handle)];
4419 writeln!(
4420 self.out,
4421 "{}constexpr {}::sampler {}(",
4422 back::INDENT,
4423 NAMESPACE,
4424 name
4425 )?;
4426 self.put_inline_sampler_properties(back::Level(2), sampler)?;
4427 writeln!(self.out, "{});", back::INDENT)?;
4428 }
4429 }
4430 }
4431
4432 for (arg_index, arg) in fun.arguments.iter().enumerate() {
4443 let arg_name =
4444 &self.names[&NameKey::EntryPointArgument(ep_index as _, arg_index as u32)];
4445 match module.types[arg.ty].inner {
4446 crate::TypeInner::Struct { ref members, .. } => {
4447 let struct_name = &self.names[&NameKey::Type(arg.ty)];
4448 write!(
4449 self.out,
4450 "{}const {} {} = {{ ",
4451 back::INDENT,
4452 struct_name,
4453 arg_name
4454 )?;
4455 for (member_index, member) in members.iter().enumerate() {
4456 let key = NameKey::StructMember(arg.ty, member_index as u32);
4457 let name = &flattened_member_names[&key];
4458 if member_index != 0 {
4459 write!(self.out, ", ")?;
4460 }
4461 if self
4463 .struct_member_pads
4464 .contains(&(arg.ty, member_index as u32))
4465 {
4466 write!(self.out, "{{}}, ")?;
4467 }
4468 if let Some(crate::Binding::Location { .. }) = member.binding {
4469 write!(self.out, "{varyings_member_name}.")?;
4470 }
4471 write!(self.out, "{name}")?;
4472 }
4473 writeln!(self.out, " }};")?;
4474 }
4475 _ => {
4476 if let Some(crate::Binding::Location { .. }) = arg.binding {
4477 writeln!(
4478 self.out,
4479 "{}const auto {} = {}.{};",
4480 back::INDENT,
4481 arg_name,
4482 varyings_member_name,
4483 arg_name
4484 )?;
4485 }
4486 }
4487 }
4488 }
4489
4490 let guarded_indices =
4491 index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
4492
4493 let context = StatementContext {
4494 expression: ExpressionContext {
4495 function: fun,
4496 origin: FunctionOrigin::EntryPoint(ep_index as _),
4497 info: fun_info,
4498 lang_version: options.lang_version,
4499 policies: options.bounds_check_policies,
4500 guarded_indices,
4501 module,
4502 mod_info,
4503 pipeline_options,
4504 },
4505 result_struct: Some(&stage_out_name),
4506 };
4507
4508 for (local_handle, local) in fun.local_variables.iter() {
4511 let name = &self.names[&NameKey::EntryPointLocal(ep_index as _, local_handle)];
4512 let ty_name = TypeContext {
4513 handle: local.ty,
4514 gctx: module.to_ctx(),
4515 names: &self.names,
4516 access: crate::StorageAccess::empty(),
4517 binding: None,
4518 first_time: false,
4519 };
4520 write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
4521 match local.init {
4522 Some(value) => {
4523 write!(self.out, " = ")?;
4524 self.put_expression(value, &context.expression, true)?;
4525 }
4526 None => {
4527 write!(self.out, " = {{}}")?;
4528 }
4529 };
4530 writeln!(self.out, ";")?;
4531 }
4532
4533 self.update_expressions_to_bake(fun, fun_info, &context.expression);
4534 self.put_block(back::Level(1), &fun.body, &context)?;
4535 writeln!(self.out, "}}")?;
4536 if ep_index + 1 != module.entry_points.len() {
4537 writeln!(self.out)?;
4538 }
4539 self.named_expressions.clear();
4540 }
4541
4542 Ok(info)
4543 }
4544
4545 fn write_barrier(&mut self, flags: crate::Barrier, level: back::Level) -> BackendResult {
4546 if flags.is_empty() {
4549 writeln!(
4550 self.out,
4551 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_none);",
4552 )?;
4553 }
4554 if flags.contains(crate::Barrier::STORAGE) {
4555 writeln!(
4556 self.out,
4557 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_device);",
4558 )?;
4559 }
4560 if flags.contains(crate::Barrier::WORK_GROUP) {
4561 writeln!(
4562 self.out,
4563 "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
4564 )?;
4565 }
4566 if flags.contains(crate::Barrier::SUB_GROUP) {
4567 writeln!(
4568 self.out,
4569 "{level}{NAMESPACE}::simdgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
4570 )?;
4571 }
4572 Ok(())
4573 }
4574}
4575
4576mod workgroup_mem_init {
4579 use crate::EntryPoint;
4580
4581 use super::*;
4582
4583 enum Access {
4584 GlobalVariable(Handle<crate::GlobalVariable>),
4585 StructMember(Handle<crate::Type>, u32),
4586 Array(usize),
4587 }
4588
4589 impl Access {
4590 fn write<W: Write>(
4591 &self,
4592 writer: &mut W,
4593 names: &FastHashMap<NameKey, String>,
4594 ) -> Result<(), core::fmt::Error> {
4595 match *self {
4596 Access::GlobalVariable(handle) => {
4597 write!(writer, "{}", &names[&NameKey::GlobalVariable(handle)])
4598 }
4599 Access::StructMember(handle, index) => {
4600 write!(writer, ".{}", &names[&NameKey::StructMember(handle, index)])
4601 }
4602 Access::Array(depth) => write!(writer, ".{WRAPPED_ARRAY_FIELD}[__i{depth}]"),
4603 }
4604 }
4605 }
4606
4607 struct AccessStack {
4608 stack: Vec<Access>,
4609 array_depth: usize,
4610 }
4611
4612 impl AccessStack {
4613 const fn new() -> Self {
4614 Self {
4615 stack: Vec::new(),
4616 array_depth: 0,
4617 }
4618 }
4619
4620 fn enter_array<R>(&mut self, cb: impl FnOnce(&mut Self, usize) -> R) -> R {
4621 let array_depth = self.array_depth;
4622 self.stack.push(Access::Array(array_depth));
4623 self.array_depth += 1;
4624 let res = cb(self, array_depth);
4625 self.stack.pop();
4626 self.array_depth -= 1;
4627 res
4628 }
4629
4630 fn enter<R>(&mut self, new: Access, cb: impl FnOnce(&mut Self) -> R) -> R {
4631 self.stack.push(new);
4632 let res = cb(self);
4633 self.stack.pop();
4634 res
4635 }
4636
4637 fn write<W: Write>(
4638 &self,
4639 writer: &mut W,
4640 names: &FastHashMap<NameKey, String>,
4641 ) -> Result<(), core::fmt::Error> {
4642 for next in self.stack.iter() {
4643 next.write(writer, names)?;
4644 }
4645 Ok(())
4646 }
4647 }
4648
4649 impl<W: Write> Writer<W> {
4650 pub(super) fn need_workgroup_variables_initialization(
4651 &mut self,
4652 options: &Options,
4653 ep: &EntryPoint,
4654 module: &crate::Module,
4655 fun_info: &valid::FunctionInfo,
4656 ) -> bool {
4657 options.zero_initialize_workgroup_memory
4658 && ep.stage == crate::ShaderStage::Compute
4659 && module.global_variables.iter().any(|(handle, var)| {
4660 !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
4661 })
4662 }
4663
4664 pub(super) fn write_workgroup_variables_initialization(
4665 &mut self,
4666 module: &crate::Module,
4667 module_info: &valid::ModuleInfo,
4668 fun_info: &valid::FunctionInfo,
4669 local_invocation_id: Option<&NameKey>,
4670 ) -> BackendResult {
4671 let level = back::Level(1);
4672
4673 writeln!(
4674 self.out,
4675 "{}if ({}::all({} == {}::uint3(0u))) {{",
4676 level,
4677 NAMESPACE,
4678 local_invocation_id
4679 .map(|name_key| self.names[name_key].as_str())
4680 .unwrap_or("__local_invocation_id"),
4681 NAMESPACE,
4682 )?;
4683
4684 let mut access_stack = AccessStack::new();
4685
4686 let vars = module.global_variables.iter().filter(|&(handle, var)| {
4687 !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
4688 });
4689
4690 for (handle, var) in vars {
4691 access_stack.enter(Access::GlobalVariable(handle), |access_stack| {
4692 self.write_workgroup_variable_initialization(
4693 module,
4694 module_info,
4695 var.ty,
4696 access_stack,
4697 level.next(),
4698 )
4699 })?;
4700 }
4701
4702 writeln!(self.out, "{level}}}")?;
4703 self.write_barrier(crate::Barrier::WORK_GROUP, level)
4704 }
4705
4706 fn write_workgroup_variable_initialization(
4707 &mut self,
4708 module: &crate::Module,
4709 module_info: &valid::ModuleInfo,
4710 ty: Handle<crate::Type>,
4711 access_stack: &mut AccessStack,
4712 level: back::Level,
4713 ) -> BackendResult {
4714 if module_info[ty].contains(valid::TypeFlags::CONSTRUCTIBLE) {
4715 write!(self.out, "{level}")?;
4716 access_stack.write(&mut self.out, &self.names)?;
4717 writeln!(self.out, " = {{}};")?;
4718 } else {
4719 match module.types[ty].inner {
4720 crate::TypeInner::Atomic { .. } => {
4721 write!(
4722 self.out,
4723 "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
4724 )?;
4725 access_stack.write(&mut self.out, &self.names)?;
4726 writeln!(self.out, ", 0, {NAMESPACE}::memory_order_relaxed);")?;
4727 }
4728 crate::TypeInner::Array { base, size, .. } => {
4729 let count = match size.to_indexable_length(module).expect("Bad array size")
4730 {
4731 proc::IndexableLength::Known(count) => count,
4732 proc::IndexableLength::Dynamic => unreachable!(),
4733 };
4734
4735 access_stack.enter_array(|access_stack, array_depth| {
4736 writeln!(
4737 self.out,
4738 "{level}for (int __i{array_depth} = 0; __i{array_depth} < {count}; __i{array_depth}++) {{"
4739 )?;
4740 self.write_workgroup_variable_initialization(
4741 module,
4742 module_info,
4743 base,
4744 access_stack,
4745 level.next(),
4746 )?;
4747 writeln!(self.out, "{level}}}")?;
4748 BackendResult::Ok(())
4749 })?;
4750 }
4751 crate::TypeInner::Struct { ref members, .. } => {
4752 for (index, member) in members.iter().enumerate() {
4753 access_stack.enter(
4754 Access::StructMember(ty, index as u32),
4755 |access_stack| {
4756 self.write_workgroup_variable_initialization(
4757 module,
4758 module_info,
4759 member.ty,
4760 access_stack,
4761 level,
4762 )
4763 },
4764 )?;
4765 }
4766 }
4767 _ => unreachable!(),
4768 }
4769 }
4770
4771 Ok(())
4772 }
4773 }
4774}
4775
4776#[test]
4777fn test_stack_size() {
4778 use crate::valid::{Capabilities, ValidationFlags};
4779 let mut module = crate::Module::default();
4781 let mut fun = crate::Function::default();
4782 let const_expr = fun.expressions.append(
4783 crate::Expression::Literal(crate::Literal::F32(1.0)),
4784 Default::default(),
4785 );
4786 let nested_expr = fun.expressions.append(
4787 crate::Expression::Unary {
4788 op: crate::UnaryOperator::Negate,
4789 expr: const_expr,
4790 },
4791 Default::default(),
4792 );
4793 fun.body.push(
4794 crate::Statement::Emit(fun.expressions.range_from(1)),
4795 Default::default(),
4796 );
4797 fun.body.push(
4798 crate::Statement::If {
4799 condition: nested_expr,
4800 accept: crate::Block::new(),
4801 reject: crate::Block::new(),
4802 },
4803 Default::default(),
4804 );
4805 let _ = module.functions.append(fun, Default::default());
4806 let info = crate::valid::Validator::new(ValidationFlags::empty(), Capabilities::empty())
4808 .validate(&module)
4809 .unwrap();
4810 let mut writer = Writer::new(String::new());
4812 writer
4813 .write(&module, &info, &Default::default(), &Default::default())
4814 .unwrap();
4815
4816 {
4817 let mut addresses_start = usize::MAX;
4819 let mut addresses_end = 0usize;
4820 for pointer in writer.put_expression_stack_pointers {
4821 addresses_start = addresses_start.min(pointer as usize);
4822 addresses_end = addresses_end.max(pointer as usize);
4823 }
4824 let stack_size = addresses_end - addresses_start;
4825 if !(11000..=25000).contains(&stack_size) {
4828 panic!("`put_expression` stack size {stack_size} has changed!");
4829 }
4830 }
4831
4832 {
4833 let mut addresses_start = usize::MAX;
4835 let mut addresses_end = 0usize;
4836 for pointer in writer.put_block_stack_pointers {
4837 addresses_start = addresses_start.min(pointer as usize);
4838 addresses_end = addresses_end.max(pointer as usize);
4839 }
4840 let stack_size = addresses_end - addresses_start;
4841 if !(15000..=25000).contains(&stack_size) {
4844 panic!("`put_block` stack size {stack_size} has changed!");
4845 }
4846 }
4847}