1use super::{ExpressionError, FunctionError, ModuleInfo, ShaderStages, ValidationFlags};
10use crate::span::{AddSpan as _, WithSpan};
11use crate::{
12 arena::{Arena, Handle},
13 proc::{ResolveContext, TypeResolution},
14};
15use std::ops;
16
17pub type NonUniformResult = Option<Handle<crate::Expression>>;
18
19const DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE: bool = true;
22
23bitflags::bitflags! {
24 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
26 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
27 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
28 pub struct UniformityRequirements: u8 {
29 const WORK_GROUP_BARRIER = 0x1;
30 const DERIVATIVE = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x2 };
31 const IMPLICIT_LEVEL = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x4 };
32 }
33}
34
35#[derive(Clone, Debug)]
37#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
38#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
39#[cfg_attr(test, derive(PartialEq))]
40pub struct Uniformity {
41 pub non_uniform_result: NonUniformResult,
53 pub requirements: UniformityRequirements,
55}
56
57impl Uniformity {
58 const fn new() -> Self {
59 Uniformity {
60 non_uniform_result: None,
61 requirements: UniformityRequirements::empty(),
62 }
63 }
64}
65
66bitflags::bitflags! {
67 #[derive(Clone, Copy, Debug, PartialEq)]
68 struct ExitFlags: u8 {
69 const MAY_RETURN = 0x1;
73 const MAY_KILL = 0x2;
76 }
77}
78
79#[cfg_attr(test, derive(Debug, PartialEq))]
81struct FunctionUniformity {
82 result: Uniformity,
83 exit: ExitFlags,
84}
85
86impl ops::BitOr for FunctionUniformity {
87 type Output = Self;
88 fn bitor(self, other: Self) -> Self {
89 FunctionUniformity {
90 result: Uniformity {
91 non_uniform_result: self
92 .result
93 .non_uniform_result
94 .or(other.result.non_uniform_result),
95 requirements: self.result.requirements | other.result.requirements,
96 },
97 exit: self.exit | other.exit,
98 }
99 }
100}
101
102impl FunctionUniformity {
103 const fn new() -> Self {
104 FunctionUniformity {
105 result: Uniformity::new(),
106 exit: ExitFlags::empty(),
107 }
108 }
109
110 const fn exit_disruptor(&self) -> Option<UniformityDisruptor> {
112 if self.exit.contains(ExitFlags::MAY_RETURN) {
113 Some(UniformityDisruptor::Return)
114 } else if self.exit.contains(ExitFlags::MAY_KILL) {
115 Some(UniformityDisruptor::Discard)
116 } else {
117 None
118 }
119 }
120}
121
122bitflags::bitflags! {
123 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
125 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
126 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
127 pub struct GlobalUse: u8 {
128 const READ = 0x1;
130 const WRITE = 0x2;
132 const QUERY = 0x4;
134 }
135}
136
137#[derive(Clone, Debug, Eq, Hash, PartialEq)]
138#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
139#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
140pub struct SamplingKey {
141 pub image: Handle<crate::GlobalVariable>,
142 pub sampler: Handle<crate::GlobalVariable>,
143}
144
145#[derive(Clone, Debug)]
146#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
147#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
148pub struct ExpressionInfo {
150 pub uniformity: Uniformity,
156
157 pub ref_count: usize,
160
161 assignable_global: Option<Handle<crate::GlobalVariable>>,
175
176 pub ty: TypeResolution,
178}
179
180impl ExpressionInfo {
181 const fn new() -> Self {
182 ExpressionInfo {
183 uniformity: Uniformity::new(),
184 ref_count: 0,
185 assignable_global: None,
186 ty: TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar {
188 kind: crate::ScalarKind::Bool,
189 width: 0,
190 })),
191 }
192 }
193}
194
195#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
196#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
197#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
198enum GlobalOrArgument {
199 Global(Handle<crate::GlobalVariable>),
200 Argument(u32),
201}
202
203impl GlobalOrArgument {
204 fn from_expression(
205 expression_arena: &Arena<crate::Expression>,
206 expression: Handle<crate::Expression>,
207 ) -> Result<GlobalOrArgument, ExpressionError> {
208 Ok(match expression_arena[expression] {
209 crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
210 crate::Expression::FunctionArgument(i) => GlobalOrArgument::Argument(i),
211 crate::Expression::Access { base, .. }
212 | crate::Expression::AccessIndex { base, .. } => match expression_arena[base] {
213 crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
214 _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
215 },
216 _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
217 })
218 }
219}
220
221#[derive(Debug, Clone, PartialEq, Eq, Hash)]
222#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
223#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
224struct Sampling {
225 image: GlobalOrArgument,
226 sampler: GlobalOrArgument,
227}
228
229#[derive(Debug, Clone)]
230#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
231#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
232pub struct FunctionInfo {
233 #[allow(dead_code)]
235 flags: ValidationFlags,
236 pub available_stages: ShaderStages,
238 pub uniformity: Uniformity,
240 pub may_kill: bool,
242
243 pub sampling_set: crate::FastHashSet<SamplingKey>,
258
259 global_uses: Box<[GlobalUse]>,
266
267 expressions: Box<[ExpressionInfo]>,
274
275 sampling: crate::FastHashSet<Sampling>,
288
289 pub dual_source_blending: bool,
291}
292
293impl FunctionInfo {
294 pub const fn global_variable_count(&self) -> usize {
295 self.global_uses.len()
296 }
297 pub const fn expression_count(&self) -> usize {
298 self.expressions.len()
299 }
300 pub fn dominates_global_use(&self, other: &Self) -> bool {
301 for (self_global_uses, other_global_uses) in
302 self.global_uses.iter().zip(other.global_uses.iter())
303 {
304 if !self_global_uses.contains(*other_global_uses) {
305 return false;
306 }
307 }
308 true
309 }
310}
311
312impl ops::Index<Handle<crate::GlobalVariable>> for FunctionInfo {
313 type Output = GlobalUse;
314 fn index(&self, handle: Handle<crate::GlobalVariable>) -> &GlobalUse {
315 &self.global_uses[handle.index()]
316 }
317}
318
319impl ops::Index<Handle<crate::Expression>> for FunctionInfo {
320 type Output = ExpressionInfo;
321 fn index(&self, handle: Handle<crate::Expression>) -> &ExpressionInfo {
322 &self.expressions[handle.index()]
323 }
324}
325
326#[derive(Clone, Copy, Debug, thiserror::Error)]
328#[cfg_attr(test, derive(PartialEq))]
329pub enum UniformityDisruptor {
330 #[error("Expression {0:?} produced non-uniform result, and control flow depends on it")]
331 Expression(Handle<crate::Expression>),
332 #[error("There is a Return earlier in the control flow of the function")]
333 Return,
334 #[error("There is a Discard earlier in the entry point across all called functions")]
335 Discard,
336}
337
338impl FunctionInfo {
339 #[must_use]
347 fn add_ref_impl(
348 &mut self,
349 expr: Handle<crate::Expression>,
350 global_use: GlobalUse,
351 ) -> NonUniformResult {
352 let info = &mut self.expressions[expr.index()];
353 info.ref_count += 1;
354 if let Some(global) = info.assignable_global {
356 self.global_uses[global.index()] |= global_use;
357 }
358 info.uniformity.non_uniform_result
359 }
360
361 #[must_use]
368 fn add_ref(&mut self, expr: Handle<crate::Expression>) -> NonUniformResult {
369 self.add_ref_impl(expr, GlobalUse::READ)
370 }
371
372 #[must_use]
387 fn add_assignable_ref(
388 &mut self,
389 expr: Handle<crate::Expression>,
390 assignable_global: &mut Option<Handle<crate::GlobalVariable>>,
391 ) -> NonUniformResult {
392 let info = &mut self.expressions[expr.index()];
393 info.ref_count += 1;
394 if let Some(global) = info.assignable_global {
397 if let Some(_old) = assignable_global.replace(global) {
398 unreachable!()
399 }
400 }
401 info.uniformity.non_uniform_result
402 }
403
404 fn process_call(
406 &mut self,
407 callee: &Self,
408 arguments: &[Handle<crate::Expression>],
409 expression_arena: &Arena<crate::Expression>,
410 ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
411 self.sampling_set
412 .extend(callee.sampling_set.iter().cloned());
413 for sampling in callee.sampling.iter() {
414 let image_storage = match sampling.image {
417 GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
418 GlobalOrArgument::Argument(i) => {
419 let handle = arguments[i as usize];
420 GlobalOrArgument::from_expression(expression_arena, handle).map_err(
421 |source| {
422 FunctionError::Expression { handle, source }
423 .with_span_handle(handle, expression_arena)
424 },
425 )?
426 }
427 };
428
429 let sampler_storage = match sampling.sampler {
430 GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
431 GlobalOrArgument::Argument(i) => {
432 let handle = arguments[i as usize];
433 GlobalOrArgument::from_expression(expression_arena, handle).map_err(
434 |source| {
435 FunctionError::Expression { handle, source }
436 .with_span_handle(handle, expression_arena)
437 },
438 )?
439 }
440 };
441
442 match (image_storage, sampler_storage) {
447 (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
448 self.sampling_set.insert(SamplingKey { image, sampler });
449 }
450 (image, sampler) => {
451 self.sampling.insert(Sampling { image, sampler });
452 }
453 }
454 }
455
456 for (mine, other) in self.global_uses.iter_mut().zip(callee.global_uses.iter()) {
458 *mine |= *other;
459 }
460
461 Ok(FunctionUniformity {
462 result: callee.uniformity.clone(),
463 exit: if callee.may_kill {
464 ExitFlags::MAY_KILL
465 } else {
466 ExitFlags::empty()
467 },
468 })
469 }
470
471 #[allow(clippy::or_fun_call)]
491 fn process_expression(
492 &mut self,
493 handle: Handle<crate::Expression>,
494 expression_arena: &Arena<crate::Expression>,
495 other_functions: &[FunctionInfo],
496 resolve_context: &ResolveContext,
497 capabilities: super::Capabilities,
498 ) -> Result<(), ExpressionError> {
499 use crate::{Expression as E, SampleLevel as Sl};
500
501 let expression = &expression_arena[handle];
502 let mut assignable_global = None;
503 let uniformity = match *expression {
504 E::Access { base, index } => {
505 let base_ty = self[base].ty.inner_with(resolve_context.types);
506
507 let mut needed_caps = super::Capabilities::empty();
509 let is_binding_array = match *base_ty {
510 crate::TypeInner::BindingArray {
511 base: array_element_ty_handle,
512 ..
513 } => {
514 let ub_st = super::Capabilities::UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING;
516 let st_sb = super::Capabilities::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING;
517 let sampler = super::Capabilities::SAMPLER_NON_UNIFORM_INDEXING;
518
519 let array_element_ty =
521 &resolve_context.types[array_element_ty_handle].inner;
522
523 needed_caps |= match *array_element_ty {
524 crate::TypeInner::Image { class, .. } => match class {
526 crate::ImageClass::Storage { .. } => ub_st,
527 _ => st_sb,
528 },
529 crate::TypeInner::Sampler { .. } => sampler,
530 _ => {
532 if let E::GlobalVariable(global_handle) = expression_arena[base] {
533 let global = &resolve_context.global_vars[global_handle];
534 match global.space {
535 crate::AddressSpace::Uniform => ub_st,
536 crate::AddressSpace::Storage { .. } => st_sb,
537 _ => unreachable!(),
538 }
539 } else {
540 unreachable!()
541 }
542 }
543 };
544
545 true
546 }
547 _ => false,
548 };
549
550 if self[index].uniformity.non_uniform_result.is_some()
551 && !capabilities.contains(needed_caps)
552 && is_binding_array
553 {
554 return Err(ExpressionError::MissingCapabilities(needed_caps));
555 }
556
557 Uniformity {
558 non_uniform_result: self
559 .add_assignable_ref(base, &mut assignable_global)
560 .or(self.add_ref(index)),
561 requirements: UniformityRequirements::empty(),
562 }
563 }
564 E::AccessIndex { base, .. } => Uniformity {
565 non_uniform_result: self.add_assignable_ref(base, &mut assignable_global),
566 requirements: UniformityRequirements::empty(),
567 },
568 E::Splat { size: _, value } => Uniformity {
570 non_uniform_result: self.add_ref(value),
571 requirements: UniformityRequirements::empty(),
572 },
573 E::Swizzle { vector, .. } => Uniformity {
574 non_uniform_result: self.add_ref(vector),
575 requirements: UniformityRequirements::empty(),
576 },
577 E::Literal(_) | E::Constant(_) | E::Override(_) | E::ZeroValue(_) => Uniformity::new(),
578 E::Compose { ref components, .. } => {
579 let non_uniform_result = components
580 .iter()
581 .fold(None, |nur, &comp| nur.or(self.add_ref(comp)));
582 Uniformity {
583 non_uniform_result,
584 requirements: UniformityRequirements::empty(),
585 }
586 }
587 E::FunctionArgument(index) => {
589 let arg = &resolve_context.arguments[index as usize];
590 let uniform = match arg.binding {
591 Some(crate::Binding::BuiltIn(built_in)) => match built_in {
592 crate::BuiltIn::FrontFacing
594 | crate::BuiltIn::WorkGroupId
596 | crate::BuiltIn::WorkGroupSize
597 | crate::BuiltIn::NumWorkGroups => true,
598 _ => false,
599 },
600 Some(crate::Binding::Location {
602 interpolation: Some(crate::Interpolation::Flat),
603 ..
604 }) => true,
605 _ => false,
606 };
607 Uniformity {
608 non_uniform_result: if uniform { None } else { Some(handle) },
609 requirements: UniformityRequirements::empty(),
610 }
611 }
612 E::GlobalVariable(gh) => {
614 use crate::AddressSpace as As;
615 assignable_global = Some(gh);
616 let var = &resolve_context.global_vars[gh];
617 let uniform = match var.space {
618 As::Function | As::Private => false,
620 As::WorkGroup => true,
622 As::Uniform | As::PushConstant => true,
624 As::Storage { access } => !access.contains(crate::StorageAccess::STORE),
626 As::Handle => false,
627 };
628 Uniformity {
629 non_uniform_result: if uniform { None } else { Some(handle) },
630 requirements: UniformityRequirements::empty(),
631 }
632 }
633 E::LocalVariable(_) => Uniformity {
634 non_uniform_result: Some(handle),
635 requirements: UniformityRequirements::empty(),
636 },
637 E::Load { pointer } => Uniformity {
638 non_uniform_result: self.add_ref(pointer),
639 requirements: UniformityRequirements::empty(),
640 },
641 E::ImageSample {
642 image,
643 sampler,
644 gather: _,
645 coordinate,
646 array_index,
647 offset: _,
648 level,
649 depth_ref,
650 } => {
651 let image_storage = GlobalOrArgument::from_expression(expression_arena, image)?;
652 let sampler_storage = GlobalOrArgument::from_expression(expression_arena, sampler)?;
653
654 match (image_storage, sampler_storage) {
655 (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
656 self.sampling_set.insert(SamplingKey { image, sampler });
657 }
658 _ => {
659 self.sampling.insert(Sampling {
660 image: image_storage,
661 sampler: sampler_storage,
662 });
663 }
664 }
665
666 let array_nur = array_index.and_then(|h| self.add_ref(h));
668 let level_nur = match level {
669 Sl::Auto | Sl::Zero => None,
670 Sl::Exact(h) | Sl::Bias(h) => self.add_ref(h),
671 Sl::Gradient { x, y } => self.add_ref(x).or(self.add_ref(y)),
672 };
673 let dref_nur = depth_ref.and_then(|h| self.add_ref(h));
674 Uniformity {
675 non_uniform_result: self
676 .add_ref(image)
677 .or(self.add_ref(sampler))
678 .or(self.add_ref(coordinate))
679 .or(array_nur)
680 .or(level_nur)
681 .or(dref_nur),
682 requirements: if level.implicit_derivatives() {
683 UniformityRequirements::IMPLICIT_LEVEL
684 } else {
685 UniformityRequirements::empty()
686 },
687 }
688 }
689 E::ImageLoad {
690 image,
691 coordinate,
692 array_index,
693 sample,
694 level,
695 } => {
696 let array_nur = array_index.and_then(|h| self.add_ref(h));
697 let sample_nur = sample.and_then(|h| self.add_ref(h));
698 let level_nur = level.and_then(|h| self.add_ref(h));
699 Uniformity {
700 non_uniform_result: self
701 .add_ref(image)
702 .or(self.add_ref(coordinate))
703 .or(array_nur)
704 .or(sample_nur)
705 .or(level_nur),
706 requirements: UniformityRequirements::empty(),
707 }
708 }
709 E::ImageQuery { image, query } => {
710 let query_nur = match query {
711 crate::ImageQuery::Size { level: Some(h) } => self.add_ref(h),
712 _ => None,
713 };
714 Uniformity {
715 non_uniform_result: self.add_ref_impl(image, GlobalUse::QUERY).or(query_nur),
716 requirements: UniformityRequirements::empty(),
717 }
718 }
719 E::Unary { expr, .. } => Uniformity {
720 non_uniform_result: self.add_ref(expr),
721 requirements: UniformityRequirements::empty(),
722 },
723 E::Binary { left, right, .. } => Uniformity {
724 non_uniform_result: self.add_ref(left).or(self.add_ref(right)),
725 requirements: UniformityRequirements::empty(),
726 },
727 E::Select {
728 condition,
729 accept,
730 reject,
731 } => Uniformity {
732 non_uniform_result: self
733 .add_ref(condition)
734 .or(self.add_ref(accept))
735 .or(self.add_ref(reject)),
736 requirements: UniformityRequirements::empty(),
737 },
738 E::Derivative { expr, .. } => Uniformity {
740 non_uniform_result: self.add_ref(expr),
742 requirements: UniformityRequirements::DERIVATIVE,
743 },
744 E::Relational { argument, .. } => Uniformity {
745 non_uniform_result: self.add_ref(argument),
746 requirements: UniformityRequirements::empty(),
747 },
748 E::Math {
749 fun: _,
750 arg,
751 arg1,
752 arg2,
753 arg3,
754 } => {
755 let arg1_nur = arg1.and_then(|h| self.add_ref(h));
756 let arg2_nur = arg2.and_then(|h| self.add_ref(h));
757 let arg3_nur = arg3.and_then(|h| self.add_ref(h));
758 Uniformity {
759 non_uniform_result: self.add_ref(arg).or(arg1_nur).or(arg2_nur).or(arg3_nur),
760 requirements: UniformityRequirements::empty(),
761 }
762 }
763 E::As { expr, .. } => Uniformity {
764 non_uniform_result: self.add_ref(expr),
765 requirements: UniformityRequirements::empty(),
766 },
767 E::CallResult(function) => other_functions[function.index()].uniformity.clone(),
768 E::AtomicResult { .. } | E::RayQueryProceedResult => Uniformity {
769 non_uniform_result: Some(handle),
770 requirements: UniformityRequirements::empty(),
771 },
772 E::WorkGroupUniformLoadResult { .. } => Uniformity {
773 non_uniform_result: None,
775 requirements: UniformityRequirements::empty(),
778 },
779 E::ArrayLength(expr) => Uniformity {
780 non_uniform_result: self.add_ref_impl(expr, GlobalUse::QUERY),
781 requirements: UniformityRequirements::empty(),
782 },
783 E::RayQueryGetIntersection {
784 query,
785 committed: _,
786 } => Uniformity {
787 non_uniform_result: self.add_ref(query),
788 requirements: UniformityRequirements::empty(),
789 },
790 E::SubgroupBallotResult => Uniformity {
791 non_uniform_result: Some(handle),
792 requirements: UniformityRequirements::empty(),
793 },
794 E::SubgroupOperationResult { .. } => Uniformity {
795 non_uniform_result: Some(handle),
796 requirements: UniformityRequirements::empty(),
797 },
798 };
799
800 let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?;
801 self.expressions[handle.index()] = ExpressionInfo {
802 uniformity,
803 ref_count: 0,
804 assignable_global,
805 ty,
806 };
807 Ok(())
808 }
809
810 #[allow(clippy::or_fun_call)]
820 fn process_block(
821 &mut self,
822 statements: &crate::Block,
823 other_functions: &[FunctionInfo],
824 mut disruptor: Option<UniformityDisruptor>,
825 expression_arena: &Arena<crate::Expression>,
826 ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
827 use crate::Statement as S;
828
829 let mut combined_uniformity = FunctionUniformity::new();
830 for statement in statements {
831 let uniformity = match *statement {
832 S::Emit(ref range) => {
833 let mut requirements = UniformityRequirements::empty();
834 for expr in range.clone() {
835 let req = self.expressions[expr.index()].uniformity.requirements;
836 if self
837 .flags
838 .contains(ValidationFlags::CONTROL_FLOW_UNIFORMITY)
839 && !req.is_empty()
840 {
841 if let Some(cause) = disruptor {
842 return Err(FunctionError::NonUniformControlFlow(req, expr, cause)
843 .with_span_handle(expr, expression_arena));
844 }
845 }
846 requirements |= req;
847 }
848 FunctionUniformity {
849 result: Uniformity {
850 non_uniform_result: None,
851 requirements,
852 },
853 exit: ExitFlags::empty(),
854 }
855 }
856 S::Break | S::Continue => FunctionUniformity::new(),
857 S::Kill => FunctionUniformity {
858 result: Uniformity::new(),
859 exit: if disruptor.is_some() {
860 ExitFlags::MAY_KILL
861 } else {
862 ExitFlags::empty()
863 },
864 },
865 S::Barrier(_) => FunctionUniformity {
866 result: Uniformity {
867 non_uniform_result: None,
868 requirements: UniformityRequirements::WORK_GROUP_BARRIER,
869 },
870 exit: ExitFlags::empty(),
871 },
872 S::WorkGroupUniformLoad { pointer, .. } => {
873 let _condition_nur = self.add_ref(pointer);
874
875 FunctionUniformity {
894 result: Uniformity {
895 non_uniform_result: None,
896 requirements: UniformityRequirements::WORK_GROUP_BARRIER,
897 },
898 exit: ExitFlags::empty(),
899 }
900 }
901 S::Block(ref b) => {
902 self.process_block(b, other_functions, disruptor, expression_arena)?
903 }
904 S::If {
905 condition,
906 ref accept,
907 ref reject,
908 } => {
909 let condition_nur = self.add_ref(condition);
910 let branch_disruptor =
911 disruptor.or(condition_nur.map(UniformityDisruptor::Expression));
912 let accept_uniformity = self.process_block(
913 accept,
914 other_functions,
915 branch_disruptor,
916 expression_arena,
917 )?;
918 let reject_uniformity = self.process_block(
919 reject,
920 other_functions,
921 branch_disruptor,
922 expression_arena,
923 )?;
924 accept_uniformity | reject_uniformity
925 }
926 S::Switch {
927 selector,
928 ref cases,
929 } => {
930 let selector_nur = self.add_ref(selector);
931 let branch_disruptor =
932 disruptor.or(selector_nur.map(UniformityDisruptor::Expression));
933 let mut uniformity = FunctionUniformity::new();
934 let mut case_disruptor = branch_disruptor;
935 for case in cases.iter() {
936 let case_uniformity = self.process_block(
937 &case.body,
938 other_functions,
939 case_disruptor,
940 expression_arena,
941 )?;
942 case_disruptor = if case.fall_through {
943 case_disruptor.or(case_uniformity.exit_disruptor())
944 } else {
945 branch_disruptor
946 };
947 uniformity = uniformity | case_uniformity;
948 }
949 uniformity
950 }
951 S::Loop {
952 ref body,
953 ref continuing,
954 break_if,
955 } => {
956 let body_uniformity =
957 self.process_block(body, other_functions, disruptor, expression_arena)?;
958 let continuing_disruptor = disruptor.or(body_uniformity.exit_disruptor());
959 let continuing_uniformity = self.process_block(
960 continuing,
961 other_functions,
962 continuing_disruptor,
963 expression_arena,
964 )?;
965 if let Some(expr) = break_if {
966 let _ = self.add_ref(expr);
967 }
968 body_uniformity | continuing_uniformity
969 }
970 S::Return { value } => FunctionUniformity {
971 result: Uniformity {
972 non_uniform_result: value.and_then(|expr| self.add_ref(expr)),
973 requirements: UniformityRequirements::empty(),
974 },
975 exit: if disruptor.is_some() {
976 ExitFlags::MAY_RETURN
977 } else {
978 ExitFlags::empty()
979 },
980 },
981 S::Store { pointer, value } => {
985 let _ = self.add_ref_impl(pointer, GlobalUse::WRITE);
986 let _ = self.add_ref(value);
987 FunctionUniformity::new()
988 }
989 S::ImageStore {
990 image,
991 coordinate,
992 array_index,
993 value,
994 } => {
995 let _ = self.add_ref_impl(image, GlobalUse::WRITE);
996 if let Some(expr) = array_index {
997 let _ = self.add_ref(expr);
998 }
999 let _ = self.add_ref(coordinate);
1000 let _ = self.add_ref(value);
1001 FunctionUniformity::new()
1002 }
1003 S::Call {
1004 function,
1005 ref arguments,
1006 result: _,
1007 } => {
1008 for &argument in arguments {
1009 let _ = self.add_ref(argument);
1010 }
1011 let info = &other_functions[function.index()];
1012 self.process_call(info, arguments, expression_arena)?
1014 }
1015 S::Atomic {
1016 pointer,
1017 ref fun,
1018 value,
1019 result: _,
1020 } => {
1021 let _ = self.add_ref_impl(pointer, GlobalUse::WRITE);
1022 let _ = self.add_ref(value);
1023 if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
1024 let _ = self.add_ref(cmp);
1025 }
1026 FunctionUniformity::new()
1027 }
1028 S::RayQuery { query, ref fun } => {
1029 let _ = self.add_ref(query);
1030 if let crate::RayQueryFunction::Initialize {
1031 acceleration_structure,
1032 descriptor,
1033 } = *fun
1034 {
1035 let _ = self.add_ref(acceleration_structure);
1036 let _ = self.add_ref(descriptor);
1037 }
1038 FunctionUniformity::new()
1039 }
1040 S::SubgroupBallot {
1041 result: _,
1042 predicate,
1043 } => {
1044 if let Some(predicate) = predicate {
1045 let _ = self.add_ref(predicate);
1046 }
1047 FunctionUniformity::new()
1048 }
1049 S::SubgroupCollectiveOperation {
1050 op: _,
1051 collective_op: _,
1052 argument,
1053 result: _,
1054 } => {
1055 let _ = self.add_ref(argument);
1056 FunctionUniformity::new()
1057 }
1058 S::SubgroupGather {
1059 mode,
1060 argument,
1061 result: _,
1062 } => {
1063 let _ = self.add_ref(argument);
1064 match mode {
1065 crate::GatherMode::BroadcastFirst => {}
1066 crate::GatherMode::Broadcast(index)
1067 | crate::GatherMode::Shuffle(index)
1068 | crate::GatherMode::ShuffleDown(index)
1069 | crate::GatherMode::ShuffleUp(index)
1070 | crate::GatherMode::ShuffleXor(index) => {
1071 let _ = self.add_ref(index);
1072 }
1073 }
1074 FunctionUniformity::new()
1075 }
1076 };
1077
1078 disruptor = disruptor.or(uniformity.exit_disruptor());
1079 combined_uniformity = combined_uniformity | uniformity;
1080 }
1081 Ok(combined_uniformity)
1082 }
1083}
1084
1085impl ModuleInfo {
1086 pub(super) fn process_const_expression(
1088 &mut self,
1089 handle: Handle<crate::Expression>,
1090 resolve_context: &ResolveContext,
1091 gctx: crate::proc::GlobalCtx,
1092 ) -> Result<(), super::ConstExpressionError> {
1093 self.const_expression_types[handle.index()] =
1094 resolve_context.resolve(&gctx.global_expressions[handle], |h| Ok(&self[h]))?;
1095 Ok(())
1096 }
1097
1098 pub(super) fn process_function(
1101 &self,
1102 fun: &crate::Function,
1103 module: &crate::Module,
1104 flags: ValidationFlags,
1105 capabilities: super::Capabilities,
1106 ) -> Result<FunctionInfo, WithSpan<FunctionError>> {
1107 let mut info = FunctionInfo {
1108 flags,
1109 available_stages: ShaderStages::all(),
1110 uniformity: Uniformity::new(),
1111 may_kill: false,
1112 sampling_set: crate::FastHashSet::default(),
1113 global_uses: vec![GlobalUse::empty(); module.global_variables.len()].into_boxed_slice(),
1114 expressions: vec![ExpressionInfo::new(); fun.expressions.len()].into_boxed_slice(),
1115 sampling: crate::FastHashSet::default(),
1116 dual_source_blending: false,
1117 };
1118 let resolve_context =
1119 ResolveContext::with_locals(module, &fun.local_variables, &fun.arguments);
1120
1121 for (handle, _) in fun.expressions.iter() {
1122 if let Err(source) = info.process_expression(
1123 handle,
1124 &fun.expressions,
1125 &self.functions,
1126 &resolve_context,
1127 capabilities,
1128 ) {
1129 return Err(FunctionError::Expression { handle, source }
1130 .with_span_handle(handle, &fun.expressions));
1131 }
1132 }
1133
1134 for (_, expr) in fun.local_variables.iter() {
1135 if let Some(init) = expr.init {
1136 let _ = info.add_ref(init);
1137 }
1138 }
1139
1140 let uniformity = info.process_block(&fun.body, &self.functions, None, &fun.expressions)?;
1141 info.uniformity = uniformity.result;
1142 info.may_kill = uniformity.exit.contains(ExitFlags::MAY_KILL);
1143
1144 Ok(info)
1145 }
1146
1147 pub fn get_entry_point(&self, index: usize) -> &FunctionInfo {
1148 &self.entry_points[index]
1149 }
1150}
1151
1152#[test]
1153fn uniform_control_flow() {
1154 use crate::{Expression as E, Statement as S};
1155
1156 let mut type_arena = crate::UniqueArena::new();
1157 let ty = type_arena.insert(
1158 crate::Type {
1159 name: None,
1160 inner: crate::TypeInner::Vector {
1161 size: crate::VectorSize::Bi,
1162 scalar: crate::Scalar::F32,
1163 },
1164 },
1165 Default::default(),
1166 );
1167 let mut global_var_arena = Arena::new();
1168 let non_uniform_global = global_var_arena.append(
1169 crate::GlobalVariable {
1170 name: None,
1171 init: None,
1172 ty,
1173 space: crate::AddressSpace::Handle,
1174 binding: None,
1175 },
1176 Default::default(),
1177 );
1178 let uniform_global = global_var_arena.append(
1179 crate::GlobalVariable {
1180 name: None,
1181 init: None,
1182 ty,
1183 binding: None,
1184 space: crate::AddressSpace::Uniform,
1185 },
1186 Default::default(),
1187 );
1188
1189 let mut expressions = Arena::new();
1190 let constant_expr = expressions.append(E::Literal(crate::Literal::U32(0)), Default::default());
1192 let derivative_expr = expressions.append(
1194 E::Derivative {
1195 axis: crate::DerivativeAxis::X,
1196 ctrl: crate::DerivativeControl::None,
1197 expr: constant_expr,
1198 },
1199 Default::default(),
1200 );
1201 let emit_range_constant_derivative = expressions.range_from(0);
1202 let non_uniform_global_expr =
1203 expressions.append(E::GlobalVariable(non_uniform_global), Default::default());
1204 let uniform_global_expr =
1205 expressions.append(E::GlobalVariable(uniform_global), Default::default());
1206 let emit_range_globals = expressions.range_from(2);
1207
1208 let query_expr = expressions.append(E::ArrayLength(uniform_global_expr), Default::default());
1210 let access_expr = expressions.append(
1212 E::AccessIndex {
1213 base: non_uniform_global_expr,
1214 index: 1,
1215 },
1216 Default::default(),
1217 );
1218 let emit_range_query_access_globals = expressions.range_from(2);
1219
1220 let mut info = FunctionInfo {
1221 flags: ValidationFlags::all(),
1222 available_stages: ShaderStages::all(),
1223 uniformity: Uniformity::new(),
1224 may_kill: false,
1225 sampling_set: crate::FastHashSet::default(),
1226 global_uses: vec![GlobalUse::empty(); global_var_arena.len()].into_boxed_slice(),
1227 expressions: vec![ExpressionInfo::new(); expressions.len()].into_boxed_slice(),
1228 sampling: crate::FastHashSet::default(),
1229 dual_source_blending: false,
1230 };
1231 let resolve_context = ResolveContext {
1232 constants: &Arena::new(),
1233 overrides: &Arena::new(),
1234 types: &type_arena,
1235 special_types: &crate::SpecialTypes::default(),
1236 global_vars: &global_var_arena,
1237 local_vars: &Arena::new(),
1238 functions: &Arena::new(),
1239 arguments: &[],
1240 };
1241 for (handle, _) in expressions.iter() {
1242 info.process_expression(
1243 handle,
1244 &expressions,
1245 &[],
1246 &resolve_context,
1247 super::Capabilities::empty(),
1248 )
1249 .unwrap();
1250 }
1251 assert_eq!(info[non_uniform_global_expr].ref_count, 1);
1252 assert_eq!(info[uniform_global_expr].ref_count, 1);
1253 assert_eq!(info[query_expr].ref_count, 0);
1254 assert_eq!(info[access_expr].ref_count, 0);
1255 assert_eq!(info[non_uniform_global], GlobalUse::empty());
1256 assert_eq!(info[uniform_global], GlobalUse::QUERY);
1257
1258 let stmt_emit1 = S::Emit(emit_range_globals.clone());
1259 let stmt_if_uniform = S::If {
1260 condition: uniform_global_expr,
1261 accept: crate::Block::new(),
1262 reject: vec![
1263 S::Emit(emit_range_constant_derivative.clone()),
1264 S::Store {
1265 pointer: constant_expr,
1266 value: derivative_expr,
1267 },
1268 ]
1269 .into(),
1270 };
1271 assert_eq!(
1272 info.process_block(
1273 &vec![stmt_emit1, stmt_if_uniform].into(),
1274 &[],
1275 None,
1276 &expressions
1277 ),
1278 Ok(FunctionUniformity {
1279 result: Uniformity {
1280 non_uniform_result: None,
1281 requirements: UniformityRequirements::DERIVATIVE,
1282 },
1283 exit: ExitFlags::empty(),
1284 }),
1285 );
1286 assert_eq!(info[constant_expr].ref_count, 2);
1287 assert_eq!(info[uniform_global], GlobalUse::READ | GlobalUse::QUERY);
1288
1289 let stmt_emit2 = S::Emit(emit_range_globals.clone());
1290 let stmt_if_non_uniform = S::If {
1291 condition: non_uniform_global_expr,
1292 accept: vec![
1293 S::Emit(emit_range_constant_derivative),
1294 S::Store {
1295 pointer: constant_expr,
1296 value: derivative_expr,
1297 },
1298 ]
1299 .into(),
1300 reject: crate::Block::new(),
1301 };
1302 {
1303 let block_info = info.process_block(
1304 &vec![stmt_emit2, stmt_if_non_uniform].into(),
1305 &[],
1306 None,
1307 &expressions,
1308 );
1309 if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE {
1310 assert_eq!(info[derivative_expr].ref_count, 2);
1311 } else {
1312 assert_eq!(
1313 block_info,
1314 Err(FunctionError::NonUniformControlFlow(
1315 UniformityRequirements::DERIVATIVE,
1316 derivative_expr,
1317 UniformityDisruptor::Expression(non_uniform_global_expr)
1318 )
1319 .with_span()),
1320 );
1321 assert_eq!(info[derivative_expr].ref_count, 1);
1322 }
1323 }
1324 assert_eq!(info[non_uniform_global], GlobalUse::READ);
1325
1326 let stmt_emit3 = S::Emit(emit_range_globals);
1327 let stmt_return_non_uniform = S::Return {
1328 value: Some(non_uniform_global_expr),
1329 };
1330 assert_eq!(
1331 info.process_block(
1332 &vec![stmt_emit3, stmt_return_non_uniform].into(),
1333 &[],
1334 Some(UniformityDisruptor::Return),
1335 &expressions
1336 ),
1337 Ok(FunctionUniformity {
1338 result: Uniformity {
1339 non_uniform_result: Some(non_uniform_global_expr),
1340 requirements: UniformityRequirements::empty(),
1341 },
1342 exit: ExitFlags::MAY_RETURN,
1343 }),
1344 );
1345 assert_eq!(info[non_uniform_global_expr].ref_count, 3);
1346
1347 let stmt_emit4 = S::Emit(emit_range_query_access_globals);
1349 let stmt_assign = S::Store {
1350 pointer: access_expr,
1351 value: query_expr,
1352 };
1353 let stmt_return_pointer = S::Return {
1354 value: Some(access_expr),
1355 };
1356 let stmt_kill = S::Kill;
1357 assert_eq!(
1358 info.process_block(
1359 &vec![stmt_emit4, stmt_assign, stmt_kill, stmt_return_pointer].into(),
1360 &[],
1361 Some(UniformityDisruptor::Discard),
1362 &expressions
1363 ),
1364 Ok(FunctionUniformity {
1365 result: Uniformity {
1366 non_uniform_result: Some(non_uniform_global_expr),
1367 requirements: UniformityRequirements::empty(),
1368 },
1369 exit: ExitFlags::all(),
1370 }),
1371 );
1372 assert_eq!(info[non_uniform_global], GlobalUse::READ | GlobalUse::WRITE);
1373}