1use super::{
2 help::{
3 WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess,
4 WrappedZeroValue,
5 },
6 storage::StoreValue,
7 BackendResult, Error, Options,
8};
9use crate::{
10 back,
11 proc::{self, NameKey},
12 valid, Handle, Module, ScalarKind, ShaderStage, TypeInner,
13};
14use std::{fmt, mem};
15
16const LOCATION_SEMANTIC: &str = "LOC";
17const SPECIAL_CBUF_TYPE: &str = "NagaConstants";
18const SPECIAL_CBUF_VAR: &str = "_NagaConstants";
19const SPECIAL_FIRST_VERTEX: &str = "first_vertex";
20const SPECIAL_FIRST_INSTANCE: &str = "first_instance";
21const SPECIAL_OTHER: &str = "other";
22
23pub(crate) const MODF_FUNCTION: &str = "naga_modf";
24pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
25pub(crate) const EXTRACT_BITS_FUNCTION: &str = "naga_extractBits";
26pub(crate) const INSERT_BITS_FUNCTION: &str = "naga_insertBits";
27
28struct EpStructMember {
29 name: String,
30 ty: Handle<crate::Type>,
31 binding: Option<crate::Binding>,
33 index: u32,
34}
35
36struct EntryPointBinding {
39 arg_name: String,
42 ty_name: String,
44 members: Vec<EpStructMember>,
46}
47
48pub(super) struct EntryPointInterface {
49 input: Option<EntryPointBinding>,
54 output: Option<EntryPointBinding>,
58}
59
60#[derive(Clone, Eq, PartialEq, PartialOrd, Ord)]
61enum InterfaceKey {
62 Location(u32),
63 BuiltIn(crate::BuiltIn),
64 Other,
65}
66
67impl InterfaceKey {
68 const fn new(binding: Option<&crate::Binding>) -> Self {
69 match binding {
70 Some(&crate::Binding::Location { location, .. }) => Self::Location(location),
71 Some(&crate::Binding::BuiltIn(built_in)) => Self::BuiltIn(built_in),
72 None => Self::Other,
73 }
74 }
75}
76
77#[derive(Copy, Clone, PartialEq)]
78enum Io {
79 Input,
80 Output,
81}
82
83const fn is_subgroup_builtin_binding(binding: &Option<crate::Binding>) -> bool {
84 let &Some(crate::Binding::BuiltIn(builtin)) = binding else {
85 return false;
86 };
87 matches!(
88 builtin,
89 crate::BuiltIn::SubgroupSize
90 | crate::BuiltIn::SubgroupInvocationId
91 | crate::BuiltIn::NumSubgroups
92 | crate::BuiltIn::SubgroupId
93 )
94}
95
96impl<'a, W: fmt::Write> super::Writer<'a, W> {
97 pub fn new(out: W, options: &'a Options) -> Self {
98 Self {
99 out,
100 names: crate::FastHashMap::default(),
101 namer: proc::Namer::default(),
102 options,
103 entry_point_io: Vec::new(),
104 named_expressions: crate::NamedExpressions::default(),
105 wrapped: super::Wrapped::default(),
106 temp_access_chain: Vec::new(),
107 need_bake_expressions: Default::default(),
108 }
109 }
110
111 fn reset(&mut self, module: &Module) {
112 self.names.clear();
113 self.namer.reset(
114 module,
115 super::keywords::RESERVED,
116 super::keywords::TYPES,
117 super::keywords::RESERVED_CASE_INSENSITIVE,
118 &[],
119 &mut self.names,
120 );
121 self.entry_point_io.clear();
122 self.named_expressions.clear();
123 self.wrapped.clear();
124 self.need_bake_expressions.clear();
125 }
126
127 fn update_expressions_to_bake(
132 &mut self,
133 module: &Module,
134 func: &crate::Function,
135 info: &valid::FunctionInfo,
136 ) {
137 use crate::Expression;
138 self.need_bake_expressions.clear();
139 for (fun_handle, expr) in func.expressions.iter() {
140 let expr_info = &info[fun_handle];
141 let min_ref_count = func.expressions[fun_handle].bake_ref_count();
142 if min_ref_count <= expr_info.ref_count {
143 self.need_bake_expressions.insert(fun_handle);
144 }
145
146 if let Expression::Math { fun, arg, .. } = *expr {
147 match fun {
148 crate::MathFunction::Asinh
149 | crate::MathFunction::Acosh
150 | crate::MathFunction::Atanh
151 | crate::MathFunction::Unpack2x16float
152 | crate::MathFunction::Unpack2x16snorm
153 | crate::MathFunction::Unpack2x16unorm
154 | crate::MathFunction::Unpack4x8snorm
155 | crate::MathFunction::Unpack4x8unorm
156 | crate::MathFunction::Pack2x16float
157 | crate::MathFunction::Pack2x16snorm
158 | crate::MathFunction::Pack2x16unorm
159 | crate::MathFunction::Pack4x8snorm
160 | crate::MathFunction::Pack4x8unorm => {
161 self.need_bake_expressions.insert(arg);
162 }
163 crate::MathFunction::CountLeadingZeros => {
164 let inner = info[fun_handle].ty.inner_with(&module.types);
165 if let Some(crate::ScalarKind::Sint) = inner.scalar_kind() {
166 self.need_bake_expressions.insert(arg);
167 }
168 }
169 _ => {}
170 }
171 }
172
173 if let Expression::Derivative { axis, ctrl, expr } = *expr {
174 use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
175 if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
176 self.need_bake_expressions.insert(expr);
177 }
178 }
179 }
180 for statement in func.body.iter() {
181 match *statement {
182 crate::Statement::SubgroupCollectiveOperation {
183 op: _,
184 collective_op: crate::CollectiveOperation::InclusiveScan,
185 argument,
186 result: _,
187 } => {
188 self.need_bake_expressions.insert(argument);
189 }
190 _ => {}
191 }
192 }
193 }
194
195 pub fn write(
196 &mut self,
197 module: &Module,
198 module_info: &valid::ModuleInfo,
199 ) -> Result<super::ReflectionInfo, Error> {
200 if !module.overrides.is_empty() {
201 return Err(Error::Override);
202 }
203
204 self.reset(module);
205
206 if let Some(ref bt) = self.options.special_constants_binding {
208 writeln!(self.out, "struct {SPECIAL_CBUF_TYPE} {{")?;
209 writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_VERTEX)?;
210 writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_INSTANCE)?;
211 writeln!(self.out, "{}uint {};", back::INDENT, SPECIAL_OTHER)?;
212 writeln!(self.out, "}};")?;
213 write!(
214 self.out,
215 "ConstantBuffer<{}> {}: register(b{}",
216 SPECIAL_CBUF_TYPE, SPECIAL_CBUF_VAR, bt.register
217 )?;
218 if bt.space != 0 {
219 write!(self.out, ", space{}", bt.space)?;
220 }
221 writeln!(self.out, ");")?;
222
223 writeln!(self.out)?;
225 }
226
227 let ep_results = module
229 .entry_points
230 .iter()
231 .map(|ep| (ep.stage, ep.function.result.clone()))
232 .collect::<Vec<(ShaderStage, Option<crate::FunctionResult>)>>();
233
234 self.write_all_mat_cx2_typedefs_and_functions(module)?;
235
236 for (handle, ty) in module.types.iter() {
238 if let TypeInner::Struct { ref members, span } = ty.inner {
239 if module.types[members.last().unwrap().ty]
240 .inner
241 .is_dynamically_sized(&module.types)
242 {
243 continue;
246 }
247
248 let ep_result = ep_results.iter().find(|e| {
249 if let Some(ref result) = e.1 {
250 result.ty == handle
251 } else {
252 false
253 }
254 });
255
256 self.write_struct(
257 module,
258 handle,
259 members,
260 span,
261 ep_result.map(|r| (r.0, Io::Output)),
262 )?;
263 writeln!(self.out)?;
264 }
265 }
266
267 self.write_special_functions(module)?;
268
269 self.write_wrapped_compose_functions(module, &module.global_expressions)?;
270 self.write_wrapped_zero_value_functions(module, &module.global_expressions)?;
271
272 let mut constants = module
274 .constants
275 .iter()
276 .filter(|&(_, c)| c.name.is_some())
277 .peekable();
278 while let Some((handle, _)) = constants.next() {
279 self.write_global_constant(module, handle)?;
280 if constants.peek().is_none() {
282 writeln!(self.out)?;
283 }
284 }
285
286 for (ty, _) in module.global_variables.iter() {
288 self.write_global(module, ty)?;
289 }
290
291 if !module.global_variables.is_empty() {
292 writeln!(self.out)?;
294 }
295
296 for (index, ep) in module.entry_points.iter().enumerate() {
298 let ep_name = self.names[&NameKey::EntryPoint(index as u16)].clone();
299 let ep_io = self.write_ep_interface(module, &ep.function, ep.stage, &ep_name)?;
300 self.entry_point_io.push(ep_io);
301 }
302
303 for (handle, function) in module.functions.iter() {
305 let info = &module_info[handle];
306
307 if !self.options.fake_missing_bindings {
309 if let Some((var_handle, _)) =
310 module
311 .global_variables
312 .iter()
313 .find(|&(var_handle, var)| match var.binding {
314 Some(ref binding) if !info[var_handle].is_empty() => {
315 self.options.resolve_resource_binding(binding).is_err()
316 }
317 _ => false,
318 })
319 {
320 log::info!(
321 "Skipping function {:?} (name {:?}) because global {:?} is inaccessible",
322 handle,
323 function.name,
324 var_handle
325 );
326 continue;
327 }
328 }
329
330 let ctx = back::FunctionCtx {
331 ty: back::FunctionType::Function(handle),
332 info,
333 expressions: &function.expressions,
334 named_expressions: &function.named_expressions,
335 };
336 let name = self.names[&NameKey::Function(handle)].clone();
337
338 self.write_wrapped_functions(module, &ctx)?;
339
340 self.write_function(module, name.as_str(), function, &ctx, info)?;
341
342 writeln!(self.out)?;
343 }
344
345 let mut entry_point_names = Vec::with_capacity(module.entry_points.len());
346
347 for (index, ep) in module.entry_points.iter().enumerate() {
349 let info = module_info.get_entry_point(index);
350
351 if !self.options.fake_missing_bindings {
352 let mut ep_error = None;
353 for (var_handle, var) in module.global_variables.iter() {
354 match var.binding {
355 Some(ref binding) if !info[var_handle].is_empty() => {
356 if let Err(err) = self.options.resolve_resource_binding(binding) {
357 ep_error = Some(err);
358 break;
359 }
360 }
361 _ => {}
362 }
363 }
364 if let Some(err) = ep_error {
365 entry_point_names.push(Err(err));
366 continue;
367 }
368 }
369
370 let ctx = back::FunctionCtx {
371 ty: back::FunctionType::EntryPoint(index as u16),
372 info,
373 expressions: &ep.function.expressions,
374 named_expressions: &ep.function.named_expressions,
375 };
376
377 self.write_wrapped_functions(module, &ctx)?;
378
379 if ep.stage == ShaderStage::Compute {
380 let num_threads = ep.workgroup_size;
382 writeln!(
383 self.out,
384 "[numthreads({}, {}, {})]",
385 num_threads[0], num_threads[1], num_threads[2]
386 )?;
387 }
388
389 let name = self.names[&NameKey::EntryPoint(index as u16)].clone();
390 self.write_function(module, &name, &ep.function, &ctx, info)?;
391
392 if index < module.entry_points.len() - 1 {
393 writeln!(self.out)?;
394 }
395
396 entry_point_names.push(Ok(name));
397 }
398
399 Ok(super::ReflectionInfo { entry_point_names })
400 }
401
402 fn write_modifier(&mut self, binding: &crate::Binding) -> BackendResult {
403 match *binding {
404 crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }) => {
405 write!(self.out, "precise ")?;
406 }
407 crate::Binding::Location {
408 interpolation,
409 sampling,
410 ..
411 } => {
412 if let Some(interpolation) = interpolation {
413 if let Some(string) = interpolation.to_hlsl_str() {
414 write!(self.out, "{string} ")?
415 }
416 }
417
418 if let Some(sampling) = sampling {
419 if let Some(string) = sampling.to_hlsl_str() {
420 write!(self.out, "{string} ")?
421 }
422 }
423 }
424 crate::Binding::BuiltIn(_) => {}
425 }
426
427 Ok(())
428 }
429
430 fn write_semantic(
433 &mut self,
434 binding: &Option<crate::Binding>,
435 stage: Option<(ShaderStage, Io)>,
436 ) -> BackendResult {
437 match *binding {
438 Some(crate::Binding::BuiltIn(builtin)) if !is_subgroup_builtin_binding(binding) => {
439 let builtin_str = builtin.to_hlsl_str()?;
440 write!(self.out, " : {builtin_str}")?;
441 }
442 Some(crate::Binding::Location {
443 second_blend_source: true,
444 ..
445 }) => {
446 write!(self.out, " : SV_Target1")?;
447 }
448 Some(crate::Binding::Location {
449 location,
450 second_blend_source: false,
451 ..
452 }) => {
453 if stage == Some((crate::ShaderStage::Fragment, Io::Output)) {
454 write!(self.out, " : SV_Target{location}")?;
455 } else {
456 write!(self.out, " : {LOCATION_SEMANTIC}{location}")?;
457 }
458 }
459 _ => {}
460 }
461
462 Ok(())
463 }
464
465 fn write_interface_struct(
466 &mut self,
467 module: &Module,
468 shader_stage: (ShaderStage, Io),
469 struct_name: String,
470 mut members: Vec<EpStructMember>,
471 ) -> Result<EntryPointBinding, Error> {
472 members.sort_by_key(|m| InterfaceKey::new(m.binding.as_ref()));
476
477 write!(self.out, "struct {struct_name}")?;
478 writeln!(self.out, " {{")?;
479 for m in members.iter() {
480 if is_subgroup_builtin_binding(&m.binding) {
481 continue;
482 }
483 write!(self.out, "{}", back::INDENT)?;
484 if let Some(ref binding) = m.binding {
485 self.write_modifier(binding)?;
486 }
487 self.write_type(module, m.ty)?;
488 write!(self.out, " {}", &m.name)?;
489 self.write_semantic(&m.binding, Some(shader_stage))?;
490 writeln!(self.out, ";")?;
491 }
492 if members.iter().any(|arg| {
493 matches!(
494 arg.binding,
495 Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId))
496 )
497 }) {
498 writeln!(
499 self.out,
500 "{}uint __local_invocation_index : SV_GroupIndex;",
501 back::INDENT
502 )?;
503 }
504 writeln!(self.out, "}};")?;
505 writeln!(self.out)?;
506
507 match shader_stage.1 {
508 Io::Input => {
509 members.sort_by_key(|m| m.index);
511 }
512 Io::Output => {
513 }
515 }
516
517 Ok(EntryPointBinding {
518 arg_name: self.namer.call(struct_name.to_lowercase().as_str()),
519 ty_name: struct_name,
520 members,
521 })
522 }
523
524 fn write_ep_input_struct(
528 &mut self,
529 module: &Module,
530 func: &crate::Function,
531 stage: ShaderStage,
532 entry_point_name: &str,
533 ) -> Result<EntryPointBinding, Error> {
534 let struct_name = format!("{stage:?}Input_{entry_point_name}");
535
536 let mut fake_members = Vec::new();
537 for arg in func.arguments.iter() {
538 match module.types[arg.ty].inner {
539 TypeInner::Struct { ref members, .. } => {
540 for member in members.iter() {
541 let name = self.namer.call_or(&member.name, "member");
542 let index = fake_members.len() as u32;
543 fake_members.push(EpStructMember {
544 name,
545 ty: member.ty,
546 binding: member.binding.clone(),
547 index,
548 });
549 }
550 }
551 _ => {
552 let member_name = self.namer.call_or(&arg.name, "member");
553 let index = fake_members.len() as u32;
554 fake_members.push(EpStructMember {
555 name: member_name,
556 ty: arg.ty,
557 binding: arg.binding.clone(),
558 index,
559 });
560 }
561 }
562 }
563
564 self.write_interface_struct(module, (stage, Io::Input), struct_name, fake_members)
565 }
566
567 fn write_ep_output_struct(
571 &mut self,
572 module: &Module,
573 result: &crate::FunctionResult,
574 stage: ShaderStage,
575 entry_point_name: &str,
576 ) -> Result<EntryPointBinding, Error> {
577 let struct_name = format!("{stage:?}Output_{entry_point_name}");
578
579 let mut fake_members = Vec::new();
580 let empty = [];
581 let members = match module.types[result.ty].inner {
582 TypeInner::Struct { ref members, .. } => members,
583 ref other => {
584 log::error!("Unexpected {:?} output type without a binding", other);
585 &empty[..]
586 }
587 };
588
589 for member in members.iter() {
590 let member_name = self.namer.call_or(&member.name, "member");
591 let index = fake_members.len() as u32;
592 fake_members.push(EpStructMember {
593 name: member_name,
594 ty: member.ty,
595 binding: member.binding.clone(),
596 index,
597 });
598 }
599
600 self.write_interface_struct(module, (stage, Io::Output), struct_name, fake_members)
601 }
602
603 fn write_ep_interface(
607 &mut self,
608 module: &Module,
609 func: &crate::Function,
610 stage: ShaderStage,
611 ep_name: &str,
612 ) -> Result<EntryPointInterface, Error> {
613 Ok(EntryPointInterface {
614 input: if !func.arguments.is_empty()
615 && (stage == ShaderStage::Fragment
616 || func
617 .arguments
618 .iter()
619 .any(|arg| is_subgroup_builtin_binding(&arg.binding)))
620 {
621 Some(self.write_ep_input_struct(module, func, stage, ep_name)?)
622 } else {
623 None
624 },
625 output: match func.result {
626 Some(ref fr) if fr.binding.is_none() && stage == ShaderStage::Vertex => {
627 Some(self.write_ep_output_struct(module, fr, stage, ep_name)?)
628 }
629 _ => None,
630 },
631 })
632 }
633
634 fn write_ep_argument_initialization(
635 &mut self,
636 ep: &crate::EntryPoint,
637 ep_input: &EntryPointBinding,
638 fake_member: &EpStructMember,
639 ) -> BackendResult {
640 match fake_member.binding {
641 Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupSize)) => {
642 write!(self.out, "WaveGetLaneCount()")?
643 }
644 Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupInvocationId)) => {
645 write!(self.out, "WaveGetLaneIndex()")?
646 }
647 Some(crate::Binding::BuiltIn(crate::BuiltIn::NumSubgroups)) => write!(
648 self.out,
649 "({}u + WaveGetLaneCount() - 1u) / WaveGetLaneCount()",
650 ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2]
651 )?,
652 Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => {
653 write!(
654 self.out,
655 "{}.__local_invocation_index / WaveGetLaneCount()",
656 ep_input.arg_name
657 )?;
658 }
659 _ => {
660 write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?;
661 }
662 }
663 Ok(())
664 }
665
666 fn write_ep_arguments_initialization(
668 &mut self,
669 module: &Module,
670 func: &crate::Function,
671 ep_index: u16,
672 ) -> BackendResult {
673 let ep = &module.entry_points[ep_index as usize];
674 let ep_input = match self.entry_point_io[ep_index as usize].input.take() {
675 Some(ep_input) => ep_input,
676 None => return Ok(()),
677 };
678 let mut fake_iter = ep_input.members.iter();
679 for (arg_index, arg) in func.arguments.iter().enumerate() {
680 write!(self.out, "{}", back::INDENT)?;
681 self.write_type(module, arg.ty)?;
682 let arg_name = &self.names[&NameKey::EntryPointArgument(ep_index, arg_index as u32)];
683 write!(self.out, " {arg_name}")?;
684 match module.types[arg.ty].inner {
685 TypeInner::Array { base, size, .. } => {
686 self.write_array_size(module, base, size)?;
687 write!(self.out, " = ")?;
688 self.write_ep_argument_initialization(
689 ep,
690 &ep_input,
691 fake_iter.next().unwrap(),
692 )?;
693 writeln!(self.out, ";")?;
694 }
695 TypeInner::Struct { ref members, .. } => {
696 write!(self.out, " = {{ ")?;
697 for index in 0..members.len() {
698 if index != 0 {
699 write!(self.out, ", ")?;
700 }
701 self.write_ep_argument_initialization(
702 ep,
703 &ep_input,
704 fake_iter.next().unwrap(),
705 )?;
706 }
707 writeln!(self.out, " }};")?;
708 }
709 _ => {
710 write!(self.out, " = ")?;
711 self.write_ep_argument_initialization(
712 ep,
713 &ep_input,
714 fake_iter.next().unwrap(),
715 )?;
716 writeln!(self.out, ";")?;
717 }
718 }
719 }
720 assert!(fake_iter.next().is_none());
721 Ok(())
722 }
723
724 fn write_global(
728 &mut self,
729 module: &Module,
730 handle: Handle<crate::GlobalVariable>,
731 ) -> BackendResult {
732 let global = &module.global_variables[handle];
733 let inner = &module.types[global.ty].inner;
734
735 if let Some(ref binding) = global.binding {
736 if let Err(err) = self.options.resolve_resource_binding(binding) {
737 log::info!(
738 "Skipping global {:?} (name {:?}) for being inaccessible: {}",
739 handle,
740 global.name,
741 err,
742 );
743 return Ok(());
744 }
745 }
746
747 let register_ty = match global.space {
749 crate::AddressSpace::Function => unreachable!("Function address space"),
750 crate::AddressSpace::Private => {
751 write!(self.out, "static ")?;
752 self.write_type(module, global.ty)?;
753 ""
754 }
755 crate::AddressSpace::WorkGroup => {
756 write!(self.out, "groupshared ")?;
757 self.write_type(module, global.ty)?;
758 ""
759 }
760 crate::AddressSpace::Uniform => {
761 write!(self.out, "cbuffer")?;
764 "b"
765 }
766 crate::AddressSpace::Storage { access } => {
767 let (prefix, register) = if access.contains(crate::StorageAccess::STORE) {
768 ("RW", "u")
769 } else {
770 ("", "t")
771 };
772 write!(self.out, "{prefix}ByteAddressBuffer")?;
773 register
774 }
775 crate::AddressSpace::Handle => {
776 let handle_ty = match *inner {
777 TypeInner::BindingArray { ref base, .. } => &module.types[*base].inner,
778 _ => inner,
779 };
780
781 let register = match *handle_ty {
782 TypeInner::Sampler { .. } => "s",
783 TypeInner::Image {
785 class: crate::ImageClass::Storage { .. },
786 ..
787 } => "u",
788 _ => "t",
789 };
790 self.write_type(module, global.ty)?;
791 register
792 }
793 crate::AddressSpace::PushConstant => {
794 write!(self.out, "ConstantBuffer<")?;
796 "b"
797 }
798 };
799
800 if global.space == crate::AddressSpace::PushConstant {
803 self.write_global_type(module, global.ty)?;
804
805 if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
807 self.write_array_size(module, base, size)?;
808 }
809
810 write!(self.out, ">")?;
812 }
813
814 let name = &self.names[&NameKey::GlobalVariable(handle)];
815 write!(self.out, " {name}")?;
816
817 if global.space == crate::AddressSpace::PushConstant {
820 let target = self
821 .options
822 .push_constants_target
823 .as_ref()
824 .expect("No bind target was defined for the push constants block");
825 write!(self.out, ": register(b{}", target.register)?;
826 if target.space != 0 {
827 write!(self.out, ", space{}", target.space)?;
828 }
829 write!(self.out, ")")?;
830 }
831
832 if let Some(ref binding) = global.binding {
833 let bt = self.options.resolve_resource_binding(binding).unwrap();
835
836 if let TypeInner::BindingArray { base, size, .. } = module.types[global.ty].inner {
838 if let Some(overridden_size) = bt.binding_array_size {
839 write!(self.out, "[{overridden_size}]")?;
840 } else {
841 self.write_array_size(module, base, size)?;
842 }
843 }
844
845 write!(self.out, " : register({}{}", register_ty, bt.register)?;
846 if bt.space != 0 {
847 write!(self.out, ", space{}", bt.space)?;
848 }
849 write!(self.out, ")")?;
850 } else {
851 if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
853 self.write_array_size(module, base, size)?;
854 }
855 if global.space == crate::AddressSpace::Private {
856 write!(self.out, " = ")?;
857 if let Some(init) = global.init {
858 self.write_const_expression(module, init)?;
859 } else {
860 self.write_default_init(module, global.ty)?;
861 }
862 }
863 }
864
865 if global.space == crate::AddressSpace::Uniform {
866 write!(self.out, " {{ ")?;
867
868 self.write_global_type(module, global.ty)?;
869
870 write!(
871 self.out,
872 " {}",
873 &self.names[&NameKey::GlobalVariable(handle)]
874 )?;
875
876 if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
878 self.write_array_size(module, base, size)?;
879 }
880
881 writeln!(self.out, "; }}")?;
882 } else {
883 writeln!(self.out, ";")?;
884 }
885
886 Ok(())
887 }
888
889 fn write_global_constant(
894 &mut self,
895 module: &Module,
896 handle: Handle<crate::Constant>,
897 ) -> BackendResult {
898 write!(self.out, "static const ")?;
899 let constant = &module.constants[handle];
900 self.write_type(module, constant.ty)?;
901 let name = &self.names[&NameKey::Constant(handle)];
902 write!(self.out, " {}", name)?;
903 if let TypeInner::Array { base, size, .. } = module.types[constant.ty].inner {
905 self.write_array_size(module, base, size)?;
906 }
907 write!(self.out, " = ")?;
908 self.write_const_expression(module, constant.init)?;
909 writeln!(self.out, ";")?;
910 Ok(())
911 }
912
913 pub(super) fn write_array_size(
914 &mut self,
915 module: &Module,
916 base: Handle<crate::Type>,
917 size: crate::ArraySize,
918 ) -> BackendResult {
919 write!(self.out, "[")?;
920
921 match size {
922 crate::ArraySize::Constant(size) => {
923 write!(self.out, "{size}")?;
924 }
925 crate::ArraySize::Dynamic => unreachable!(),
926 }
927
928 write!(self.out, "]")?;
929
930 if let TypeInner::Array {
931 base: next_base,
932 size: next_size,
933 ..
934 } = module.types[base].inner
935 {
936 self.write_array_size(module, next_base, next_size)?;
937 }
938
939 Ok(())
940 }
941
942 fn write_struct(
947 &mut self,
948 module: &Module,
949 handle: Handle<crate::Type>,
950 members: &[crate::StructMember],
951 span: u32,
952 shader_stage: Option<(ShaderStage, Io)>,
953 ) -> BackendResult {
954 let struct_name = &self.names[&NameKey::Type(handle)];
956 writeln!(self.out, "struct {struct_name} {{")?;
957
958 let mut last_offset = 0;
959 for (index, member) in members.iter().enumerate() {
960 if member.binding.is_none() && member.offset > last_offset {
961 let padding = (member.offset - last_offset) / 4;
965 for i in 0..padding {
966 writeln!(self.out, "{}int _pad{}_{};", back::INDENT, index, i)?;
967 }
968 }
969 let ty_inner = &module.types[member.ty].inner;
970 last_offset = member.offset + ty_inner.size_hlsl(module.to_ctx());
971
972 write!(self.out, "{}", back::INDENT)?;
974
975 match module.types[member.ty].inner {
976 TypeInner::Array { base, size, .. } => {
977 self.write_global_type(module, member.ty)?;
980
981 write!(
983 self.out,
984 " {}",
985 &self.names[&NameKey::StructMember(handle, index as u32)]
986 )?;
987 self.write_array_size(module, base, size)?;
989 }
990 TypeInner::Matrix {
993 rows,
994 columns,
995 scalar,
996 } if member.binding.is_none() && rows == crate::VectorSize::Bi => {
997 let vec_ty = crate::TypeInner::Vector { size: rows, scalar };
998 let field_name_key = NameKey::StructMember(handle, index as u32);
999
1000 for i in 0..columns as u8 {
1001 if i != 0 {
1002 write!(self.out, "; ")?;
1003 }
1004 self.write_value_type(module, &vec_ty)?;
1005 write!(self.out, " {}_{}", &self.names[&field_name_key], i)?;
1006 }
1007 }
1008 _ => {
1009 if let Some(ref binding) = member.binding {
1011 self.write_modifier(binding)?;
1012 }
1013
1014 if let TypeInner::Matrix { .. } = module.types[member.ty].inner {
1018 write!(self.out, "row_major ")?;
1019 }
1020
1021 self.write_type(module, member.ty)?;
1023 write!(
1024 self.out,
1025 " {}",
1026 &self.names[&NameKey::StructMember(handle, index as u32)]
1027 )?;
1028 }
1029 }
1030
1031 self.write_semantic(&member.binding, shader_stage)?;
1032 writeln!(self.out, ";")?;
1033 }
1034
1035 if members.last().unwrap().binding.is_none() && span > last_offset {
1037 let padding = (span - last_offset) / 4;
1038 for i in 0..padding {
1039 writeln!(self.out, "{}int _end_pad_{};", back::INDENT, i)?;
1040 }
1041 }
1042
1043 writeln!(self.out, "}};")?;
1044 Ok(())
1045 }
1046
1047 pub(super) fn write_global_type(
1052 &mut self,
1053 module: &Module,
1054 ty: Handle<crate::Type>,
1055 ) -> BackendResult {
1056 let matrix_data = get_inner_matrix_data(module, ty);
1057
1058 if let Some(MatrixType {
1061 columns,
1062 rows: crate::VectorSize::Bi,
1063 width: 4,
1064 }) = matrix_data
1065 {
1066 write!(self.out, "__mat{}x2", columns as u8)?;
1067 } else {
1068 if matrix_data.is_some() {
1072 write!(self.out, "row_major ")?;
1073 }
1074
1075 self.write_type(module, ty)?;
1076 }
1077
1078 Ok(())
1079 }
1080
1081 pub(super) fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
1086 let inner = &module.types[ty].inner;
1087 match *inner {
1088 TypeInner::Struct { .. } => write!(self.out, "{}", self.names[&NameKey::Type(ty)])?,
1089 TypeInner::Array { base, .. } | TypeInner::BindingArray { base, .. } => {
1091 self.write_type(module, base)?
1092 }
1093 ref other => self.write_value_type(module, other)?,
1094 }
1095
1096 Ok(())
1097 }
1098
1099 pub(super) fn write_value_type(&mut self, module: &Module, inner: &TypeInner) -> BackendResult {
1104 match *inner {
1105 TypeInner::Scalar(scalar) | TypeInner::Atomic(scalar) => {
1106 write!(self.out, "{}", scalar.to_hlsl_str()?)?;
1107 }
1108 TypeInner::Vector { size, scalar } => {
1109 write!(
1110 self.out,
1111 "{}{}",
1112 scalar.to_hlsl_str()?,
1113 back::vector_size_str(size)
1114 )?;
1115 }
1116 TypeInner::Matrix {
1117 columns,
1118 rows,
1119 scalar,
1120 } => {
1121 write!(
1126 self.out,
1127 "{}{}x{}",
1128 scalar.to_hlsl_str()?,
1129 back::vector_size_str(columns),
1130 back::vector_size_str(rows),
1131 )?;
1132 }
1133 TypeInner::Image {
1134 dim,
1135 arrayed,
1136 class,
1137 } => {
1138 self.write_image_type(dim, arrayed, class)?;
1139 }
1140 TypeInner::Sampler { comparison } => {
1141 let sampler = if comparison {
1142 "SamplerComparisonState"
1143 } else {
1144 "SamplerState"
1145 };
1146 write!(self.out, "{sampler}")?;
1147 }
1148 TypeInner::Array { base, size, .. } | TypeInner::BindingArray { base, size } => {
1152 self.write_array_size(module, base, size)?;
1153 }
1154 _ => return Err(Error::Unimplemented(format!("write_value_type {inner:?}"))),
1155 }
1156
1157 Ok(())
1158 }
1159
1160 fn write_function(
1164 &mut self,
1165 module: &Module,
1166 name: &str,
1167 func: &crate::Function,
1168 func_ctx: &back::FunctionCtx<'_>,
1169 info: &valid::FunctionInfo,
1170 ) -> BackendResult {
1171 self.update_expressions_to_bake(module, func, info);
1174
1175 if let Some(crate::FunctionResult {
1177 binding:
1178 Some(
1179 ref binding @ crate::Binding::BuiltIn(crate::BuiltIn::Position {
1180 invariant: true,
1181 }),
1182 ),
1183 ..
1184 }) = func.result
1185 {
1186 self.write_modifier(binding)?;
1187 }
1188
1189 if let Some(ref result) = func.result {
1191 match func_ctx.ty {
1192 back::FunctionType::Function(_) => {
1193 self.write_type(module, result.ty)?;
1194 }
1195 back::FunctionType::EntryPoint(index) => {
1196 if let Some(ref ep_output) = self.entry_point_io[index as usize].output {
1197 write!(self.out, "{}", ep_output.ty_name)?;
1198 } else {
1199 self.write_type(module, result.ty)?;
1200 }
1201 }
1202 }
1203 } else {
1204 write!(self.out, "void")?;
1205 }
1206
1207 write!(self.out, " {name}(")?;
1209
1210 let need_workgroup_variables_initialization =
1211 self.need_workgroup_variables_initialization(func_ctx, module);
1212
1213 match func_ctx.ty {
1215 back::FunctionType::Function(handle) => {
1216 for (index, arg) in func.arguments.iter().enumerate() {
1217 if index != 0 {
1218 write!(self.out, ", ")?;
1219 }
1220 let arg_ty = match module.types[arg.ty].inner {
1222 TypeInner::Pointer { base, .. } => {
1224 write!(self.out, "inout ")?;
1226 base
1227 }
1228 _ => arg.ty,
1229 };
1230 self.write_type(module, arg_ty)?;
1231
1232 let argument_name =
1233 &self.names[&NameKey::FunctionArgument(handle, index as u32)];
1234
1235 write!(self.out, " {argument_name}")?;
1237 if let TypeInner::Array { base, size, .. } = module.types[arg_ty].inner {
1238 self.write_array_size(module, base, size)?;
1239 }
1240 }
1241 }
1242 back::FunctionType::EntryPoint(ep_index) => {
1243 if let Some(ref ep_input) = self.entry_point_io[ep_index as usize].input {
1244 write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name)?;
1245 } else {
1246 let stage = module.entry_points[ep_index as usize].stage;
1247 for (index, arg) in func.arguments.iter().enumerate() {
1248 if index != 0 {
1249 write!(self.out, ", ")?;
1250 }
1251 self.write_type(module, arg.ty)?;
1252
1253 let argument_name =
1254 &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)];
1255
1256 write!(self.out, " {argument_name}")?;
1257 if let TypeInner::Array { base, size, .. } = module.types[arg.ty].inner {
1258 self.write_array_size(module, base, size)?;
1259 }
1260
1261 self.write_semantic(&arg.binding, Some((stage, Io::Input)))?;
1262 }
1263 }
1264 if need_workgroup_variables_initialization {
1265 if self.entry_point_io[ep_index as usize].input.is_some()
1266 || !func.arguments.is_empty()
1267 {
1268 write!(self.out, ", ")?;
1269 }
1270 write!(self.out, "uint3 __local_invocation_id : SV_GroupThreadID")?;
1271 }
1272 }
1273 }
1274 write!(self.out, ")")?;
1276
1277 if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1279 let stage = module.entry_points[index as usize].stage;
1280 if let Some(crate::FunctionResult { ref binding, .. }) = func.result {
1281 self.write_semantic(binding, Some((stage, Io::Output)))?;
1282 }
1283 }
1284
1285 writeln!(self.out)?;
1287 writeln!(self.out, "{{")?;
1288
1289 if need_workgroup_variables_initialization {
1290 self.write_workgroup_variables_initialization(func_ctx, module)?;
1291 }
1292
1293 if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
1294 self.write_ep_arguments_initialization(module, func, index)?;
1295 }
1296
1297 for (handle, local) in func.local_variables.iter() {
1299 write!(self.out, "{}", back::INDENT)?;
1301
1302 self.write_type(module, local.ty)?;
1305 write!(self.out, " {}", self.names[&func_ctx.name_key(handle)])?;
1306 if let TypeInner::Array { base, size, .. } = module.types[local.ty].inner {
1308 self.write_array_size(module, base, size)?;
1309 }
1310
1311 write!(self.out, " = ")?;
1312 if let Some(init) = local.init {
1314 self.write_expr(module, init, func_ctx)?;
1315 } else {
1316 self.write_default_init(module, local.ty)?;
1318 }
1319
1320 writeln!(self.out, ";")?
1322 }
1323
1324 if !func.local_variables.is_empty() {
1325 writeln!(self.out)?;
1326 }
1327
1328 for sta in func.body.iter() {
1330 self.write_stmt(module, sta, func_ctx, back::Level(1))?;
1332 }
1333
1334 writeln!(self.out, "}}")?;
1335
1336 self.named_expressions.clear();
1337
1338 Ok(())
1339 }
1340
1341 fn need_workgroup_variables_initialization(
1342 &mut self,
1343 func_ctx: &back::FunctionCtx,
1344 module: &Module,
1345 ) -> bool {
1346 self.options.zero_initialize_workgroup_memory
1347 && func_ctx.ty.is_compute_entry_point(module)
1348 && module.global_variables.iter().any(|(handle, var)| {
1349 !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
1350 })
1351 }
1352
1353 fn write_workgroup_variables_initialization(
1354 &mut self,
1355 func_ctx: &back::FunctionCtx,
1356 module: &Module,
1357 ) -> BackendResult {
1358 let level = back::Level(1);
1359
1360 writeln!(
1361 self.out,
1362 "{level}if (all(__local_invocation_id == uint3(0u, 0u, 0u))) {{"
1363 )?;
1364
1365 let vars = module.global_variables.iter().filter(|&(handle, var)| {
1366 !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
1367 });
1368
1369 for (handle, var) in vars {
1370 let name = &self.names[&NameKey::GlobalVariable(handle)];
1371 write!(self.out, "{}{} = ", level.next(), name)?;
1372 self.write_default_init(module, var.ty)?;
1373 writeln!(self.out, ";")?;
1374 }
1375
1376 writeln!(self.out, "{level}}}")?;
1377 self.write_barrier(crate::Barrier::WORK_GROUP, level)
1378 }
1379
1380 fn write_stmt(
1385 &mut self,
1386 module: &Module,
1387 stmt: &crate::Statement,
1388 func_ctx: &back::FunctionCtx<'_>,
1389 level: back::Level,
1390 ) -> BackendResult {
1391 use crate::Statement;
1392
1393 match *stmt {
1394 Statement::Emit(ref range) => {
1395 for handle in range.clone() {
1396 let ptr_class = func_ctx.resolve_type(handle, &module.types).pointer_space();
1397 let expr_name = if ptr_class.is_some() {
1398 None
1402 } else if let Some(name) = func_ctx.named_expressions.get(&handle) {
1403 Some(self.namer.call(name))
1408 } else if self.need_bake_expressions.contains(&handle) {
1409 Some(format!("_expr{}", handle.index()))
1410 } else {
1411 None
1412 };
1413
1414 if let Some(name) = expr_name {
1415 write!(self.out, "{level}")?;
1416 self.write_named_expr(module, handle, name, handle, func_ctx)?;
1417 }
1418 }
1419 }
1420 Statement::Block(ref block) => {
1422 write!(self.out, "{level}")?;
1423 writeln!(self.out, "{{")?;
1424 for sta in block.iter() {
1425 self.write_stmt(module, sta, func_ctx, level.next())?
1427 }
1428 writeln!(self.out, "{level}}}")?
1429 }
1430 Statement::If {
1432 condition,
1433 ref accept,
1434 ref reject,
1435 } => {
1436 write!(self.out, "{level}")?;
1437 write!(self.out, "if (")?;
1438 self.write_expr(module, condition, func_ctx)?;
1439 writeln!(self.out, ") {{")?;
1440
1441 let l2 = level.next();
1442 for sta in accept {
1443 self.write_stmt(module, sta, func_ctx, l2)?;
1445 }
1446
1447 if !reject.is_empty() {
1450 writeln!(self.out, "{level}}} else {{")?;
1451
1452 for sta in reject {
1453 self.write_stmt(module, sta, func_ctx, l2)?;
1455 }
1456 }
1457
1458 writeln!(self.out, "{level}}}")?
1459 }
1460 Statement::Kill => writeln!(self.out, "{level}discard;")?,
1462 Statement::Return { value: None } => {
1463 writeln!(self.out, "{level}return;")?;
1464 }
1465 Statement::Return { value: Some(expr) } => {
1466 let base_ty_res = &func_ctx.info[expr].ty;
1467 let mut resolved = base_ty_res.inner_with(&module.types);
1468 if let TypeInner::Pointer { base, space: _ } = *resolved {
1469 resolved = &module.types[base].inner;
1470 }
1471
1472 if let TypeInner::Struct { .. } = *resolved {
1473 let ty = base_ty_res.handle().unwrap();
1475 let struct_name = &self.names[&NameKey::Type(ty)];
1476 let variable_name = self.namer.call(&struct_name.to_lowercase());
1477 write!(self.out, "{level}const {struct_name} {variable_name} = ",)?;
1478 self.write_expr(module, expr, func_ctx)?;
1479 writeln!(self.out, ";")?;
1480
1481 let ep_output = match func_ctx.ty {
1483 back::FunctionType::Function(_) => None,
1484 back::FunctionType::EntryPoint(index) => {
1485 self.entry_point_io[index as usize].output.as_ref()
1486 }
1487 };
1488 let final_name = match ep_output {
1489 Some(ep_output) => {
1490 let final_name = self.namer.call(&variable_name);
1491 write!(
1492 self.out,
1493 "{}const {} {} = {{ ",
1494 level, ep_output.ty_name, final_name,
1495 )?;
1496 for (index, m) in ep_output.members.iter().enumerate() {
1497 if index != 0 {
1498 write!(self.out, ", ")?;
1499 }
1500 let member_name = &self.names[&NameKey::StructMember(ty, m.index)];
1501 write!(self.out, "{variable_name}.{member_name}")?;
1502 }
1503 writeln!(self.out, " }};")?;
1504 final_name
1505 }
1506 None => variable_name,
1507 };
1508 writeln!(self.out, "{level}return {final_name};")?;
1509 } else {
1510 write!(self.out, "{level}return ")?;
1511 self.write_expr(module, expr, func_ctx)?;
1512 writeln!(self.out, ";")?
1513 }
1514 }
1515 Statement::Store { pointer, value } => {
1516 let ty_inner = func_ctx.resolve_type(pointer, &module.types);
1517 if let Some(crate::AddressSpace::Storage { .. }) = ty_inner.pointer_space() {
1518 let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
1519 self.write_storage_store(
1520 module,
1521 var_handle,
1522 StoreValue::Expression(value),
1523 func_ctx,
1524 level,
1525 )?;
1526 } else {
1527 struct MatrixAccess {
1533 base: Handle<crate::Expression>,
1534 index: u32,
1535 }
1536 enum Index {
1537 Expression(Handle<crate::Expression>),
1538 Static(u32),
1539 }
1540
1541 let get_members = |expr: Handle<crate::Expression>| {
1542 let resolved = func_ctx.resolve_type(expr, &module.types);
1543 match *resolved {
1544 TypeInner::Pointer { base, .. } => match module.types[base].inner {
1545 TypeInner::Struct { ref members, .. } => Some(members),
1546 _ => None,
1547 },
1548 _ => None,
1549 }
1550 };
1551
1552 let mut matrix = None;
1553 let mut vector = None;
1554 let mut scalar = None;
1555
1556 let mut current_expr = pointer;
1557 for _ in 0..3 {
1558 let resolved = func_ctx.resolve_type(current_expr, &module.types);
1559
1560 match (resolved, &func_ctx.expressions[current_expr]) {
1561 (
1562 &TypeInner::Pointer { base: ty, .. },
1563 &crate::Expression::AccessIndex { base, index },
1564 ) if matches!(
1565 module.types[ty].inner,
1566 TypeInner::Matrix {
1567 rows: crate::VectorSize::Bi,
1568 ..
1569 }
1570 ) && get_members(base)
1571 .map(|members| members[index as usize].binding.is_none())
1572 == Some(true) =>
1573 {
1574 matrix = Some(MatrixAccess { base, index });
1575 break;
1576 }
1577 (
1578 &TypeInner::ValuePointer {
1579 size: Some(crate::VectorSize::Bi),
1580 ..
1581 },
1582 &crate::Expression::Access { base, index },
1583 ) => {
1584 vector = Some(Index::Expression(index));
1585 current_expr = base;
1586 }
1587 (
1588 &TypeInner::ValuePointer {
1589 size: Some(crate::VectorSize::Bi),
1590 ..
1591 },
1592 &crate::Expression::AccessIndex { base, index },
1593 ) => {
1594 vector = Some(Index::Static(index));
1595 current_expr = base;
1596 }
1597 (
1598 &TypeInner::ValuePointer { size: None, .. },
1599 &crate::Expression::Access { base, index },
1600 ) => {
1601 scalar = Some(Index::Expression(index));
1602 current_expr = base;
1603 }
1604 (
1605 &TypeInner::ValuePointer { size: None, .. },
1606 &crate::Expression::AccessIndex { base, index },
1607 ) => {
1608 scalar = Some(Index::Static(index));
1609 current_expr = base;
1610 }
1611 _ => break,
1612 }
1613 }
1614
1615 write!(self.out, "{level}")?;
1616
1617 if let Some(MatrixAccess { index, base }) = matrix {
1618 let base_ty_res = &func_ctx.info[base].ty;
1619 let resolved = base_ty_res.inner_with(&module.types);
1620 let ty = match *resolved {
1621 TypeInner::Pointer { base, .. } => base,
1622 _ => base_ty_res.handle().unwrap(),
1623 };
1624
1625 if let Some(Index::Static(vec_index)) = vector {
1626 self.write_expr(module, base, func_ctx)?;
1627 write!(
1628 self.out,
1629 ".{}_{}",
1630 &self.names[&NameKey::StructMember(ty, index)],
1631 vec_index
1632 )?;
1633
1634 if let Some(scalar_index) = scalar {
1635 write!(self.out, "[")?;
1636 match scalar_index {
1637 Index::Static(index) => {
1638 write!(self.out, "{index}")?;
1639 }
1640 Index::Expression(index) => {
1641 self.write_expr(module, index, func_ctx)?;
1642 }
1643 }
1644 write!(self.out, "]")?;
1645 }
1646
1647 write!(self.out, " = ")?;
1648 self.write_expr(module, value, func_ctx)?;
1649 writeln!(self.out, ";")?;
1650 } else {
1651 let access = WrappedStructMatrixAccess { ty, index };
1652 match (&vector, &scalar) {
1653 (&Some(_), &Some(_)) => {
1654 self.write_wrapped_struct_matrix_set_scalar_function_name(
1655 access,
1656 )?;
1657 }
1658 (&Some(_), &None) => {
1659 self.write_wrapped_struct_matrix_set_vec_function_name(access)?;
1660 }
1661 (&None, _) => {
1662 self.write_wrapped_struct_matrix_set_function_name(access)?;
1663 }
1664 }
1665
1666 write!(self.out, "(")?;
1667 self.write_expr(module, base, func_ctx)?;
1668 write!(self.out, ", ")?;
1669 self.write_expr(module, value, func_ctx)?;
1670
1671 if let Some(Index::Expression(vec_index)) = vector {
1672 write!(self.out, ", ")?;
1673 self.write_expr(module, vec_index, func_ctx)?;
1674
1675 if let Some(scalar_index) = scalar {
1676 write!(self.out, ", ")?;
1677 match scalar_index {
1678 Index::Static(index) => {
1679 write!(self.out, "{index}")?;
1680 }
1681 Index::Expression(index) => {
1682 self.write_expr(module, index, func_ctx)?;
1683 }
1684 }
1685 }
1686 }
1687 writeln!(self.out, ");")?;
1688 }
1689 } else {
1690 struct MatrixData {
1693 columns: crate::VectorSize,
1694 base: Handle<crate::Expression>,
1695 }
1696
1697 enum Index {
1698 Expression(Handle<crate::Expression>),
1699 Static(u32),
1700 }
1701
1702 let mut matrix = None;
1703 let mut vector = None;
1704 let mut scalar = None;
1705
1706 let mut current_expr = pointer;
1707 for _ in 0..3 {
1708 let resolved = func_ctx.resolve_type(current_expr, &module.types);
1709 match (resolved, &func_ctx.expressions[current_expr]) {
1710 (
1711 &TypeInner::ValuePointer {
1712 size: Some(crate::VectorSize::Bi),
1713 ..
1714 },
1715 &crate::Expression::Access { base, index },
1716 ) => {
1717 vector = Some(index);
1718 current_expr = base;
1719 }
1720 (
1721 &TypeInner::ValuePointer { size: None, .. },
1722 &crate::Expression::Access { base, index },
1723 ) => {
1724 scalar = Some(Index::Expression(index));
1725 current_expr = base;
1726 }
1727 (
1728 &TypeInner::ValuePointer { size: None, .. },
1729 &crate::Expression::AccessIndex { base, index },
1730 ) => {
1731 scalar = Some(Index::Static(index));
1732 current_expr = base;
1733 }
1734 _ => {
1735 if let Some(MatrixType {
1736 columns,
1737 rows: crate::VectorSize::Bi,
1738 width: 4,
1739 }) = get_inner_matrix_of_struct_array_member(
1740 module,
1741 current_expr,
1742 func_ctx,
1743 true,
1744 ) {
1745 matrix = Some(MatrixData {
1746 columns,
1747 base: current_expr,
1748 });
1749 }
1750
1751 break;
1752 }
1753 }
1754 }
1755
1756 if let (Some(MatrixData { columns, base }), Some(vec_index)) =
1757 (matrix, vector)
1758 {
1759 if scalar.is_some() {
1760 write!(self.out, "__set_el_of_mat{}x2", columns as u8)?;
1761 } else {
1762 write!(self.out, "__set_col_of_mat{}x2", columns as u8)?;
1763 }
1764 write!(self.out, "(")?;
1765 self.write_expr(module, base, func_ctx)?;
1766 write!(self.out, ", ")?;
1767 self.write_expr(module, vec_index, func_ctx)?;
1768
1769 if let Some(scalar_index) = scalar {
1770 write!(self.out, ", ")?;
1771 match scalar_index {
1772 Index::Static(index) => {
1773 write!(self.out, "{index}")?;
1774 }
1775 Index::Expression(index) => {
1776 self.write_expr(module, index, func_ctx)?;
1777 }
1778 }
1779 }
1780
1781 write!(self.out, ", ")?;
1782 self.write_expr(module, value, func_ctx)?;
1783
1784 writeln!(self.out, ");")?;
1785 } else {
1786 self.write_expr(module, pointer, func_ctx)?;
1787 write!(self.out, " = ")?;
1788
1789 if let Some(MatrixType {
1794 columns,
1795 rows: crate::VectorSize::Bi,
1796 width: 4,
1797 }) = get_inner_matrix_of_struct_array_member(
1798 module, pointer, func_ctx, false,
1799 ) {
1800 let mut resolved = func_ctx.resolve_type(pointer, &module.types);
1801 if let TypeInner::Pointer { base, .. } = *resolved {
1802 resolved = &module.types[base].inner;
1803 }
1804
1805 write!(self.out, "(__mat{}x2", columns as u8)?;
1806 if let TypeInner::Array { base, size, .. } = *resolved {
1807 self.write_array_size(module, base, size)?;
1808 }
1809 write!(self.out, ")")?;
1810 }
1811
1812 self.write_expr(module, value, func_ctx)?;
1813 writeln!(self.out, ";")?
1814 }
1815 }
1816 }
1817 }
1818 Statement::Loop {
1819 ref body,
1820 ref continuing,
1821 break_if,
1822 } => {
1823 let l2 = level.next();
1824 if !continuing.is_empty() || break_if.is_some() {
1825 let gate_name = self.namer.call("loop_init");
1826 writeln!(self.out, "{level}bool {gate_name} = true;")?;
1827 writeln!(self.out, "{level}while(true) {{")?;
1828 writeln!(self.out, "{l2}if (!{gate_name}) {{")?;
1829 let l3 = l2.next();
1830 for sta in continuing.iter() {
1831 self.write_stmt(module, sta, func_ctx, l3)?;
1832 }
1833 if let Some(condition) = break_if {
1834 write!(self.out, "{l3}if (")?;
1835 self.write_expr(module, condition, func_ctx)?;
1836 writeln!(self.out, ") {{")?;
1837 writeln!(self.out, "{}break;", l3.next())?;
1838 writeln!(self.out, "{l3}}}")?;
1839 }
1840 writeln!(self.out, "{l2}}}")?;
1841 writeln!(self.out, "{l2}{gate_name} = false;")?;
1842 } else {
1843 writeln!(self.out, "{level}while(true) {{")?;
1844 }
1845
1846 for sta in body.iter() {
1847 self.write_stmt(module, sta, func_ctx, l2)?;
1848 }
1849 writeln!(self.out, "{level}}}")?
1850 }
1851 Statement::Break => writeln!(self.out, "{level}break;")?,
1852 Statement::Continue => writeln!(self.out, "{level}continue;")?,
1853 Statement::Barrier(barrier) => {
1854 self.write_barrier(barrier, level)?;
1855 }
1856 Statement::ImageStore {
1857 image,
1858 coordinate,
1859 array_index,
1860 value,
1861 } => {
1862 write!(self.out, "{level}")?;
1863 self.write_expr(module, image, func_ctx)?;
1864
1865 write!(self.out, "[")?;
1866 if let Some(index) = array_index {
1867 write!(self.out, "int3(")?;
1869 self.write_expr(module, coordinate, func_ctx)?;
1870 write!(self.out, ", ")?;
1871 self.write_expr(module, index, func_ctx)?;
1872 write!(self.out, ")")?;
1873 } else {
1874 self.write_expr(module, coordinate, func_ctx)?;
1875 }
1876 write!(self.out, "]")?;
1877
1878 write!(self.out, " = ")?;
1879 self.write_expr(module, value, func_ctx)?;
1880 writeln!(self.out, ";")?;
1881 }
1882 Statement::Call {
1883 function,
1884 ref arguments,
1885 result,
1886 } => {
1887 write!(self.out, "{level}")?;
1888 if let Some(expr) = result {
1889 write!(self.out, "const ")?;
1890 let name = format!("{}{}", back::BAKE_PREFIX, expr.index());
1891 let expr_ty = &func_ctx.info[expr].ty;
1892 match *expr_ty {
1893 proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
1894 proc::TypeResolution::Value(ref value) => {
1895 self.write_value_type(module, value)?
1896 }
1897 };
1898 write!(self.out, " {name} = ")?;
1899 self.named_expressions.insert(expr, name);
1900 }
1901 let func_name = &self.names[&NameKey::Function(function)];
1902 write!(self.out, "{func_name}(")?;
1903 for (index, argument) in arguments.iter().enumerate() {
1904 if index != 0 {
1905 write!(self.out, ", ")?;
1906 }
1907 self.write_expr(module, *argument, func_ctx)?;
1908 }
1909 writeln!(self.out, ");")?
1910 }
1911 Statement::Atomic {
1912 pointer,
1913 ref fun,
1914 value,
1915 result,
1916 } => {
1917 write!(self.out, "{level}")?;
1918 let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
1919 match func_ctx.info[result].ty {
1920 proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
1921 proc::TypeResolution::Value(ref value) => {
1922 self.write_value_type(module, value)?
1923 }
1924 };
1925
1926 let pointer_space = func_ctx
1928 .resolve_type(pointer, &module.types)
1929 .pointer_space()
1930 .unwrap();
1931
1932 let fun_str = fun.to_hlsl_suffix();
1933 write!(self.out, " {res_name}; ")?;
1934 match pointer_space {
1935 crate::AddressSpace::WorkGroup => {
1936 write!(self.out, "Interlocked{fun_str}(")?;
1937 self.write_expr(module, pointer, func_ctx)?;
1938 }
1939 crate::AddressSpace::Storage { .. } => {
1940 let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
1941 let chain = mem::take(&mut self.temp_access_chain);
1945 let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
1946 write!(self.out, "{var_name}.Interlocked{fun_str}(")?;
1947 self.write_storage_address(module, &chain, func_ctx)?;
1948 self.temp_access_chain = chain;
1949 }
1950 ref other => {
1951 return Err(Error::Custom(format!(
1952 "invalid address space {other:?} for atomic statement"
1953 )))
1954 }
1955 }
1956 write!(self.out, ", ")?;
1957 match *fun {
1959 crate::AtomicFunction::Subtract => {
1960 write!(self.out, "-")?;
1962 }
1963 crate::AtomicFunction::Exchange { compare: Some(_) } => {
1964 return Err(Error::Unimplemented("atomic CompareExchange".to_string()));
1965 }
1966 _ => {}
1967 }
1968 self.write_expr(module, value, func_ctx)?;
1969 writeln!(self.out, ", {res_name});")?;
1970 self.named_expressions.insert(result, res_name);
1971 }
1972 Statement::WorkGroupUniformLoad { pointer, result } => {
1973 self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
1974 write!(self.out, "{level}")?;
1975 let name = format!("_expr{}", result.index());
1976 self.write_named_expr(module, pointer, name, result, func_ctx)?;
1977
1978 self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
1979 }
1980 Statement::Switch {
1981 selector,
1982 ref cases,
1983 } => {
1984 write!(self.out, "{level}")?;
1986 write!(self.out, "switch(")?;
1987 self.write_expr(module, selector, func_ctx)?;
1988 writeln!(self.out, ") {{")?;
1989
1990 let indent_level_1 = level.next();
1992 let indent_level_2 = indent_level_1.next();
1993
1994 for (i, case) in cases.iter().enumerate() {
1995 match case.value {
1996 crate::SwitchValue::I32(value) => {
1997 write!(self.out, "{indent_level_1}case {value}:")?
1998 }
1999 crate::SwitchValue::U32(value) => {
2000 write!(self.out, "{indent_level_1}case {value}u:")?
2001 }
2002 crate::SwitchValue::Default => {
2003 write!(self.out, "{indent_level_1}default:")?
2004 }
2005 }
2006
2007 let write_block_braces = !(case.fall_through && case.body.is_empty());
2014 if write_block_braces {
2015 writeln!(self.out, " {{")?;
2016 } else {
2017 writeln!(self.out)?;
2018 }
2019
2020 if case.fall_through && !case.body.is_empty() {
2038 let curr_len = i + 1;
2039 let end_case_idx = curr_len
2040 + cases
2041 .iter()
2042 .skip(curr_len)
2043 .position(|case| !case.fall_through)
2044 .unwrap();
2045 let indent_level_3 = indent_level_2.next();
2046 for case in &cases[i..=end_case_idx] {
2047 writeln!(self.out, "{indent_level_2}{{")?;
2048 let prev_len = self.named_expressions.len();
2049 for sta in case.body.iter() {
2050 self.write_stmt(module, sta, func_ctx, indent_level_3)?;
2051 }
2052 self.named_expressions.truncate(prev_len);
2054 writeln!(self.out, "{indent_level_2}}}")?;
2055 }
2056
2057 let last_case = &cases[end_case_idx];
2058 if last_case.body.last().map_or(true, |s| !s.is_terminator()) {
2059 writeln!(self.out, "{indent_level_2}break;")?;
2060 }
2061 } else {
2062 for sta in case.body.iter() {
2063 self.write_stmt(module, sta, func_ctx, indent_level_2)?;
2064 }
2065 if !case.fall_through
2066 && case.body.last().map_or(true, |s| !s.is_terminator())
2067 {
2068 writeln!(self.out, "{indent_level_2}break;")?;
2069 }
2070 }
2071
2072 if write_block_braces {
2073 writeln!(self.out, "{indent_level_1}}}")?;
2074 }
2075 }
2076
2077 writeln!(self.out, "{level}}}")?
2078 }
2079 Statement::RayQuery { .. } => unreachable!(),
2080 Statement::SubgroupBallot { result, predicate } => {
2081 write!(self.out, "{level}")?;
2082 let name = format!("{}{}", back::BAKE_PREFIX, result.index());
2083 write!(self.out, "const uint4 {name} = ")?;
2084 self.named_expressions.insert(result, name);
2085
2086 write!(self.out, "WaveActiveBallot(")?;
2087 match predicate {
2088 Some(predicate) => self.write_expr(module, predicate, func_ctx)?,
2089 None => write!(self.out, "true")?,
2090 }
2091 writeln!(self.out, ");")?;
2092 }
2093 Statement::SubgroupCollectiveOperation {
2094 op,
2095 collective_op,
2096 argument,
2097 result,
2098 } => {
2099 write!(self.out, "{level}")?;
2100 write!(self.out, "const ")?;
2101 let name = format!("{}{}", back::BAKE_PREFIX, result.index());
2102 match func_ctx.info[result].ty {
2103 proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2104 proc::TypeResolution::Value(ref value) => {
2105 self.write_value_type(module, value)?
2106 }
2107 };
2108 write!(self.out, " {name} = ")?;
2109 self.named_expressions.insert(result, name);
2110
2111 match (collective_op, op) {
2112 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
2113 write!(self.out, "WaveActiveAllTrue(")?
2114 }
2115 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
2116 write!(self.out, "WaveActiveAnyTrue(")?
2117 }
2118 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
2119 write!(self.out, "WaveActiveSum(")?
2120 }
2121 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
2122 write!(self.out, "WaveActiveProduct(")?
2123 }
2124 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
2125 write!(self.out, "WaveActiveMax(")?
2126 }
2127 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
2128 write!(self.out, "WaveActiveMin(")?
2129 }
2130 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
2131 write!(self.out, "WaveActiveBitAnd(")?
2132 }
2133 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
2134 write!(self.out, "WaveActiveBitOr(")?
2135 }
2136 (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
2137 write!(self.out, "WaveActiveBitXor(")?
2138 }
2139 (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
2140 write!(self.out, "WavePrefixSum(")?
2141 }
2142 (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
2143 write!(self.out, "WavePrefixProduct(")?
2144 }
2145 (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
2146 self.write_expr(module, argument, func_ctx)?;
2147 write!(self.out, " + WavePrefixSum(")?;
2148 }
2149 (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
2150 self.write_expr(module, argument, func_ctx)?;
2151 write!(self.out, " * WavePrefixProduct(")?;
2152 }
2153 _ => unimplemented!(),
2154 }
2155 self.write_expr(module, argument, func_ctx)?;
2156 writeln!(self.out, ");")?;
2157 }
2158 Statement::SubgroupGather {
2159 mode,
2160 argument,
2161 result,
2162 } => {
2163 write!(self.out, "{level}")?;
2164 write!(self.out, "const ")?;
2165 let name = format!("{}{}", back::BAKE_PREFIX, result.index());
2166 match func_ctx.info[result].ty {
2167 proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2168 proc::TypeResolution::Value(ref value) => {
2169 self.write_value_type(module, value)?
2170 }
2171 };
2172 write!(self.out, " {name} = ")?;
2173 self.named_expressions.insert(result, name);
2174
2175 if matches!(mode, crate::GatherMode::BroadcastFirst) {
2176 write!(self.out, "WaveReadLaneFirst(")?;
2177 self.write_expr(module, argument, func_ctx)?;
2178 } else {
2179 write!(self.out, "WaveReadLaneAt(")?;
2180 self.write_expr(module, argument, func_ctx)?;
2181 write!(self.out, ", ")?;
2182 match mode {
2183 crate::GatherMode::BroadcastFirst => unreachable!(),
2184 crate::GatherMode::Broadcast(index) | crate::GatherMode::Shuffle(index) => {
2185 self.write_expr(module, index, func_ctx)?;
2186 }
2187 crate::GatherMode::ShuffleDown(index) => {
2188 write!(self.out, "WaveGetLaneIndex() + ")?;
2189 self.write_expr(module, index, func_ctx)?;
2190 }
2191 crate::GatherMode::ShuffleUp(index) => {
2192 write!(self.out, "WaveGetLaneIndex() - ")?;
2193 self.write_expr(module, index, func_ctx)?;
2194 }
2195 crate::GatherMode::ShuffleXor(index) => {
2196 write!(self.out, "WaveGetLaneIndex() ^ ")?;
2197 self.write_expr(module, index, func_ctx)?;
2198 }
2199 }
2200 }
2201 writeln!(self.out, ");")?;
2202 }
2203 }
2204
2205 Ok(())
2206 }
2207
2208 fn write_const_expression(
2209 &mut self,
2210 module: &Module,
2211 expr: Handle<crate::Expression>,
2212 ) -> BackendResult {
2213 self.write_possibly_const_expression(
2214 module,
2215 expr,
2216 &module.global_expressions,
2217 |writer, expr| writer.write_const_expression(module, expr),
2218 )
2219 }
2220
2221 fn write_possibly_const_expression<E>(
2222 &mut self,
2223 module: &Module,
2224 expr: Handle<crate::Expression>,
2225 expressions: &crate::Arena<crate::Expression>,
2226 write_expression: E,
2227 ) -> BackendResult
2228 where
2229 E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult,
2230 {
2231 use crate::Expression;
2232
2233 match expressions[expr] {
2234 Expression::Literal(literal) => match literal {
2235 crate::Literal::F64(value) => write!(self.out, "{value:?}L")?,
2238 crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
2239 crate::Literal::U32(value) => write!(self.out, "{}u", value)?,
2240 crate::Literal::I32(value) => write!(self.out, "{}", value)?,
2241 crate::Literal::U64(value) => write!(self.out, "{}uL", value)?,
2242 crate::Literal::I64(value) => write!(self.out, "{}L", value)?,
2243 crate::Literal::Bool(value) => write!(self.out, "{}", value)?,
2244 crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
2245 return Err(Error::Custom(
2246 "Abstract types should not appear in IR presented to backends".into(),
2247 ));
2248 }
2249 },
2250 Expression::Constant(handle) => {
2251 let constant = &module.constants[handle];
2252 if constant.name.is_some() {
2253 write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
2254 } else {
2255 self.write_const_expression(module, constant.init)?;
2256 }
2257 }
2258 Expression::ZeroValue(ty) => {
2259 self.write_wrapped_zero_value_function_name(module, WrappedZeroValue { ty })?;
2260 write!(self.out, "()")?;
2261 }
2262 Expression::Compose { ty, ref components } => {
2263 match module.types[ty].inner {
2264 TypeInner::Struct { .. } | TypeInner::Array { .. } => {
2265 self.write_wrapped_constructor_function_name(
2266 module,
2267 WrappedConstructor { ty },
2268 )?;
2269 }
2270 _ => {
2271 self.write_type(module, ty)?;
2272 }
2273 };
2274 write!(self.out, "(")?;
2275 for (index, component) in components.iter().enumerate() {
2276 if index != 0 {
2277 write!(self.out, ", ")?;
2278 }
2279 write_expression(self, *component)?;
2280 }
2281 write!(self.out, ")")?;
2282 }
2283 Expression::Splat { size, value } => {
2284 let number_of_components = match size {
2288 crate::VectorSize::Bi => "xx",
2289 crate::VectorSize::Tri => "xxx",
2290 crate::VectorSize::Quad => "xxxx",
2291 };
2292 write!(self.out, "(")?;
2293 write_expression(self, value)?;
2294 write!(self.out, ").{number_of_components}")?
2295 }
2296 _ => unreachable!(),
2297 }
2298
2299 Ok(())
2300 }
2301
2302 pub(super) fn write_expr(
2307 &mut self,
2308 module: &Module,
2309 expr: Handle<crate::Expression>,
2310 func_ctx: &back::FunctionCtx<'_>,
2311 ) -> BackendResult {
2312 use crate::Expression;
2313
2314 let ff_input = if self.options.special_constants_binding.is_some() {
2316 func_ctx.is_fixed_function_input(expr, module)
2317 } else {
2318 None
2319 };
2320 let closing_bracket = match ff_input {
2321 Some(crate::BuiltIn::VertexIndex) => {
2322 write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX} + ")?;
2323 ")"
2324 }
2325 Some(crate::BuiltIn::InstanceIndex) => {
2326 write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE} + ",)?;
2327 ")"
2328 }
2329 Some(crate::BuiltIn::NumWorkGroups) => {
2330 write!(
2334 self.out,
2335 "uint3({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX}, {SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE}, {SPECIAL_CBUF_VAR}.{SPECIAL_OTHER})",
2336 )?;
2337 return Ok(());
2338 }
2339 _ => "",
2340 };
2341
2342 if let Some(name) = self.named_expressions.get(&expr) {
2343 write!(self.out, "{name}{closing_bracket}")?;
2344 return Ok(());
2345 }
2346
2347 let expression = &func_ctx.expressions[expr];
2348
2349 match *expression {
2350 Expression::Literal(_)
2351 | Expression::Constant(_)
2352 | Expression::ZeroValue(_)
2353 | Expression::Compose { .. }
2354 | Expression::Splat { .. } => {
2355 self.write_possibly_const_expression(
2356 module,
2357 expr,
2358 func_ctx.expressions,
2359 |writer, expr| writer.write_expr(module, expr, func_ctx),
2360 )?;
2361 }
2362 Expression::Override(_) => return Err(Error::Override),
2363 Expression::Binary {
2366 op: crate::BinaryOperator::Multiply,
2367 left,
2368 right,
2369 } if func_ctx.resolve_type(left, &module.types).is_matrix()
2370 || func_ctx.resolve_type(right, &module.types).is_matrix() =>
2371 {
2372 write!(self.out, "mul(")?;
2374 self.write_expr(module, right, func_ctx)?;
2375 write!(self.out, ", ")?;
2376 self.write_expr(module, left, func_ctx)?;
2377 write!(self.out, ")")?;
2378 }
2379
2380 Expression::Binary {
2396 op: crate::BinaryOperator::Modulo,
2397 left,
2398 right,
2399 } if func_ctx.resolve_type(left, &module.types).scalar_kind()
2400 == Some(crate::ScalarKind::Float) =>
2401 {
2402 write!(self.out, "fmod(")?;
2403 self.write_expr(module, left, func_ctx)?;
2404 write!(self.out, ", ")?;
2405 self.write_expr(module, right, func_ctx)?;
2406 write!(self.out, ")")?;
2407 }
2408 Expression::Binary { op, left, right } => {
2409 write!(self.out, "(")?;
2410 self.write_expr(module, left, func_ctx)?;
2411 write!(self.out, " {} ", crate::back::binary_operation_str(op))?;
2412 self.write_expr(module, right, func_ctx)?;
2413 write!(self.out, ")")?;
2414 }
2415 Expression::Access { base, index } => {
2416 if let Some(crate::AddressSpace::Storage { .. }) =
2417 func_ctx.resolve_type(expr, &module.types).pointer_space()
2418 {
2419 } else {
2421 if let Some(MatrixType {
2428 columns,
2429 rows: crate::VectorSize::Bi,
2430 width: 4,
2431 }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
2432 {
2433 write!(self.out, "__get_col_of_mat{}x2(", columns as u8)?;
2434 self.write_expr(module, base, func_ctx)?;
2435 write!(self.out, ", ")?;
2436 self.write_expr(module, index, func_ctx)?;
2437 write!(self.out, ")")?;
2438 return Ok(());
2439 }
2440
2441 let resolved = func_ctx.resolve_type(base, &module.types);
2442
2443 let non_uniform_qualifier = match *resolved {
2444 TypeInner::BindingArray { .. } => {
2445 let uniformity = &func_ctx.info[index].uniformity;
2446
2447 uniformity.non_uniform_result.is_some()
2448 }
2449 _ => false,
2450 };
2451
2452 self.write_expr(module, base, func_ctx)?;
2453 write!(self.out, "[")?;
2454 if non_uniform_qualifier {
2455 write!(self.out, "NonUniformResourceIndex(")?;
2456 }
2457 self.write_expr(module, index, func_ctx)?;
2458 if non_uniform_qualifier {
2459 write!(self.out, ")")?;
2460 }
2461 write!(self.out, "]")?;
2462 }
2463 }
2464 Expression::AccessIndex { base, index } => {
2465 if let Some(crate::AddressSpace::Storage { .. }) =
2466 func_ctx.resolve_type(expr, &module.types).pointer_space()
2467 {
2468 } else {
2470 fn write_access<W: fmt::Write>(
2471 writer: &mut super::Writer<'_, W>,
2472 resolved: &TypeInner,
2473 base_ty_handle: Option<Handle<crate::Type>>,
2474 index: u32,
2475 ) -> BackendResult {
2476 match *resolved {
2477 TypeInner::Vector { .. } | TypeInner::ValuePointer { .. } => {
2483 write!(writer.out, ".{}", back::COMPONENTS[index as usize])?
2485 }
2486 TypeInner::Matrix { .. }
2487 | TypeInner::Array { .. }
2488 | TypeInner::BindingArray { .. } => write!(writer.out, "[{index}]")?,
2489 TypeInner::Struct { .. } => {
2490 let ty = base_ty_handle.unwrap();
2493
2494 write!(
2495 writer.out,
2496 ".{}",
2497 &writer.names[&NameKey::StructMember(ty, index)]
2498 )?
2499 }
2500 ref other => {
2501 return Err(Error::Custom(format!("Cannot index {other:?}")))
2502 }
2503 }
2504 Ok(())
2505 }
2506
2507 if let Some(MatrixType {
2510 rows: crate::VectorSize::Bi,
2511 width: 4,
2512 ..
2513 }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
2514 {
2515 self.write_expr(module, base, func_ctx)?;
2516 write!(self.out, "._{index}")?;
2517 return Ok(());
2518 }
2519
2520 let base_ty_res = &func_ctx.info[base].ty;
2521 let mut resolved = base_ty_res.inner_with(&module.types);
2522 let base_ty_handle = match *resolved {
2523 TypeInner::Pointer { base, .. } => {
2524 resolved = &module.types[base].inner;
2525 Some(base)
2526 }
2527 _ => base_ty_res.handle(),
2528 };
2529
2530 if let TypeInner::Struct { ref members, .. } = *resolved {
2536 let member = &members[index as usize];
2537
2538 match module.types[member.ty].inner {
2539 TypeInner::Matrix {
2540 rows: crate::VectorSize::Bi,
2541 ..
2542 } if member.binding.is_none() => {
2543 let ty = base_ty_handle.unwrap();
2544 self.write_wrapped_struct_matrix_get_function_name(
2545 WrappedStructMatrixAccess { ty, index },
2546 )?;
2547 write!(self.out, "(")?;
2548 self.write_expr(module, base, func_ctx)?;
2549 write!(self.out, ")")?;
2550 return Ok(());
2551 }
2552 _ => {}
2553 }
2554 }
2555
2556 self.write_expr(module, base, func_ctx)?;
2557 write_access(self, resolved, base_ty_handle, index)?;
2558 }
2559 }
2560 Expression::FunctionArgument(pos) => {
2561 let key = func_ctx.argument_key(pos);
2562 let name = &self.names[&key];
2563 write!(self.out, "{name}")?;
2564 }
2565 Expression::ImageSample {
2566 image,
2567 sampler,
2568 gather,
2569 coordinate,
2570 array_index,
2571 offset,
2572 level,
2573 depth_ref,
2574 } => {
2575 use crate::SampleLevel as Sl;
2576 const COMPONENTS: [&str; 4] = ["", "Green", "Blue", "Alpha"];
2577
2578 let (base_str, component_str) = match gather {
2579 Some(component) => ("Gather", COMPONENTS[component as usize]),
2580 None => ("Sample", ""),
2581 };
2582 let cmp_str = match depth_ref {
2583 Some(_) => "Cmp",
2584 None => "",
2585 };
2586 let level_str = match level {
2587 Sl::Zero if gather.is_none() => "LevelZero",
2588 Sl::Auto | Sl::Zero => "",
2589 Sl::Exact(_) => "Level",
2590 Sl::Bias(_) => "Bias",
2591 Sl::Gradient { .. } => "Grad",
2592 };
2593
2594 self.write_expr(module, image, func_ctx)?;
2595 write!(self.out, ".{base_str}{cmp_str}{component_str}{level_str}(")?;
2596 self.write_expr(module, sampler, func_ctx)?;
2597 write!(self.out, ", ")?;
2598 self.write_texture_coordinates(
2599 "float",
2600 coordinate,
2601 array_index,
2602 None,
2603 module,
2604 func_ctx,
2605 )?;
2606
2607 if let Some(depth_ref) = depth_ref {
2608 write!(self.out, ", ")?;
2609 self.write_expr(module, depth_ref, func_ctx)?;
2610 }
2611
2612 match level {
2613 Sl::Auto | Sl::Zero => {}
2614 Sl::Exact(expr) => {
2615 write!(self.out, ", ")?;
2616 self.write_expr(module, expr, func_ctx)?;
2617 }
2618 Sl::Bias(expr) => {
2619 write!(self.out, ", ")?;
2620 self.write_expr(module, expr, func_ctx)?;
2621 }
2622 Sl::Gradient { x, y } => {
2623 write!(self.out, ", ")?;
2624 self.write_expr(module, x, func_ctx)?;
2625 write!(self.out, ", ")?;
2626 self.write_expr(module, y, func_ctx)?;
2627 }
2628 }
2629
2630 if let Some(offset) = offset {
2631 write!(self.out, ", ")?;
2632 write!(self.out, "int2(")?; self.write_const_expression(module, offset)?;
2634 write!(self.out, ")")?;
2635 }
2636
2637 write!(self.out, ")")?;
2638 }
2639 Expression::ImageQuery { image, query } => {
2640 if let TypeInner::Image {
2642 dim,
2643 arrayed,
2644 class,
2645 } = *func_ctx.resolve_type(image, &module.types)
2646 {
2647 let wrapped_image_query = WrappedImageQuery {
2648 dim,
2649 arrayed,
2650 class,
2651 query: query.into(),
2652 };
2653
2654 self.write_wrapped_image_query_function_name(wrapped_image_query)?;
2655 write!(self.out, "(")?;
2656 self.write_expr(module, image, func_ctx)?;
2658 if let crate::ImageQuery::Size { level: Some(level) } = query {
2659 write!(self.out, ", ")?;
2660 self.write_expr(module, level, func_ctx)?;
2661 }
2662 write!(self.out, ")")?;
2663 }
2664 }
2665 Expression::ImageLoad {
2666 image,
2667 coordinate,
2668 array_index,
2669 sample,
2670 level,
2671 } => {
2672 self.write_expr(module, image, func_ctx)?;
2674 write!(self.out, ".Load(")?;
2675
2676 self.write_texture_coordinates(
2677 "int",
2678 coordinate,
2679 array_index,
2680 level,
2681 module,
2682 func_ctx,
2683 )?;
2684
2685 if let Some(sample) = sample {
2686 write!(self.out, ", ")?;
2687 self.write_expr(module, sample, func_ctx)?;
2688 }
2689
2690 write!(self.out, ")")?;
2692
2693 if let TypeInner::Scalar(_) = *func_ctx.resolve_type(expr, &module.types) {
2695 write!(self.out, ".x")?;
2696 }
2697 }
2698 Expression::GlobalVariable(handle) => match module.global_variables[handle].space {
2699 crate::AddressSpace::Storage { .. } => {}
2700 _ => {
2701 let name = &self.names[&NameKey::GlobalVariable(handle)];
2702 write!(self.out, "{name}")?;
2703 }
2704 },
2705 Expression::LocalVariable(handle) => {
2706 write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])?
2707 }
2708 Expression::Load { pointer } => {
2709 match func_ctx
2710 .resolve_type(pointer, &module.types)
2711 .pointer_space()
2712 {
2713 Some(crate::AddressSpace::Storage { .. }) => {
2714 let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2715 let result_ty = func_ctx.info[expr].ty.clone();
2716 self.write_storage_load(module, var_handle, result_ty, func_ctx)?;
2717 }
2718 _ => {
2719 let mut close_paren = false;
2720
2721 if let Some(MatrixType {
2726 rows: crate::VectorSize::Bi,
2727 width: 4,
2728 ..
2729 }) = get_inner_matrix_of_struct_array_member(
2730 module, pointer, func_ctx, false,
2731 )
2732 .or_else(|| get_inner_matrix_of_global_uniform(module, pointer, func_ctx))
2733 {
2734 let mut resolved = func_ctx.resolve_type(pointer, &module.types);
2735 if let TypeInner::Pointer { base, .. } = *resolved {
2736 resolved = &module.types[base].inner;
2737 }
2738
2739 write!(self.out, "((")?;
2740 if let TypeInner::Array { base, size, .. } = *resolved {
2741 self.write_type(module, base)?;
2742 self.write_array_size(module, base, size)?;
2743 } else {
2744 self.write_value_type(module, resolved)?;
2745 }
2746 write!(self.out, ")")?;
2747 close_paren = true;
2748 }
2749
2750 self.write_expr(module, pointer, func_ctx)?;
2751
2752 if close_paren {
2753 write!(self.out, ")")?;
2754 }
2755 }
2756 }
2757 }
2758 Expression::Unary { op, expr } => {
2759 let op_str = match op {
2761 crate::UnaryOperator::Negate => "-",
2762 crate::UnaryOperator::LogicalNot => "!",
2763 crate::UnaryOperator::BitwiseNot => "~",
2764 };
2765 write!(self.out, "{op_str}(")?;
2766 self.write_expr(module, expr, func_ctx)?;
2767 write!(self.out, ")")?;
2768 }
2769 Expression::As {
2770 expr,
2771 kind,
2772 convert,
2773 } => {
2774 let inner = func_ctx.resolve_type(expr, &module.types);
2775 let close_paren = match convert {
2776 Some(dst_width) => {
2777 let scalar = crate::Scalar {
2778 kind,
2779 width: dst_width,
2780 };
2781 match *inner {
2782 TypeInner::Vector { size, .. } => {
2783 write!(
2784 self.out,
2785 "{}{}(",
2786 scalar.to_hlsl_str()?,
2787 back::vector_size_str(size)
2788 )?;
2789 }
2790 TypeInner::Scalar(_) => {
2791 write!(self.out, "{}(", scalar.to_hlsl_str()?,)?;
2792 }
2793 TypeInner::Matrix { columns, rows, .. } => {
2794 write!(
2795 self.out,
2796 "{}{}x{}(",
2797 scalar.to_hlsl_str()?,
2798 back::vector_size_str(columns),
2799 back::vector_size_str(rows)
2800 )?;
2801 }
2802 _ => {
2803 return Err(Error::Unimplemented(format!(
2804 "write_expr expression::as {inner:?}"
2805 )));
2806 }
2807 };
2808 true
2809 }
2810 None => {
2811 if inner.scalar_width() == Some(8) {
2812 false
2813 } else {
2814 write!(self.out, "{}(", kind.to_hlsl_cast(),)?;
2815 true
2816 }
2817 }
2818 };
2819 self.write_expr(module, expr, func_ctx)?;
2820 if close_paren {
2821 write!(self.out, ")")?;
2822 }
2823 }
2824 Expression::Math {
2825 fun,
2826 arg,
2827 arg1,
2828 arg2,
2829 arg3,
2830 } => {
2831 use crate::MathFunction as Mf;
2832
2833 enum Function {
2834 Asincosh { is_sin: bool },
2835 Atanh,
2836 Pack2x16float,
2837 Pack2x16snorm,
2838 Pack2x16unorm,
2839 Pack4x8snorm,
2840 Pack4x8unorm,
2841 Unpack2x16float,
2842 Unpack2x16snorm,
2843 Unpack2x16unorm,
2844 Unpack4x8snorm,
2845 Unpack4x8unorm,
2846 Regular(&'static str),
2847 MissingIntOverload(&'static str),
2848 MissingIntReturnType(&'static str),
2849 CountTrailingZeros,
2850 CountLeadingZeros,
2851 }
2852
2853 let fun = match fun {
2854 Mf::Abs => Function::Regular("abs"),
2856 Mf::Min => Function::Regular("min"),
2857 Mf::Max => Function::Regular("max"),
2858 Mf::Clamp => Function::Regular("clamp"),
2859 Mf::Saturate => Function::Regular("saturate"),
2860 Mf::Cos => Function::Regular("cos"),
2862 Mf::Cosh => Function::Regular("cosh"),
2863 Mf::Sin => Function::Regular("sin"),
2864 Mf::Sinh => Function::Regular("sinh"),
2865 Mf::Tan => Function::Regular("tan"),
2866 Mf::Tanh => Function::Regular("tanh"),
2867 Mf::Acos => Function::Regular("acos"),
2868 Mf::Asin => Function::Regular("asin"),
2869 Mf::Atan => Function::Regular("atan"),
2870 Mf::Atan2 => Function::Regular("atan2"),
2871 Mf::Asinh => Function::Asincosh { is_sin: true },
2872 Mf::Acosh => Function::Asincosh { is_sin: false },
2873 Mf::Atanh => Function::Atanh,
2874 Mf::Radians => Function::Regular("radians"),
2875 Mf::Degrees => Function::Regular("degrees"),
2876 Mf::Ceil => Function::Regular("ceil"),
2878 Mf::Floor => Function::Regular("floor"),
2879 Mf::Round => Function::Regular("round"),
2880 Mf::Fract => Function::Regular("frac"),
2881 Mf::Trunc => Function::Regular("trunc"),
2882 Mf::Modf => Function::Regular(MODF_FUNCTION),
2883 Mf::Frexp => Function::Regular(FREXP_FUNCTION),
2884 Mf::Ldexp => Function::Regular("ldexp"),
2885 Mf::Exp => Function::Regular("exp"),
2887 Mf::Exp2 => Function::Regular("exp2"),
2888 Mf::Log => Function::Regular("log"),
2889 Mf::Log2 => Function::Regular("log2"),
2890 Mf::Pow => Function::Regular("pow"),
2891 Mf::Dot => Function::Regular("dot"),
2893 Mf::Cross => Function::Regular("cross"),
2895 Mf::Distance => Function::Regular("distance"),
2896 Mf::Length => Function::Regular("length"),
2897 Mf::Normalize => Function::Regular("normalize"),
2898 Mf::FaceForward => Function::Regular("faceforward"),
2899 Mf::Reflect => Function::Regular("reflect"),
2900 Mf::Refract => Function::Regular("refract"),
2901 Mf::Sign => Function::Regular("sign"),
2903 Mf::Fma => Function::Regular("mad"),
2904 Mf::Mix => Function::Regular("lerp"),
2905 Mf::Step => Function::Regular("step"),
2906 Mf::SmoothStep => Function::Regular("smoothstep"),
2907 Mf::Sqrt => Function::Regular("sqrt"),
2908 Mf::InverseSqrt => Function::Regular("rsqrt"),
2909 Mf::Transpose => Function::Regular("transpose"),
2911 Mf::Determinant => Function::Regular("determinant"),
2912 Mf::CountTrailingZeros => Function::CountTrailingZeros,
2914 Mf::CountLeadingZeros => Function::CountLeadingZeros,
2915 Mf::CountOneBits => Function::MissingIntOverload("countbits"),
2916 Mf::ReverseBits => Function::MissingIntOverload("reversebits"),
2917 Mf::FindLsb => Function::MissingIntReturnType("firstbitlow"),
2918 Mf::FindMsb => Function::MissingIntReturnType("firstbithigh"),
2919 Mf::ExtractBits => Function::Regular(EXTRACT_BITS_FUNCTION),
2920 Mf::InsertBits => Function::Regular(INSERT_BITS_FUNCTION),
2921 Mf::Pack2x16float => Function::Pack2x16float,
2923 Mf::Pack2x16snorm => Function::Pack2x16snorm,
2924 Mf::Pack2x16unorm => Function::Pack2x16unorm,
2925 Mf::Pack4x8snorm => Function::Pack4x8snorm,
2926 Mf::Pack4x8unorm => Function::Pack4x8unorm,
2927 Mf::Unpack2x16float => Function::Unpack2x16float,
2929 Mf::Unpack2x16snorm => Function::Unpack2x16snorm,
2930 Mf::Unpack2x16unorm => Function::Unpack2x16unorm,
2931 Mf::Unpack4x8snorm => Function::Unpack4x8snorm,
2932 Mf::Unpack4x8unorm => Function::Unpack4x8unorm,
2933 _ => return Err(Error::Unimplemented(format!("write_expr_math {fun:?}"))),
2934 };
2935
2936 match fun {
2937 Function::Asincosh { is_sin } => {
2938 write!(self.out, "log(")?;
2939 self.write_expr(module, arg, func_ctx)?;
2940 write!(self.out, " + sqrt(")?;
2941 self.write_expr(module, arg, func_ctx)?;
2942 write!(self.out, " * ")?;
2943 self.write_expr(module, arg, func_ctx)?;
2944 match is_sin {
2945 true => write!(self.out, " + 1.0))")?,
2946 false => write!(self.out, " - 1.0))")?,
2947 }
2948 }
2949 Function::Atanh => {
2950 write!(self.out, "0.5 * log((1.0 + ")?;
2951 self.write_expr(module, arg, func_ctx)?;
2952 write!(self.out, ") / (1.0 - ")?;
2953 self.write_expr(module, arg, func_ctx)?;
2954 write!(self.out, "))")?;
2955 }
2956 Function::Pack2x16float => {
2957 write!(self.out, "(f32tof16(")?;
2958 self.write_expr(module, arg, func_ctx)?;
2959 write!(self.out, "[0]) | f32tof16(")?;
2960 self.write_expr(module, arg, func_ctx)?;
2961 write!(self.out, "[1]) << 16)")?;
2962 }
2963 Function::Pack2x16snorm => {
2964 let scale = 32767;
2965
2966 write!(self.out, "uint((int(round(clamp(")?;
2967 self.write_expr(module, arg, func_ctx)?;
2968 write!(
2969 self.out,
2970 "[0], -1.0, 1.0) * {scale}.0)) & 0xFFFF) | ((int(round(clamp("
2971 )?;
2972 self.write_expr(module, arg, func_ctx)?;
2973 write!(self.out, "[1], -1.0, 1.0) * {scale}.0)) & 0xFFFF) << 16))",)?;
2974 }
2975 Function::Pack2x16unorm => {
2976 let scale = 65535;
2977
2978 write!(self.out, "(uint(round(clamp(")?;
2979 self.write_expr(module, arg, func_ctx)?;
2980 write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
2981 self.write_expr(module, arg, func_ctx)?;
2982 write!(self.out, "[1], 0.0, 1.0) * {scale}.0)) << 16)")?;
2983 }
2984 Function::Pack4x8snorm => {
2985 let scale = 127;
2986
2987 write!(self.out, "uint((int(round(clamp(")?;
2988 self.write_expr(module, arg, func_ctx)?;
2989 write!(
2990 self.out,
2991 "[0], -1.0, 1.0) * {scale}.0)) & 0xFF) | ((int(round(clamp("
2992 )?;
2993 self.write_expr(module, arg, func_ctx)?;
2994 write!(
2995 self.out,
2996 "[1], -1.0, 1.0) * {scale}.0)) & 0xFF) << 8) | ((int(round(clamp("
2997 )?;
2998 self.write_expr(module, arg, func_ctx)?;
2999 write!(
3000 self.out,
3001 "[2], -1.0, 1.0) * {scale}.0)) & 0xFF) << 16) | ((int(round(clamp("
3002 )?;
3003 self.write_expr(module, arg, func_ctx)?;
3004 write!(self.out, "[3], -1.0, 1.0) * {scale}.0)) & 0xFF) << 24))",)?;
3005 }
3006 Function::Pack4x8unorm => {
3007 let scale = 255;
3008
3009 write!(self.out, "(uint(round(clamp(")?;
3010 self.write_expr(module, arg, func_ctx)?;
3011 write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
3012 self.write_expr(module, arg, func_ctx)?;
3013 write!(
3014 self.out,
3015 "[1], 0.0, 1.0) * {scale}.0)) << 8 | uint(round(clamp("
3016 )?;
3017 self.write_expr(module, arg, func_ctx)?;
3018 write!(
3019 self.out,
3020 "[2], 0.0, 1.0) * {scale}.0)) << 16 | uint(round(clamp("
3021 )?;
3022 self.write_expr(module, arg, func_ctx)?;
3023 write!(self.out, "[3], 0.0, 1.0) * {scale}.0)) << 24)")?;
3024 }
3025
3026 Function::Unpack2x16float => {
3027 write!(self.out, "float2(f16tof32(")?;
3028 self.write_expr(module, arg, func_ctx)?;
3029 write!(self.out, "), f16tof32((")?;
3030 self.write_expr(module, arg, func_ctx)?;
3031 write!(self.out, ") >> 16))")?;
3032 }
3033 Function::Unpack2x16snorm => {
3034 let scale = 32767;
3035
3036 write!(self.out, "(float2(int2(")?;
3037 self.write_expr(module, arg, func_ctx)?;
3038 write!(self.out, " << 16, ")?;
3039 self.write_expr(module, arg, func_ctx)?;
3040 write!(self.out, ") >> 16) / {scale}.0)")?;
3041 }
3042 Function::Unpack2x16unorm => {
3043 let scale = 65535;
3044
3045 write!(self.out, "(float2(")?;
3046 self.write_expr(module, arg, func_ctx)?;
3047 write!(self.out, " & 0xFFFF, ")?;
3048 self.write_expr(module, arg, func_ctx)?;
3049 write!(self.out, " >> 16) / {scale}.0)")?;
3050 }
3051 Function::Unpack4x8snorm => {
3052 let scale = 127;
3053
3054 write!(self.out, "(float4(int4(")?;
3055 self.write_expr(module, arg, func_ctx)?;
3056 write!(self.out, " << 24, ")?;
3057 self.write_expr(module, arg, func_ctx)?;
3058 write!(self.out, " << 16, ")?;
3059 self.write_expr(module, arg, func_ctx)?;
3060 write!(self.out, " << 8, ")?;
3061 self.write_expr(module, arg, func_ctx)?;
3062 write!(self.out, ") >> 24) / {scale}.0)")?;
3063 }
3064 Function::Unpack4x8unorm => {
3065 let scale = 255;
3066
3067 write!(self.out, "(float4(")?;
3068 self.write_expr(module, arg, func_ctx)?;
3069 write!(self.out, " & 0xFF, ")?;
3070 self.write_expr(module, arg, func_ctx)?;
3071 write!(self.out, " >> 8 & 0xFF, ")?;
3072 self.write_expr(module, arg, func_ctx)?;
3073 write!(self.out, " >> 16 & 0xFF, ")?;
3074 self.write_expr(module, arg, func_ctx)?;
3075 write!(self.out, " >> 24) / {scale}.0)")?;
3076 }
3077 Function::Regular(fun_name) => {
3078 write!(self.out, "{fun_name}(")?;
3079 self.write_expr(module, arg, func_ctx)?;
3080 if let Some(arg) = arg1 {
3081 write!(self.out, ", ")?;
3082 self.write_expr(module, arg, func_ctx)?;
3083 }
3084 if let Some(arg) = arg2 {
3085 write!(self.out, ", ")?;
3086 self.write_expr(module, arg, func_ctx)?;
3087 }
3088 if let Some(arg) = arg3 {
3089 write!(self.out, ", ")?;
3090 self.write_expr(module, arg, func_ctx)?;
3091 }
3092 write!(self.out, ")")?
3093 }
3094 Function::MissingIntOverload(fun_name) => {
3097 let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
3098 if let Some(crate::Scalar {
3099 kind: ScalarKind::Sint,
3100 width: 4,
3101 }) = scalar_kind
3102 {
3103 write!(self.out, "asint({fun_name}(asuint(")?;
3104 self.write_expr(module, arg, func_ctx)?;
3105 write!(self.out, ")))")?;
3106 } else {
3107 write!(self.out, "{fun_name}(")?;
3108 self.write_expr(module, arg, func_ctx)?;
3109 write!(self.out, ")")?;
3110 }
3111 }
3112 Function::MissingIntReturnType(fun_name) => {
3115 let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
3116 if let Some(crate::Scalar {
3117 kind: ScalarKind::Sint,
3118 width: 4,
3119 }) = scalar_kind
3120 {
3121 write!(self.out, "asint({fun_name}(")?;
3122 self.write_expr(module, arg, func_ctx)?;
3123 write!(self.out, "))")?;
3124 } else {
3125 write!(self.out, "{fun_name}(")?;
3126 self.write_expr(module, arg, func_ctx)?;
3127 write!(self.out, ")")?;
3128 }
3129 }
3130 Function::CountTrailingZeros => {
3131 match *func_ctx.resolve_type(arg, &module.types) {
3132 TypeInner::Vector { size, scalar } => {
3133 let s = match size {
3134 crate::VectorSize::Bi => ".xx",
3135 crate::VectorSize::Tri => ".xxx",
3136 crate::VectorSize::Quad => ".xxxx",
3137 };
3138
3139 let scalar_width_bits = scalar.width * 8;
3140
3141 if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
3142 write!(
3143 self.out,
3144 "min(({scalar_width_bits}u){s}, firstbitlow("
3145 )?;
3146 self.write_expr(module, arg, func_ctx)?;
3147 write!(self.out, "))")?;
3148 } else {
3149 write!(
3151 self.out,
3152 "asint(min(({scalar_width_bits}u){s}, firstbitlow("
3153 )?;
3154 self.write_expr(module, arg, func_ctx)?;
3155 write!(self.out, ")))")?;
3156 }
3157 }
3158 TypeInner::Scalar(scalar) => {
3159 let scalar_width_bits = scalar.width * 8;
3160
3161 if scalar.kind == ScalarKind::Uint || scalar.width != 4 {
3162 write!(self.out, "min({scalar_width_bits}u, firstbitlow(")?;
3163 self.write_expr(module, arg, func_ctx)?;
3164 write!(self.out, "))")?;
3165 } else {
3166 write!(
3168 self.out,
3169 "asint(min({scalar_width_bits}u, firstbitlow("
3170 )?;
3171 self.write_expr(module, arg, func_ctx)?;
3172 write!(self.out, ")))")?;
3173 }
3174 }
3175 _ => unreachable!(),
3176 }
3177
3178 return Ok(());
3179 }
3180 Function::CountLeadingZeros => {
3181 match *func_ctx.resolve_type(arg, &module.types) {
3182 TypeInner::Vector { size, scalar } => {
3183 let s = match size {
3184 crate::VectorSize::Bi => ".xx",
3185 crate::VectorSize::Tri => ".xxx",
3186 crate::VectorSize::Quad => ".xxxx",
3187 };
3188
3189 let constant = scalar.width * 8 - 1;
3191
3192 if scalar.kind == ScalarKind::Uint {
3193 write!(self.out, "(({constant}u){s} - firstbithigh(")?;
3194 self.write_expr(module, arg, func_ctx)?;
3195 write!(self.out, "))")?;
3196 } else {
3197 let conversion_func = match scalar.width {
3198 4 => "asint",
3199 _ => "",
3200 };
3201 write!(self.out, "(")?;
3202 self.write_expr(module, arg, func_ctx)?;
3203 write!(
3204 self.out,
3205 " < (0){s} ? (0){s} : ({constant}){s} - {conversion_func}(firstbithigh("
3206 )?;
3207 self.write_expr(module, arg, func_ctx)?;
3208 write!(self.out, ")))")?;
3209 }
3210 }
3211 TypeInner::Scalar(scalar) => {
3212 let constant = scalar.width * 8 - 1;
3214
3215 if let ScalarKind::Uint = scalar.kind {
3216 write!(self.out, "({constant}u - firstbithigh(")?;
3217 self.write_expr(module, arg, func_ctx)?;
3218 write!(self.out, "))")?;
3219 } else {
3220 let conversion_func = match scalar.width {
3221 4 => "asint",
3222 _ => "",
3223 };
3224 write!(self.out, "(")?;
3225 self.write_expr(module, arg, func_ctx)?;
3226 write!(
3227 self.out,
3228 " < 0 ? 0 : {constant} - {conversion_func}(firstbithigh("
3229 )?;
3230 self.write_expr(module, arg, func_ctx)?;
3231 write!(self.out, ")))")?;
3232 }
3233 }
3234 _ => unreachable!(),
3235 }
3236
3237 return Ok(());
3238 }
3239 }
3240 }
3241 Expression::Swizzle {
3242 size,
3243 vector,
3244 pattern,
3245 } => {
3246 self.write_expr(module, vector, func_ctx)?;
3247 write!(self.out, ".")?;
3248 for &sc in pattern[..size as usize].iter() {
3249 self.out.write_char(back::COMPONENTS[sc as usize])?;
3250 }
3251 }
3252 Expression::ArrayLength(expr) => {
3253 let var_handle = match func_ctx.expressions[expr] {
3254 Expression::AccessIndex { base, index: _ } => {
3255 match func_ctx.expressions[base] {
3256 Expression::GlobalVariable(handle) => handle,
3257 _ => unreachable!(),
3258 }
3259 }
3260 Expression::GlobalVariable(handle) => handle,
3261 _ => unreachable!(),
3262 };
3263
3264 let var = &module.global_variables[var_handle];
3265 let (offset, stride) = match module.types[var.ty].inner {
3266 TypeInner::Array { stride, .. } => (0, stride),
3267 TypeInner::Struct { ref members, .. } => {
3268 let last = members.last().unwrap();
3269 let stride = match module.types[last.ty].inner {
3270 TypeInner::Array { stride, .. } => stride,
3271 _ => unreachable!(),
3272 };
3273 (last.offset, stride)
3274 }
3275 _ => unreachable!(),
3276 };
3277
3278 let storage_access = match var.space {
3279 crate::AddressSpace::Storage { access } => access,
3280 _ => crate::StorageAccess::default(),
3281 };
3282 let wrapped_array_length = WrappedArrayLength {
3283 writable: storage_access.contains(crate::StorageAccess::STORE),
3284 };
3285
3286 write!(self.out, "((")?;
3287 self.write_wrapped_array_length_function_name(wrapped_array_length)?;
3288 let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
3289 write!(self.out, "({var_name}) - {offset}) / {stride})")?
3290 }
3291 Expression::Derivative { axis, ctrl, expr } => {
3292 use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
3293 if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
3294 let tail = match ctrl {
3295 Ctrl::Coarse => "coarse",
3296 Ctrl::Fine => "fine",
3297 Ctrl::None => unreachable!(),
3298 };
3299 write!(self.out, "abs(ddx_{tail}(")?;
3300 self.write_expr(module, expr, func_ctx)?;
3301 write!(self.out, ")) + abs(ddy_{tail}(")?;
3302 self.write_expr(module, expr, func_ctx)?;
3303 write!(self.out, "))")?
3304 } else {
3305 let fun_str = match (axis, ctrl) {
3306 (Axis::X, Ctrl::Coarse) => "ddx_coarse",
3307 (Axis::X, Ctrl::Fine) => "ddx_fine",
3308 (Axis::X, Ctrl::None) => "ddx",
3309 (Axis::Y, Ctrl::Coarse) => "ddy_coarse",
3310 (Axis::Y, Ctrl::Fine) => "ddy_fine",
3311 (Axis::Y, Ctrl::None) => "ddy",
3312 (Axis::Width, Ctrl::Coarse | Ctrl::Fine) => unreachable!(),
3313 (Axis::Width, Ctrl::None) => "fwidth",
3314 };
3315 write!(self.out, "{fun_str}(")?;
3316 self.write_expr(module, expr, func_ctx)?;
3317 write!(self.out, ")")?
3318 }
3319 }
3320 Expression::Relational { fun, argument } => {
3321 use crate::RelationalFunction as Rf;
3322
3323 let fun_str = match fun {
3324 Rf::All => "all",
3325 Rf::Any => "any",
3326 Rf::IsNan => "isnan",
3327 Rf::IsInf => "isinf",
3328 };
3329 write!(self.out, "{fun_str}(")?;
3330 self.write_expr(module, argument, func_ctx)?;
3331 write!(self.out, ")")?
3332 }
3333 Expression::Select {
3334 condition,
3335 accept,
3336 reject,
3337 } => {
3338 write!(self.out, "(")?;
3339 self.write_expr(module, condition, func_ctx)?;
3340 write!(self.out, " ? ")?;
3341 self.write_expr(module, accept, func_ctx)?;
3342 write!(self.out, " : ")?;
3343 self.write_expr(module, reject, func_ctx)?;
3344 write!(self.out, ")")?
3345 }
3346 Expression::RayQueryGetIntersection { .. } => unreachable!(),
3348 Expression::CallResult(_)
3350 | Expression::AtomicResult { .. }
3351 | Expression::WorkGroupUniformLoadResult { .. }
3352 | Expression::RayQueryProceedResult
3353 | Expression::SubgroupBallotResult
3354 | Expression::SubgroupOperationResult { .. } => {}
3355 }
3356
3357 if !closing_bracket.is_empty() {
3358 write!(self.out, "{closing_bracket}")?;
3359 }
3360 Ok(())
3361 }
3362
3363 fn write_named_expr(
3364 &mut self,
3365 module: &Module,
3366 handle: Handle<crate::Expression>,
3367 name: String,
3368 named: Handle<crate::Expression>,
3371 ctx: &back::FunctionCtx,
3372 ) -> BackendResult {
3373 match ctx.info[named].ty {
3374 proc::TypeResolution::Handle(ty_handle) => match module.types[ty_handle].inner {
3375 TypeInner::Struct { .. } => {
3376 let ty_name = &self.names[&NameKey::Type(ty_handle)];
3377 write!(self.out, "{ty_name}")?;
3378 }
3379 _ => {
3380 self.write_type(module, ty_handle)?;
3381 }
3382 },
3383 proc::TypeResolution::Value(ref inner) => {
3384 self.write_value_type(module, inner)?;
3385 }
3386 }
3387
3388 let resolved = ctx.resolve_type(named, &module.types);
3389
3390 write!(self.out, " {name}")?;
3391 if let TypeInner::Array { base, size, .. } = *resolved {
3393 self.write_array_size(module, base, size)?;
3394 }
3395 write!(self.out, " = ")?;
3396 self.write_expr(module, handle, ctx)?;
3397 writeln!(self.out, ";")?;
3398 self.named_expressions.insert(named, name);
3399
3400 Ok(())
3401 }
3402
3403 pub(super) fn write_default_init(
3405 &mut self,
3406 module: &Module,
3407 ty: Handle<crate::Type>,
3408 ) -> BackendResult {
3409 write!(self.out, "(")?;
3410 self.write_type(module, ty)?;
3411 if let TypeInner::Array { base, size, .. } = module.types[ty].inner {
3412 self.write_array_size(module, base, size)?;
3413 }
3414 write!(self.out, ")0")?;
3415 Ok(())
3416 }
3417
3418 fn write_barrier(&mut self, barrier: crate::Barrier, level: back::Level) -> BackendResult {
3419 if barrier.contains(crate::Barrier::STORAGE) {
3420 writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
3421 }
3422 if barrier.contains(crate::Barrier::WORK_GROUP) {
3423 writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?;
3424 }
3425 if barrier.contains(crate::Barrier::SUB_GROUP) {
3426 }
3428 Ok(())
3429 }
3430}
3431
3432pub(super) struct MatrixType {
3433 pub(super) columns: crate::VectorSize,
3434 pub(super) rows: crate::VectorSize,
3435 pub(super) width: crate::Bytes,
3436}
3437
3438pub(super) fn get_inner_matrix_data(
3439 module: &Module,
3440 handle: Handle<crate::Type>,
3441) -> Option<MatrixType> {
3442 match module.types[handle].inner {
3443 TypeInner::Matrix {
3444 columns,
3445 rows,
3446 scalar,
3447 } => Some(MatrixType {
3448 columns,
3449 rows,
3450 width: scalar.width,
3451 }),
3452 TypeInner::Array { base, .. } => get_inner_matrix_data(module, base),
3453 _ => None,
3454 }
3455}
3456
3457pub(super) fn get_inner_matrix_of_struct_array_member(
3462 module: &Module,
3463 base: Handle<crate::Expression>,
3464 func_ctx: &back::FunctionCtx<'_>,
3465 direct: bool,
3466) -> Option<MatrixType> {
3467 let mut mat_data = None;
3468 let mut array_base = None;
3469
3470 let mut current_base = base;
3471 loop {
3472 let mut resolved = func_ctx.resolve_type(current_base, &module.types);
3473 if let TypeInner::Pointer { base, .. } = *resolved {
3474 resolved = &module.types[base].inner;
3475 };
3476
3477 match *resolved {
3478 TypeInner::Matrix {
3479 columns,
3480 rows,
3481 scalar,
3482 } => {
3483 mat_data = Some(MatrixType {
3484 columns,
3485 rows,
3486 width: scalar.width,
3487 })
3488 }
3489 TypeInner::Array { base, .. } => {
3490 array_base = Some(base);
3491 }
3492 TypeInner::Struct { .. } => {
3493 if let Some(array_base) = array_base {
3494 if direct {
3495 return mat_data;
3496 } else {
3497 return get_inner_matrix_data(module, array_base);
3498 }
3499 }
3500
3501 break;
3502 }
3503 _ => break,
3504 }
3505
3506 current_base = match func_ctx.expressions[current_base] {
3507 crate::Expression::Access { base, .. } => base,
3508 crate::Expression::AccessIndex { base, .. } => base,
3509 _ => break,
3510 };
3511 }
3512 None
3513}
3514
3515fn get_inner_matrix_of_global_uniform(
3520 module: &Module,
3521 base: Handle<crate::Expression>,
3522 func_ctx: &back::FunctionCtx<'_>,
3523) -> Option<MatrixType> {
3524 let mut mat_data = None;
3525 let mut array_base = None;
3526
3527 let mut current_base = base;
3528 loop {
3529 let mut resolved = func_ctx.resolve_type(current_base, &module.types);
3530 if let TypeInner::Pointer { base, .. } = *resolved {
3531 resolved = &module.types[base].inner;
3532 };
3533
3534 match *resolved {
3535 TypeInner::Matrix {
3536 columns,
3537 rows,
3538 scalar,
3539 } => {
3540 mat_data = Some(MatrixType {
3541 columns,
3542 rows,
3543 width: scalar.width,
3544 })
3545 }
3546 TypeInner::Array { base, .. } => {
3547 array_base = Some(base);
3548 }
3549 _ => break,
3550 }
3551
3552 current_base = match func_ctx.expressions[current_base] {
3553 crate::Expression::Access { base, .. } => base,
3554 crate::Expression::AccessIndex { base, .. } => base,
3555 crate::Expression::GlobalVariable(handle)
3556 if module.global_variables[handle].space == crate::AddressSpace::Uniform =>
3557 {
3558 return mat_data.or_else(|| {
3559 array_base.and_then(|array_base| get_inner_matrix_data(module, array_base))
3560 })
3561 }
3562 _ => break,
3563 };
3564 }
3565 None
3566}