1use std::iter;
2
3use arrayvec::ArrayVec;
4
5use crate::{
6 arena::{Arena, Handle, UniqueArena},
7 ArraySize, BinaryOperator, Constant, Expression, Literal, Override, ScalarKind, Span, Type,
8 TypeInner, UnaryOperator,
9};
10
11macro_rules! with_dollar_sign {
17 ($($body:tt)*) => {
18 macro_rules! __with_dollar_sign { $($body)* }
19 __with_dollar_sign!($);
20 }
21}
22
23macro_rules! gen_component_wise_extractor {
24 (
25 $ident:ident -> $target:ident,
26 literals: [$( $literal:ident => $mapping:ident: $ty:ident ),+ $(,)?],
27 scalar_kinds: [$( $scalar_kind:ident ),* $(,)?],
28 ) => {
29 enum $target<const N: usize> {
31 $(
32 #[doc = concat!(
33 "Maps to [`Literal::",
34 stringify!($literal),
35 "`]",
36 )]
37 $mapping([$ty; N]),
38 )+
39 }
40
41 impl From<$target<1>> for Expression {
42 fn from(value: $target<1>) -> Self {
43 match value {
44 $(
45 $target::$mapping([value]) => {
46 Expression::Literal(Literal::$literal(value))
47 }
48 )+
49 }
50 }
51 }
52
53 #[doc = concat!(
54 "Attempts to evaluate multiple `exprs` as a combined [`",
55 stringify!($target),
56 "`] to pass to `handler`. ",
57 )]
58 fn $ident<const N: usize, const M: usize, F>(
65 eval: &mut ConstantEvaluator<'_>,
66 span: Span,
67 exprs: [Handle<Expression>; N],
68 mut handler: F,
69 ) -> Result<Handle<Expression>, ConstantEvaluatorError>
70 where
71 $target<M>: Into<Expression>,
72 F: FnMut($target<N>) -> Result<$target<M>, ConstantEvaluatorError> + Clone,
73 {
74 assert!(N > 0);
75 let err = ConstantEvaluatorError::InvalidMathArg;
76 let mut exprs = exprs.into_iter();
77
78 macro_rules! sanitize {
79 ($expr:expr) => {
80 eval.eval_zero_value_and_splat($expr, span)
81 .map(|expr| &eval.expressions[expr])
82 };
83 }
84
85 let new_expr = match sanitize!(exprs.next().unwrap())? {
86 $(
87 &Expression::Literal(Literal::$literal(x)) => iter::once(Ok(x))
88 .chain(exprs.map(|expr| {
89 sanitize!(expr).and_then(|expr| match expr {
90 &Expression::Literal(Literal::$literal(x)) => Ok(x),
91 _ => Err(err.clone()),
92 })
93 }))
94 .collect::<Result<ArrayVec<_, N>, _>>()
95 .map(|a| a.into_inner().unwrap())
96 .map($target::$mapping)
97 .and_then(|comps| Ok(handler(comps)?.into())),
98 )+
99 &Expression::Compose { ty, ref components } => match &eval.types[ty].inner {
100 &TypeInner::Vector { size, scalar } => match scalar.kind {
101 $(ScalarKind::$scalar_kind)|* => {
102 let first_ty = ty;
103 let mut component_groups =
104 ArrayVec::<ArrayVec<_, { crate::VectorSize::MAX }>, N>::new();
105 component_groups.push(crate::proc::flatten_compose(
106 first_ty,
107 components,
108 eval.expressions,
109 eval.types,
110 ).collect());
111 component_groups.extend(
112 exprs
113 .map(|expr| {
114 sanitize!(expr).and_then(|expr| match expr {
115 &Expression::Compose { ty, ref components }
116 if &eval.types[ty].inner
117 == &eval.types[first_ty].inner =>
118 {
119 Ok(crate::proc::flatten_compose(
120 ty,
121 components,
122 eval.expressions,
123 eval.types,
124 ).collect())
125 }
126 _ => Err(err.clone()),
127 })
128 })
129 .collect::<Result<ArrayVec<_, { crate::VectorSize::MAX }>, _>>(
130 )?,
131 );
132 let component_groups = component_groups.into_inner().unwrap();
133 let mut new_components =
134 ArrayVec::<_, { crate::VectorSize::MAX }>::new();
135 for idx in 0..(size as u8).into() {
136 let group = component_groups
137 .iter()
138 .map(|cs| cs[idx])
139 .collect::<ArrayVec<_, N>>()
140 .into_inner()
141 .unwrap();
142 new_components.push($ident(
143 eval,
144 span,
145 group,
146 handler.clone(),
147 )?);
148 }
149 Ok(Expression::Compose {
150 ty: first_ty,
151 components: new_components.into_iter().collect(),
152 })
153 }
154 _ => return Err(err),
155 },
156 _ => return Err(err),
157 },
158 _ => return Err(err),
159 }?;
160 eval.register_evaluated_expr(new_expr, span)
161 }
162
163 with_dollar_sign! {
164 ($d:tt) => {
165 #[allow(unused)]
166 #[doc = concat!(
167 "A convenience macro for using the same RHS for each [`",
168 stringify!($target),
169 "`] variant in a call to [`",
170 stringify!($ident),
171 "`].",
172 )]
173 macro_rules! $ident {
174 (
175 $eval:expr,
176 $span:expr,
177 [$d ($d expr:expr),+ $d (,)?],
178 |$d ($d arg:ident),+| $d tt:tt
179 ) => {
180 $ident($eval, $span, [$d ($d expr),+], |args| match args {
181 $(
182 $target::$mapping([$d ($d arg),+]) => {
183 let res = $d tt;
184 Result::map(res, $target::$mapping)
185 },
186 )+
187 })
188 };
189 }
190 };
191 }
192 };
193}
194
195gen_component_wise_extractor! {
196 component_wise_scalar -> Scalar,
197 literals: [
198 AbstractFloat => AbstractFloat: f64,
199 F32 => F32: f32,
200 AbstractInt => AbstractInt: i64,
201 U32 => U32: u32,
202 I32 => I32: i32,
203 U64 => U64: u64,
204 I64 => I64: i64,
205 ],
206 scalar_kinds: [
207 Float,
208 AbstractFloat,
209 Sint,
210 Uint,
211 AbstractInt,
212 ],
213}
214
215gen_component_wise_extractor! {
216 component_wise_float -> Float,
217 literals: [
218 AbstractFloat => Abstract: f64,
219 F32 => F32: f32,
220 ],
221 scalar_kinds: [
222 Float,
223 AbstractFloat,
224 ],
225}
226
227gen_component_wise_extractor! {
228 component_wise_concrete_int -> ConcreteInt,
229 literals: [
230 U32 => U32: u32,
231 I32 => I32: i32,
232 ],
233 scalar_kinds: [
234 Sint,
235 Uint,
236 ],
237}
238
239gen_component_wise_extractor! {
240 component_wise_signed -> Signed,
241 literals: [
242 AbstractFloat => AbstractFloat: f64,
243 AbstractInt => AbstractInt: i64,
244 F32 => F32: f32,
245 I32 => I32: i32,
246 ],
247 scalar_kinds: [
248 Sint,
249 AbstractInt,
250 Float,
251 AbstractFloat,
252 ],
253}
254
255#[derive(Debug)]
256enum Behavior<'a> {
257 Wgsl(WgslRestrictions<'a>),
258 Glsl(GlslRestrictions<'a>),
259}
260
261impl Behavior<'_> {
262 const fn has_runtime_restrictions(&self) -> bool {
264 matches!(
265 self,
266 &Behavior::Wgsl(WgslRestrictions::Runtime(_))
267 | &Behavior::Glsl(GlslRestrictions::Runtime(_))
268 )
269 }
270}
271
272#[derive(Debug)]
290pub struct ConstantEvaluator<'a> {
291 behavior: Behavior<'a>,
293
294 types: &'a mut UniqueArena<Type>,
301
302 constants: &'a Arena<Constant>,
304
305 overrides: &'a Arena<Override>,
307
308 expressions: &'a mut Arena<Expression>,
310
311 expression_kind_tracker: &'a mut ExpressionKindTracker,
313}
314
315#[derive(Debug)]
316enum WgslRestrictions<'a> {
317 Const,
319 Override,
322 Runtime(FunctionLocalData<'a>),
326}
327
328#[derive(Debug)]
329enum GlslRestrictions<'a> {
330 Const,
332 Runtime(FunctionLocalData<'a>),
336}
337
338#[derive(Debug)]
339struct FunctionLocalData<'a> {
340 global_expressions: &'a Arena<Expression>,
342 emitter: &'a mut super::Emitter,
343 block: &'a mut crate::Block,
344}
345
346#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
347pub enum ExpressionKind {
348 Const,
349 Override,
350 Runtime,
351}
352
353#[derive(Debug)]
354pub struct ExpressionKindTracker {
355 inner: Vec<ExpressionKind>,
356}
357
358impl ExpressionKindTracker {
359 pub const fn new() -> Self {
360 Self { inner: Vec::new() }
361 }
362
363 pub fn force_non_const(&mut self, value: Handle<Expression>) {
365 self.inner[value.index()] = ExpressionKind::Runtime;
366 }
367
368 pub fn insert(&mut self, value: Handle<Expression>, expr_type: ExpressionKind) {
369 assert_eq!(self.inner.len(), value.index());
370 self.inner.push(expr_type);
371 }
372 pub fn is_const(&self, h: Handle<Expression>) -> bool {
373 matches!(self.type_of(h), ExpressionKind::Const)
374 }
375
376 pub fn is_const_or_override(&self, h: Handle<Expression>) -> bool {
377 matches!(
378 self.type_of(h),
379 ExpressionKind::Const | ExpressionKind::Override
380 )
381 }
382
383 fn type_of(&self, value: Handle<Expression>) -> ExpressionKind {
384 self.inner[value.index()]
385 }
386
387 pub fn from_arena(arena: &Arena<Expression>) -> Self {
388 let mut tracker = Self {
389 inner: Vec::with_capacity(arena.len()),
390 };
391 for (_, expr) in arena.iter() {
392 tracker.inner.push(tracker.type_of_with_expr(expr));
393 }
394 tracker
395 }
396
397 fn type_of_with_expr(&self, expr: &Expression) -> ExpressionKind {
398 match *expr {
399 Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
400 ExpressionKind::Const
401 }
402 Expression::Override(_) => ExpressionKind::Override,
403 Expression::Compose { ref components, .. } => {
404 let mut expr_type = ExpressionKind::Const;
405 for component in components {
406 expr_type = expr_type.max(self.type_of(*component))
407 }
408 expr_type
409 }
410 Expression::Splat { value, .. } => self.type_of(value),
411 Expression::AccessIndex { base, .. } => self.type_of(base),
412 Expression::Access { base, index } => self.type_of(base).max(self.type_of(index)),
413 Expression::Swizzle { vector, .. } => self.type_of(vector),
414 Expression::Unary { expr, .. } => self.type_of(expr),
415 Expression::Binary { left, right, .. } => self.type_of(left).max(self.type_of(right)),
416 Expression::Math {
417 arg,
418 arg1,
419 arg2,
420 arg3,
421 ..
422 } => self
423 .type_of(arg)
424 .max(
425 arg1.map(|arg| self.type_of(arg))
426 .unwrap_or(ExpressionKind::Const),
427 )
428 .max(
429 arg2.map(|arg| self.type_of(arg))
430 .unwrap_or(ExpressionKind::Const),
431 )
432 .max(
433 arg3.map(|arg| self.type_of(arg))
434 .unwrap_or(ExpressionKind::Const),
435 ),
436 Expression::As { expr, .. } => self.type_of(expr),
437 Expression::Select {
438 condition,
439 accept,
440 reject,
441 } => self
442 .type_of(condition)
443 .max(self.type_of(accept))
444 .max(self.type_of(reject)),
445 Expression::Relational { argument, .. } => self.type_of(argument),
446 Expression::ArrayLength(expr) => self.type_of(expr),
447 _ => ExpressionKind::Runtime,
448 }
449 }
450}
451
452#[derive(Clone, Debug, thiserror::Error)]
453#[cfg_attr(test, derive(PartialEq))]
454pub enum ConstantEvaluatorError {
455 #[error("Constants cannot access function arguments")]
456 FunctionArg,
457 #[error("Constants cannot access global variables")]
458 GlobalVariable,
459 #[error("Constants cannot access local variables")]
460 LocalVariable,
461 #[error("Cannot get the array length of a non array type")]
462 InvalidArrayLengthArg,
463 #[error("Constants cannot get the array length of a dynamically sized array")]
464 ArrayLengthDynamic,
465 #[error("Constants cannot call functions")]
466 Call,
467 #[error("Constants don't support workGroupUniformLoad")]
468 WorkGroupUniformLoadResult,
469 #[error("Constants don't support atomic functions")]
470 Atomic,
471 #[error("Constants don't support derivative functions")]
472 Derivative,
473 #[error("Constants don't support load expressions")]
474 Load,
475 #[error("Constants don't support image expressions")]
476 ImageExpression,
477 #[error("Constants don't support ray query expressions")]
478 RayQueryExpression,
479 #[error("Constants don't support subgroup expressions")]
480 SubgroupExpression,
481 #[error("Cannot access the type")]
482 InvalidAccessBase,
483 #[error("Cannot access at the index")]
484 InvalidAccessIndex,
485 #[error("Cannot access with index of type")]
486 InvalidAccessIndexTy,
487 #[error("Constants don't support array length expressions")]
488 ArrayLength,
489 #[error("Cannot cast scalar components of expression `{from}` to type `{to}`")]
490 InvalidCastArg { from: String, to: String },
491 #[error("Cannot apply the unary op to the argument")]
492 InvalidUnaryOpArg,
493 #[error("Cannot apply the binary op to the arguments")]
494 InvalidBinaryOpArgs,
495 #[error("Cannot apply math function to type")]
496 InvalidMathArg,
497 #[error("{0:?} built-in function expects {1:?} arguments but {2:?} were supplied")]
498 InvalidMathArgCount(crate::MathFunction, usize, usize),
499 #[error("value of `low` is greater than `high` for clamp built-in function")]
500 InvalidClamp,
501 #[error("Splat is defined only on scalar values")]
502 SplatScalarOnly,
503 #[error("Can only swizzle vector constants")]
504 SwizzleVectorOnly,
505 #[error("swizzle component not present in source expression")]
506 SwizzleOutOfBounds,
507 #[error("Type is not constructible")]
508 TypeNotConstructible,
509 #[error("Subexpression(s) are not constant")]
510 SubexpressionsAreNotConstant,
511 #[error("Not implemented as constant expression: {0}")]
512 NotImplemented(String),
513 #[error("{0} operation overflowed")]
514 Overflow(String),
515 #[error(
516 "the concrete type `{to_type}` cannot represent the abstract value `{value}` accurately"
517 )]
518 AutomaticConversionLossy {
519 value: String,
520 to_type: &'static str,
521 },
522 #[error("abstract floating-point values cannot be automatically converted to integers")]
523 AutomaticConversionFloatToInt { to_type: &'static str },
524 #[error("Division by zero")]
525 DivisionByZero,
526 #[error("Remainder by zero")]
527 RemainderByZero,
528 #[error("RHS of shift operation is greater than or equal to 32")]
529 ShiftedMoreThan32Bits,
530 #[error(transparent)]
531 Literal(#[from] crate::valid::LiteralError),
532 #[error("Can't use pipeline-overridable constants in const-expressions")]
533 Override,
534 #[error("Unexpected runtime-expression")]
535 RuntimeExpr,
536 #[error("Unexpected override-expression")]
537 OverrideExpr,
538}
539
540impl<'a> ConstantEvaluator<'a> {
541 pub fn for_wgsl_module(
546 module: &'a mut crate::Module,
547 global_expression_kind_tracker: &'a mut ExpressionKindTracker,
548 in_override_ctx: bool,
549 ) -> Self {
550 Self::for_module(
551 Behavior::Wgsl(if in_override_ctx {
552 WgslRestrictions::Override
553 } else {
554 WgslRestrictions::Const
555 }),
556 module,
557 global_expression_kind_tracker,
558 )
559 }
560
561 pub fn for_glsl_module(
566 module: &'a mut crate::Module,
567 global_expression_kind_tracker: &'a mut ExpressionKindTracker,
568 ) -> Self {
569 Self::for_module(
570 Behavior::Glsl(GlslRestrictions::Const),
571 module,
572 global_expression_kind_tracker,
573 )
574 }
575
576 fn for_module(
577 behavior: Behavior<'a>,
578 module: &'a mut crate::Module,
579 global_expression_kind_tracker: &'a mut ExpressionKindTracker,
580 ) -> Self {
581 Self {
582 behavior,
583 types: &mut module.types,
584 constants: &module.constants,
585 overrides: &module.overrides,
586 expressions: &mut module.global_expressions,
587 expression_kind_tracker: global_expression_kind_tracker,
588 }
589 }
590
591 pub fn for_wgsl_function(
596 module: &'a mut crate::Module,
597 expressions: &'a mut Arena<Expression>,
598 local_expression_kind_tracker: &'a mut ExpressionKindTracker,
599 emitter: &'a mut super::Emitter,
600 block: &'a mut crate::Block,
601 ) -> Self {
602 Self {
603 behavior: Behavior::Wgsl(WgslRestrictions::Runtime(FunctionLocalData {
604 global_expressions: &module.global_expressions,
605 emitter,
606 block,
607 })),
608 types: &mut module.types,
609 constants: &module.constants,
610 overrides: &module.overrides,
611 expressions,
612 expression_kind_tracker: local_expression_kind_tracker,
613 }
614 }
615
616 pub fn for_glsl_function(
621 module: &'a mut crate::Module,
622 expressions: &'a mut Arena<Expression>,
623 local_expression_kind_tracker: &'a mut ExpressionKindTracker,
624 emitter: &'a mut super::Emitter,
625 block: &'a mut crate::Block,
626 ) -> Self {
627 Self {
628 behavior: Behavior::Glsl(GlslRestrictions::Runtime(FunctionLocalData {
629 global_expressions: &module.global_expressions,
630 emitter,
631 block,
632 })),
633 types: &mut module.types,
634 constants: &module.constants,
635 overrides: &module.overrides,
636 expressions,
637 expression_kind_tracker: local_expression_kind_tracker,
638 }
639 }
640
641 pub fn to_ctx(&self) -> crate::proc::GlobalCtx {
642 crate::proc::GlobalCtx {
643 types: self.types,
644 constants: self.constants,
645 overrides: self.overrides,
646 global_expressions: match self.function_local_data() {
647 Some(data) => data.global_expressions,
648 None => self.expressions,
649 },
650 }
651 }
652
653 fn check(&self, expr: Handle<Expression>) -> Result<(), ConstantEvaluatorError> {
654 if !self.expression_kind_tracker.is_const(expr) {
655 log::debug!("check: SubexpressionsAreNotConstant");
656 return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
657 }
658 Ok(())
659 }
660
661 fn check_and_get(
662 &mut self,
663 expr: Handle<Expression>,
664 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
665 match self.expressions[expr] {
666 Expression::Constant(c) => {
667 if let Some(function_local_data) = self.function_local_data() {
670 self.copy_from(
672 self.constants[c].init,
673 function_local_data.global_expressions,
674 )
675 } else {
676 Ok(self.constants[c].init)
678 }
679 }
680 _ => {
681 self.check(expr)?;
682 Ok(expr)
683 }
684 }
685 }
686
687 pub fn try_eval_and_append(
711 &mut self,
712 expr: Expression,
713 span: Span,
714 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
715 match self.expression_kind_tracker.type_of_with_expr(&expr) {
716 ExpressionKind::Const => {
717 let eval_result = self.try_eval_and_append_impl(&expr, span);
718 if self.behavior.has_runtime_restrictions()
723 && matches!(
724 eval_result,
725 Err(ConstantEvaluatorError::NotImplemented(_)
726 | ConstantEvaluatorError::InvalidBinaryOpArgs,)
727 )
728 {
729 Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
730 } else {
731 eval_result
732 }
733 }
734 ExpressionKind::Override => match self.behavior {
735 Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)) => {
736 Ok(self.append_expr(expr, span, ExpressionKind::Override))
737 }
738 Behavior::Wgsl(WgslRestrictions::Const) => {
739 Err(ConstantEvaluatorError::OverrideExpr)
740 }
741 Behavior::Glsl(_) => {
742 unreachable!()
743 }
744 },
745 ExpressionKind::Runtime => {
746 if self.behavior.has_runtime_restrictions() {
747 Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
748 } else {
749 Err(ConstantEvaluatorError::RuntimeExpr)
750 }
751 }
752 }
753 }
754
755 const fn is_global_arena(&self) -> bool {
757 matches!(
758 self.behavior,
759 Behavior::Wgsl(WgslRestrictions::Const | WgslRestrictions::Override)
760 | Behavior::Glsl(GlslRestrictions::Const)
761 )
762 }
763
764 const fn function_local_data(&self) -> Option<&FunctionLocalData<'a>> {
765 match self.behavior {
766 Behavior::Wgsl(WgslRestrictions::Runtime(ref function_local_data))
767 | Behavior::Glsl(GlslRestrictions::Runtime(ref function_local_data)) => {
768 Some(function_local_data)
769 }
770 _ => None,
771 }
772 }
773
774 fn try_eval_and_append_impl(
775 &mut self,
776 expr: &Expression,
777 span: Span,
778 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
779 log::trace!("try_eval_and_append: {:?}", expr);
780 match *expr {
781 Expression::Constant(c) if self.is_global_arena() => {
782 Ok(self.constants[c].init)
785 }
786 Expression::Override(_) => Err(ConstantEvaluatorError::Override),
787 Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
788 self.register_evaluated_expr(expr.clone(), span)
789 }
790 Expression::Compose { ty, ref components } => {
791 let components = components
792 .iter()
793 .map(|component| self.check_and_get(*component))
794 .collect::<Result<Vec<_>, _>>()?;
795 self.register_evaluated_expr(Expression::Compose { ty, components }, span)
796 }
797 Expression::Splat { size, value } => {
798 let value = self.check_and_get(value)?;
799 self.register_evaluated_expr(Expression::Splat { size, value }, span)
800 }
801 Expression::AccessIndex { base, index } => {
802 let base = self.check_and_get(base)?;
803
804 self.access(base, index as usize, span)
805 }
806 Expression::Access { base, index } => {
807 let base = self.check_and_get(base)?;
808 let index = self.check_and_get(index)?;
809
810 self.access(base, self.constant_index(index)?, span)
811 }
812 Expression::Swizzle {
813 size,
814 vector,
815 pattern,
816 } => {
817 let vector = self.check_and_get(vector)?;
818
819 self.swizzle(size, span, vector, pattern)
820 }
821 Expression::Unary { expr, op } => {
822 let expr = self.check_and_get(expr)?;
823
824 self.unary_op(op, expr, span)
825 }
826 Expression::Binary { left, right, op } => {
827 let left = self.check_and_get(left)?;
828 let right = self.check_and_get(right)?;
829
830 self.binary_op(op, left, right, span)
831 }
832 Expression::Math {
833 fun,
834 arg,
835 arg1,
836 arg2,
837 arg3,
838 } => {
839 let arg = self.check_and_get(arg)?;
840 let arg1 = arg1.map(|arg| self.check_and_get(arg)).transpose()?;
841 let arg2 = arg2.map(|arg| self.check_and_get(arg)).transpose()?;
842 let arg3 = arg3.map(|arg| self.check_and_get(arg)).transpose()?;
843
844 self.math(arg, arg1, arg2, arg3, fun, span)
845 }
846 Expression::As {
847 convert,
848 expr,
849 kind,
850 } => {
851 let expr = self.check_and_get(expr)?;
852
853 match convert {
854 Some(width) => self.cast(expr, crate::Scalar { kind, width }, span),
855 None => Err(ConstantEvaluatorError::NotImplemented(
856 "bitcast built-in function".into(),
857 )),
858 }
859 }
860 Expression::Select { .. } => Err(ConstantEvaluatorError::NotImplemented(
861 "select built-in function".into(),
862 )),
863 Expression::Relational { fun, .. } => Err(ConstantEvaluatorError::NotImplemented(
864 format!("{fun:?} built-in function"),
865 )),
866 Expression::ArrayLength(expr) => match self.behavior {
867 Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength),
868 Behavior::Glsl(_) => {
869 let expr = self.check_and_get(expr)?;
870 self.array_length(expr, span)
871 }
872 },
873 Expression::Load { .. } => Err(ConstantEvaluatorError::Load),
874 Expression::LocalVariable(_) => Err(ConstantEvaluatorError::LocalVariable),
875 Expression::Derivative { .. } => Err(ConstantEvaluatorError::Derivative),
876 Expression::CallResult { .. } => Err(ConstantEvaluatorError::Call),
877 Expression::WorkGroupUniformLoadResult { .. } => {
878 Err(ConstantEvaluatorError::WorkGroupUniformLoadResult)
879 }
880 Expression::AtomicResult { .. } => Err(ConstantEvaluatorError::Atomic),
881 Expression::FunctionArgument(_) => Err(ConstantEvaluatorError::FunctionArg),
882 Expression::GlobalVariable(_) => Err(ConstantEvaluatorError::GlobalVariable),
883 Expression::ImageSample { .. }
884 | Expression::ImageLoad { .. }
885 | Expression::ImageQuery { .. } => Err(ConstantEvaluatorError::ImageExpression),
886 Expression::RayQueryProceedResult | Expression::RayQueryGetIntersection { .. } => {
887 Err(ConstantEvaluatorError::RayQueryExpression)
888 }
889 Expression::SubgroupBallotResult { .. } => {
890 Err(ConstantEvaluatorError::SubgroupExpression)
891 }
892 Expression::SubgroupOperationResult { .. } => {
893 Err(ConstantEvaluatorError::SubgroupExpression)
894 }
895 }
896 }
897
898 fn splat(
911 &mut self,
912 value: Handle<Expression>,
913 size: crate::VectorSize,
914 span: Span,
915 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
916 match self.expressions[value] {
917 Expression::Literal(literal) => {
918 let scalar = literal.scalar();
919 let ty = self.types.insert(
920 Type {
921 name: None,
922 inner: TypeInner::Vector { size, scalar },
923 },
924 span,
925 );
926 let expr = Expression::Compose {
927 ty,
928 components: vec![value; size as usize],
929 };
930 self.register_evaluated_expr(expr, span)
931 }
932 Expression::ZeroValue(ty) => {
933 let inner = match self.types[ty].inner {
934 TypeInner::Scalar(scalar) => TypeInner::Vector { size, scalar },
935 _ => return Err(ConstantEvaluatorError::SplatScalarOnly),
936 };
937 let res_ty = self.types.insert(Type { name: None, inner }, span);
938 let expr = Expression::ZeroValue(res_ty);
939 self.register_evaluated_expr(expr, span)
940 }
941 _ => Err(ConstantEvaluatorError::SplatScalarOnly),
942 }
943 }
944
945 fn swizzle(
946 &mut self,
947 size: crate::VectorSize,
948 span: Span,
949 src_constant: Handle<Expression>,
950 pattern: [crate::SwizzleComponent; 4],
951 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
952 let mut get_dst_ty = |ty| match self.types[ty].inner {
953 TypeInner::Vector { size: _, scalar } => Ok(self.types.insert(
954 Type {
955 name: None,
956 inner: TypeInner::Vector { size, scalar },
957 },
958 span,
959 )),
960 _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
961 };
962
963 match self.expressions[src_constant] {
964 Expression::ZeroValue(ty) => {
965 let dst_ty = get_dst_ty(ty)?;
966 let expr = Expression::ZeroValue(dst_ty);
967 self.register_evaluated_expr(expr, span)
968 }
969 Expression::Splat { value, .. } => {
970 let expr = Expression::Splat { size, value };
971 self.register_evaluated_expr(expr, span)
972 }
973 Expression::Compose { ty, ref components } => {
974 let dst_ty = get_dst_ty(ty)?;
975
976 let mut flattened = [src_constant; 4]; let len =
978 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
979 .zip(flattened.iter_mut())
980 .map(|(component, elt)| *elt = component)
981 .count();
982 let flattened = &flattened[..len];
983
984 let swizzled_components = pattern[..size as usize]
985 .iter()
986 .map(|&sc| {
987 let sc = sc as usize;
988 if let Some(elt) = flattened.get(sc) {
989 Ok(*elt)
990 } else {
991 Err(ConstantEvaluatorError::SwizzleOutOfBounds)
992 }
993 })
994 .collect::<Result<Vec<Handle<Expression>>, _>>()?;
995 let expr = Expression::Compose {
996 ty: dst_ty,
997 components: swizzled_components,
998 };
999 self.register_evaluated_expr(expr, span)
1000 }
1001 _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1002 }
1003 }
1004
1005 fn math(
1006 &mut self,
1007 arg: Handle<Expression>,
1008 arg1: Option<Handle<Expression>>,
1009 arg2: Option<Handle<Expression>>,
1010 arg3: Option<Handle<Expression>>,
1011 fun: crate::MathFunction,
1012 span: Span,
1013 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1014 let expected = fun.argument_count();
1015 let given = Some(arg)
1016 .into_iter()
1017 .chain(arg1)
1018 .chain(arg2)
1019 .chain(arg3)
1020 .count();
1021 if expected != given {
1022 return Err(ConstantEvaluatorError::InvalidMathArgCount(
1023 fun, expected, given,
1024 ));
1025 }
1026
1027 match fun {
1029 crate::MathFunction::Abs => {
1031 component_wise_scalar(self, span, [arg], |args| match args {
1032 Scalar::AbstractFloat([e]) => Ok(Scalar::AbstractFloat([e.abs()])),
1033 Scalar::F32([e]) => Ok(Scalar::F32([e.abs()])),
1034 Scalar::AbstractInt([e]) => Ok(Scalar::AbstractInt([e.abs()])),
1035 Scalar::I32([e]) => Ok(Scalar::I32([e.wrapping_abs()])),
1036 Scalar::U32([e]) => Ok(Scalar::U32([e])), Scalar::I64([e]) => Ok(Scalar::I64([e.wrapping_abs()])),
1038 Scalar::U64([e]) => Ok(Scalar::U64([e])),
1039 })
1040 }
1041 crate::MathFunction::Min => {
1042 component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1043 Ok([e1.min(e2)])
1044 })
1045 }
1046 crate::MathFunction::Max => {
1047 component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1048 Ok([e1.max(e2)])
1049 })
1050 }
1051 crate::MathFunction::Clamp => {
1052 component_wise_scalar!(
1053 self,
1054 span,
1055 [arg, arg1.unwrap(), arg2.unwrap()],
1056 |e, low, high| {
1057 if low > high {
1058 Err(ConstantEvaluatorError::InvalidClamp)
1059 } else {
1060 Ok([e.clamp(low, high)])
1061 }
1062 }
1063 )
1064 }
1065 crate::MathFunction::Saturate => {
1066 component_wise_float!(self, span, [arg], |e| { Ok([e.clamp(0., 1.)]) })
1067 }
1068
1069 crate::MathFunction::Cos => {
1071 component_wise_float!(self, span, [arg], |e| { Ok([e.cos()]) })
1072 }
1073 crate::MathFunction::Cosh => {
1074 component_wise_float!(self, span, [arg], |e| { Ok([e.cosh()]) })
1075 }
1076 crate::MathFunction::Sin => {
1077 component_wise_float!(self, span, [arg], |e| { Ok([e.sin()]) })
1078 }
1079 crate::MathFunction::Sinh => {
1080 component_wise_float!(self, span, [arg], |e| { Ok([e.sinh()]) })
1081 }
1082 crate::MathFunction::Tan => {
1083 component_wise_float!(self, span, [arg], |e| { Ok([e.tan()]) })
1084 }
1085 crate::MathFunction::Tanh => {
1086 component_wise_float!(self, span, [arg], |e| { Ok([e.tanh()]) })
1087 }
1088 crate::MathFunction::Acos => {
1089 component_wise_float!(self, span, [arg], |e| { Ok([e.acos()]) })
1090 }
1091 crate::MathFunction::Asin => {
1092 component_wise_float!(self, span, [arg], |e| { Ok([e.asin()]) })
1093 }
1094 crate::MathFunction::Atan => {
1095 component_wise_float!(self, span, [arg], |e| { Ok([e.atan()]) })
1096 }
1097 crate::MathFunction::Asinh => {
1098 component_wise_float!(self, span, [arg], |e| { Ok([e.asinh()]) })
1099 }
1100 crate::MathFunction::Acosh => {
1101 component_wise_float!(self, span, [arg], |e| { Ok([e.acosh()]) })
1102 }
1103 crate::MathFunction::Atanh => {
1104 component_wise_float!(self, span, [arg], |e| { Ok([e.atanh()]) })
1105 }
1106 crate::MathFunction::Radians => {
1107 component_wise_float!(self, span, [arg], |e1| { Ok([e1.to_radians()]) })
1108 }
1109 crate::MathFunction::Degrees => {
1110 component_wise_float!(self, span, [arg], |e| { Ok([e.to_degrees()]) })
1111 }
1112
1113 crate::MathFunction::Ceil => {
1115 component_wise_float!(self, span, [arg], |e| { Ok([e.ceil()]) })
1116 }
1117 crate::MathFunction::Floor => {
1118 component_wise_float!(self, span, [arg], |e| { Ok([e.floor()]) })
1119 }
1120 crate::MathFunction::Round => {
1121 fn round_ties_even(x: f64) -> f64 {
1128 let i = x as i64;
1129 let f = (x - i as f64).abs();
1130 if f == 0.5 {
1131 if i & 1 == 1 {
1132 (x.abs() + 0.5).copysign(x)
1134 } else {
1135 (x.abs() - 0.5).copysign(x)
1136 }
1137 } else {
1138 x.round()
1139 }
1140 }
1141 component_wise_float(self, span, [arg], |e| match e {
1142 Float::Abstract([e]) => Ok(Float::Abstract([round_ties_even(e)])),
1143 Float::F32([e]) => Ok(Float::F32([(round_ties_even(e as f64) as f32)])),
1144 })
1145 }
1146 crate::MathFunction::Fract => {
1147 component_wise_float!(self, span, [arg], |e| {
1148 Ok([e - e.floor()])
1151 })
1152 }
1153 crate::MathFunction::Trunc => {
1154 component_wise_float!(self, span, [arg], |e| { Ok([e.trunc()]) })
1155 }
1156
1157 crate::MathFunction::Exp => {
1159 component_wise_float!(self, span, [arg], |e| { Ok([e.exp()]) })
1160 }
1161 crate::MathFunction::Exp2 => {
1162 component_wise_float!(self, span, [arg], |e| { Ok([e.exp2()]) })
1163 }
1164 crate::MathFunction::Log => {
1165 component_wise_float!(self, span, [arg], |e| { Ok([e.ln()]) })
1166 }
1167 crate::MathFunction::Log2 => {
1168 component_wise_float!(self, span, [arg], |e| { Ok([e.log2()]) })
1169 }
1170 crate::MathFunction::Pow => {
1171 component_wise_float!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1172 Ok([e1.powf(e2)])
1173 })
1174 }
1175
1176 crate::MathFunction::Sign => {
1178 component_wise_signed!(self, span, [arg], |e| { Ok([e.signum()]) })
1179 }
1180 crate::MathFunction::Fma => {
1181 component_wise_float!(
1182 self,
1183 span,
1184 [arg, arg1.unwrap(), arg2.unwrap()],
1185 |e1, e2, e3| { Ok([e1.mul_add(e2, e3)]) }
1186 )
1187 }
1188 crate::MathFunction::Step => {
1189 component_wise_float!(self, span, [arg, arg1.unwrap()], |edge, x| {
1190 Ok([if edge <= x { 1.0 } else { 0.0 }])
1191 })
1192 }
1193 crate::MathFunction::Sqrt => {
1194 component_wise_float!(self, span, [arg], |e| { Ok([e.sqrt()]) })
1195 }
1196 crate::MathFunction::InverseSqrt => {
1197 component_wise_float!(self, span, [arg], |e| { Ok([1. / e.sqrt()]) })
1198 }
1199
1200 crate::MathFunction::CountTrailingZeros => {
1202 component_wise_concrete_int!(self, span, [arg], |e| {
1203 #[allow(clippy::useless_conversion)]
1204 Ok([e
1205 .trailing_zeros()
1206 .try_into()
1207 .expect("bit count overflowed 32 bits, somehow!?")])
1208 })
1209 }
1210 crate::MathFunction::CountLeadingZeros => {
1211 component_wise_concrete_int!(self, span, [arg], |e| {
1212 #[allow(clippy::useless_conversion)]
1213 Ok([e
1214 .leading_zeros()
1215 .try_into()
1216 .expect("bit count overflowed 32 bits, somehow!?")])
1217 })
1218 }
1219 crate::MathFunction::CountOneBits => {
1220 component_wise_concrete_int!(self, span, [arg], |e| {
1221 #[allow(clippy::useless_conversion)]
1222 Ok([e
1223 .count_ones()
1224 .try_into()
1225 .expect("bit count overflowed 32 bits, somehow!?")])
1226 })
1227 }
1228 crate::MathFunction::ReverseBits => {
1229 component_wise_concrete_int!(self, span, [arg], |e| { Ok([e.reverse_bits()]) })
1230 }
1231
1232 fun => Err(ConstantEvaluatorError::NotImplemented(format!(
1233 "{fun:?} built-in function"
1234 ))),
1235 }
1236 }
1237
1238 fn array_length(
1239 &mut self,
1240 array: Handle<Expression>,
1241 span: Span,
1242 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1243 match self.expressions[array] {
1244 Expression::ZeroValue(ty) | Expression::Compose { ty, .. } => {
1245 match self.types[ty].inner {
1246 TypeInner::Array { size, .. } => match size {
1247 ArraySize::Constant(len) => {
1248 let expr = Expression::Literal(Literal::U32(len.get()));
1249 self.register_evaluated_expr(expr, span)
1250 }
1251 ArraySize::Dynamic => Err(ConstantEvaluatorError::ArrayLengthDynamic),
1252 },
1253 _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
1254 }
1255 }
1256 _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
1257 }
1258 }
1259
1260 fn access(
1261 &mut self,
1262 base: Handle<Expression>,
1263 index: usize,
1264 span: Span,
1265 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1266 match self.expressions[base] {
1267 Expression::ZeroValue(ty) => {
1268 let ty_inner = &self.types[ty].inner;
1269 let components = ty_inner
1270 .components()
1271 .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
1272
1273 if index >= components as usize {
1274 Err(ConstantEvaluatorError::InvalidAccessBase)
1275 } else {
1276 let ty_res = ty_inner
1277 .component_type(index)
1278 .ok_or(ConstantEvaluatorError::InvalidAccessIndex)?;
1279 let ty = match ty_res {
1280 crate::proc::TypeResolution::Handle(ty) => ty,
1281 crate::proc::TypeResolution::Value(inner) => {
1282 self.types.insert(Type { name: None, inner }, span)
1283 }
1284 };
1285 self.register_evaluated_expr(Expression::ZeroValue(ty), span)
1286 }
1287 }
1288 Expression::Splat { size, value } => {
1289 if index >= size as usize {
1290 Err(ConstantEvaluatorError::InvalidAccessBase)
1291 } else {
1292 Ok(value)
1293 }
1294 }
1295 Expression::Compose { ty, ref components } => {
1296 let _ = self.types[ty]
1297 .inner
1298 .components()
1299 .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
1300
1301 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1302 .nth(index)
1303 .ok_or(ConstantEvaluatorError::InvalidAccessIndex)
1304 }
1305 _ => Err(ConstantEvaluatorError::InvalidAccessBase),
1306 }
1307 }
1308
1309 fn constant_index(&self, expr: Handle<Expression>) -> Result<usize, ConstantEvaluatorError> {
1310 match self.expressions[expr] {
1311 Expression::ZeroValue(ty)
1312 if matches!(
1313 self.types[ty].inner,
1314 TypeInner::Scalar(crate::Scalar {
1315 kind: ScalarKind::Uint,
1316 ..
1317 })
1318 ) =>
1319 {
1320 Ok(0)
1321 }
1322 Expression::Literal(Literal::U32(index)) => Ok(index as usize),
1323 _ => Err(ConstantEvaluatorError::InvalidAccessIndexTy),
1324 }
1325 }
1326
1327 fn eval_zero_value_and_splat(
1334 &mut self,
1335 expr: Handle<Expression>,
1336 span: Span,
1337 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1338 match self.expressions[expr] {
1339 Expression::ZeroValue(ty) => self.eval_zero_value_impl(ty, span),
1340 Expression::Splat { size, value } => self.splat(value, size, span),
1341 _ => Ok(expr),
1342 }
1343 }
1344
1345 fn eval_zero_value(
1351 &mut self,
1352 expr: Handle<Expression>,
1353 span: Span,
1354 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1355 match self.expressions[expr] {
1356 Expression::ZeroValue(ty) => self.eval_zero_value_impl(ty, span),
1357 _ => Ok(expr),
1358 }
1359 }
1360
1361 fn eval_zero_value_impl(
1367 &mut self,
1368 ty: Handle<Type>,
1369 span: Span,
1370 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1371 match self.types[ty].inner {
1372 TypeInner::Scalar(scalar) => {
1373 let expr = Expression::Literal(
1374 Literal::zero(scalar).ok_or(ConstantEvaluatorError::TypeNotConstructible)?,
1375 );
1376 self.register_evaluated_expr(expr, span)
1377 }
1378 TypeInner::Vector { size, scalar } => {
1379 let scalar_ty = self.types.insert(
1380 Type {
1381 name: None,
1382 inner: TypeInner::Scalar(scalar),
1383 },
1384 span,
1385 );
1386 let el = self.eval_zero_value_impl(scalar_ty, span)?;
1387 let expr = Expression::Compose {
1388 ty,
1389 components: vec![el; size as usize],
1390 };
1391 self.register_evaluated_expr(expr, span)
1392 }
1393 TypeInner::Matrix {
1394 columns,
1395 rows,
1396 scalar,
1397 } => {
1398 let vec_ty = self.types.insert(
1399 Type {
1400 name: None,
1401 inner: TypeInner::Vector { size: rows, scalar },
1402 },
1403 span,
1404 );
1405 let el = self.eval_zero_value_impl(vec_ty, span)?;
1406 let expr = Expression::Compose {
1407 ty,
1408 components: vec![el; columns as usize],
1409 };
1410 self.register_evaluated_expr(expr, span)
1411 }
1412 TypeInner::Array {
1413 base,
1414 size: ArraySize::Constant(size),
1415 ..
1416 } => {
1417 let el = self.eval_zero_value_impl(base, span)?;
1418 let expr = Expression::Compose {
1419 ty,
1420 components: vec![el; size.get() as usize],
1421 };
1422 self.register_evaluated_expr(expr, span)
1423 }
1424 TypeInner::Struct { ref members, .. } => {
1425 let types: Vec<_> = members.iter().map(|m| m.ty).collect();
1426 let mut components = Vec::with_capacity(members.len());
1427 for ty in types {
1428 components.push(self.eval_zero_value_impl(ty, span)?);
1429 }
1430 let expr = Expression::Compose { ty, components };
1431 self.register_evaluated_expr(expr, span)
1432 }
1433 _ => Err(ConstantEvaluatorError::TypeNotConstructible),
1434 }
1435 }
1436
1437 pub fn cast(
1441 &mut self,
1442 expr: Handle<Expression>,
1443 target: crate::Scalar,
1444 span: Span,
1445 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1446 use crate::Scalar as Sc;
1447
1448 let expr = self.eval_zero_value(expr, span)?;
1449
1450 let make_error = || -> Result<_, ConstantEvaluatorError> {
1451 let from = format!("{:?} {:?}", expr, self.expressions[expr]);
1452
1453 #[cfg(feature = "wgsl-in")]
1454 let to = target.to_wgsl();
1455
1456 #[cfg(not(feature = "wgsl-in"))]
1457 let to = format!("{target:?}");
1458
1459 Err(ConstantEvaluatorError::InvalidCastArg { from, to })
1460 };
1461
1462 let expr = match self.expressions[expr] {
1463 Expression::Literal(literal) => {
1464 let literal = match target {
1465 Sc::I32 => Literal::I32(match literal {
1466 Literal::I32(v) => v,
1467 Literal::U32(v) => v as i32,
1468 Literal::F32(v) => v as i32,
1469 Literal::Bool(v) => v as i32,
1470 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
1471 return make_error();
1472 }
1473 Literal::AbstractInt(v) => i32::try_from_abstract(v)?,
1474 Literal::AbstractFloat(v) => i32::try_from_abstract(v)?,
1475 }),
1476 Sc::U32 => Literal::U32(match literal {
1477 Literal::I32(v) => v as u32,
1478 Literal::U32(v) => v,
1479 Literal::F32(v) => v as u32,
1480 Literal::Bool(v) => v as u32,
1481 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
1482 return make_error();
1483 }
1484 Literal::AbstractInt(v) => u32::try_from_abstract(v)?,
1485 Literal::AbstractFloat(v) => u32::try_from_abstract(v)?,
1486 }),
1487 Sc::I64 => Literal::I64(match literal {
1488 Literal::I32(v) => v as i64,
1489 Literal::U32(v) => v as i64,
1490 Literal::F32(v) => v as i64,
1491 Literal::Bool(v) => v as i64,
1492 Literal::F64(v) => v as i64,
1493 Literal::I64(v) => v,
1494 Literal::U64(v) => v as i64,
1495 Literal::AbstractInt(v) => i64::try_from_abstract(v)?,
1496 Literal::AbstractFloat(v) => i64::try_from_abstract(v)?,
1497 }),
1498 Sc::U64 => Literal::U64(match literal {
1499 Literal::I32(v) => v as u64,
1500 Literal::U32(v) => v as u64,
1501 Literal::F32(v) => v as u64,
1502 Literal::Bool(v) => v as u64,
1503 Literal::F64(v) => v as u64,
1504 Literal::I64(v) => v as u64,
1505 Literal::U64(v) => v,
1506 Literal::AbstractInt(v) => u64::try_from_abstract(v)?,
1507 Literal::AbstractFloat(v) => u64::try_from_abstract(v)?,
1508 }),
1509 Sc::F32 => Literal::F32(match literal {
1510 Literal::I32(v) => v as f32,
1511 Literal::U32(v) => v as f32,
1512 Literal::F32(v) => v,
1513 Literal::Bool(v) => v as u32 as f32,
1514 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
1515 return make_error();
1516 }
1517 Literal::AbstractInt(v) => f32::try_from_abstract(v)?,
1518 Literal::AbstractFloat(v) => f32::try_from_abstract(v)?,
1519 }),
1520 Sc::F64 => Literal::F64(match literal {
1521 Literal::I32(v) => v as f64,
1522 Literal::U32(v) => v as f64,
1523 Literal::F32(v) => v as f64,
1524 Literal::F64(v) => v,
1525 Literal::Bool(v) => v as u32 as f64,
1526 Literal::I64(_) | Literal::U64(_) => return make_error(),
1527 Literal::AbstractInt(v) => f64::try_from_abstract(v)?,
1528 Literal::AbstractFloat(v) => f64::try_from_abstract(v)?,
1529 }),
1530 Sc::BOOL => Literal::Bool(match literal {
1531 Literal::I32(v) => v != 0,
1532 Literal::U32(v) => v != 0,
1533 Literal::F32(v) => v != 0.0,
1534 Literal::Bool(v) => v,
1535 Literal::F64(_)
1536 | Literal::I64(_)
1537 | Literal::U64(_)
1538 | Literal::AbstractInt(_)
1539 | Literal::AbstractFloat(_) => {
1540 return make_error();
1541 }
1542 }),
1543 Sc::ABSTRACT_FLOAT => Literal::AbstractFloat(match literal {
1544 Literal::AbstractInt(v) => {
1545 v as f64
1550 }
1551 Literal::AbstractFloat(v) => v,
1552 _ => return make_error(),
1553 }),
1554 _ => {
1555 log::debug!("Constant evaluator refused to convert value to {target:?}");
1556 return make_error();
1557 }
1558 };
1559 Expression::Literal(literal)
1560 }
1561 Expression::Compose {
1562 ty,
1563 components: ref src_components,
1564 } => {
1565 let ty_inner = match self.types[ty].inner {
1566 TypeInner::Vector { size, .. } => TypeInner::Vector {
1567 size,
1568 scalar: target,
1569 },
1570 TypeInner::Matrix { columns, rows, .. } => TypeInner::Matrix {
1571 columns,
1572 rows,
1573 scalar: target,
1574 },
1575 _ => return make_error(),
1576 };
1577
1578 let mut components = src_components.clone();
1579 for component in &mut components {
1580 *component = self.cast(*component, target, span)?;
1581 }
1582
1583 let ty = self.types.insert(
1584 Type {
1585 name: None,
1586 inner: ty_inner,
1587 },
1588 span,
1589 );
1590
1591 Expression::Compose { ty, components }
1592 }
1593 Expression::Splat { size, value } => {
1594 let value_span = self.expressions.get_span(value);
1595 let cast_value = self.cast(value, target, value_span)?;
1596 Expression::Splat {
1597 size,
1598 value: cast_value,
1599 }
1600 }
1601 _ => return make_error(),
1602 };
1603
1604 self.register_evaluated_expr(expr, span)
1605 }
1606
1607 pub fn cast_array(
1620 &mut self,
1621 expr: Handle<Expression>,
1622 target: crate::Scalar,
1623 span: Span,
1624 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1625 let Expression::Compose { ty, ref components } = self.expressions[expr] else {
1626 return self.cast(expr, target, span);
1627 };
1628
1629 let TypeInner::Array {
1630 base: _,
1631 size,
1632 stride: _,
1633 } = self.types[ty].inner
1634 else {
1635 return self.cast(expr, target, span);
1636 };
1637
1638 let mut components = components.clone();
1639 for component in &mut components {
1640 *component = self.cast_array(*component, target, span)?;
1641 }
1642
1643 let first = components.first().unwrap();
1644 let new_base = match self.resolve_type(*first)? {
1645 crate::proc::TypeResolution::Handle(ty) => ty,
1646 crate::proc::TypeResolution::Value(inner) => {
1647 self.types.insert(Type { name: None, inner }, span)
1648 }
1649 };
1650 let new_base_stride = self.types[new_base].inner.size(self.to_ctx());
1651 let new_array_ty = self.types.insert(
1652 Type {
1653 name: None,
1654 inner: TypeInner::Array {
1655 base: new_base,
1656 size,
1657 stride: new_base_stride,
1658 },
1659 },
1660 span,
1661 );
1662
1663 let compose = Expression::Compose {
1664 ty: new_array_ty,
1665 components,
1666 };
1667 self.register_evaluated_expr(compose, span)
1668 }
1669
1670 fn unary_op(
1671 &mut self,
1672 op: UnaryOperator,
1673 expr: Handle<Expression>,
1674 span: Span,
1675 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1676 let expr = self.eval_zero_value_and_splat(expr, span)?;
1677
1678 let expr = match self.expressions[expr] {
1679 Expression::Literal(value) => Expression::Literal(match op {
1680 UnaryOperator::Negate => match value {
1681 Literal::I32(v) => Literal::I32(v.wrapping_neg()),
1682 Literal::F32(v) => Literal::F32(-v),
1683 Literal::AbstractInt(v) => Literal::AbstractInt(v.wrapping_neg()),
1684 Literal::AbstractFloat(v) => Literal::AbstractFloat(-v),
1685 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
1686 },
1687 UnaryOperator::LogicalNot => match value {
1688 Literal::Bool(v) => Literal::Bool(!v),
1689 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
1690 },
1691 UnaryOperator::BitwiseNot => match value {
1692 Literal::I32(v) => Literal::I32(!v),
1693 Literal::U32(v) => Literal::U32(!v),
1694 Literal::AbstractInt(v) => Literal::AbstractInt(!v),
1695 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
1696 },
1697 }),
1698 Expression::Compose {
1699 ty,
1700 components: ref src_components,
1701 } => {
1702 match self.types[ty].inner {
1703 TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (),
1704 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
1705 }
1706
1707 let mut components = src_components.clone();
1708 for component in &mut components {
1709 *component = self.unary_op(op, *component, span)?;
1710 }
1711
1712 Expression::Compose { ty, components }
1713 }
1714 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
1715 };
1716
1717 self.register_evaluated_expr(expr, span)
1718 }
1719
1720 fn binary_op(
1721 &mut self,
1722 op: BinaryOperator,
1723 left: Handle<Expression>,
1724 right: Handle<Expression>,
1725 span: Span,
1726 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1727 let left = self.eval_zero_value_and_splat(left, span)?;
1728 let right = self.eval_zero_value_and_splat(right, span)?;
1729
1730 let expr = match (&self.expressions[left], &self.expressions[right]) {
1731 (&Expression::Literal(left_value), &Expression::Literal(right_value)) => {
1732 let literal = match op {
1733 BinaryOperator::Equal => Literal::Bool(left_value == right_value),
1734 BinaryOperator::NotEqual => Literal::Bool(left_value != right_value),
1735 BinaryOperator::Less => Literal::Bool(left_value < right_value),
1736 BinaryOperator::LessEqual => Literal::Bool(left_value <= right_value),
1737 BinaryOperator::Greater => Literal::Bool(left_value > right_value),
1738 BinaryOperator::GreaterEqual => Literal::Bool(left_value >= right_value),
1739
1740 _ => match (left_value, right_value) {
1741 (Literal::I32(a), Literal::I32(b)) => Literal::I32(match op {
1742 BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
1743 ConstantEvaluatorError::Overflow("addition".into())
1744 })?,
1745 BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
1746 ConstantEvaluatorError::Overflow("subtraction".into())
1747 })?,
1748 BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
1749 ConstantEvaluatorError::Overflow("multiplication".into())
1750 })?,
1751 BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
1752 if b == 0 {
1753 ConstantEvaluatorError::DivisionByZero
1754 } else {
1755 ConstantEvaluatorError::Overflow("division".into())
1756 }
1757 })?,
1758 BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
1759 if b == 0 {
1760 ConstantEvaluatorError::RemainderByZero
1761 } else {
1762 ConstantEvaluatorError::Overflow("remainder".into())
1763 }
1764 })?,
1765 BinaryOperator::And => a & b,
1766 BinaryOperator::ExclusiveOr => a ^ b,
1767 BinaryOperator::InclusiveOr => a | b,
1768 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1769 }),
1770 (Literal::I32(a), Literal::U32(b)) => Literal::I32(match op {
1771 BinaryOperator::ShiftLeft => a
1772 .checked_shl(b)
1773 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
1774 BinaryOperator::ShiftRight => a
1775 .checked_shr(b)
1776 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
1777 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1778 }),
1779 (Literal::U32(a), Literal::U32(b)) => Literal::U32(match op {
1780 BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
1781 ConstantEvaluatorError::Overflow("addition".into())
1782 })?,
1783 BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
1784 ConstantEvaluatorError::Overflow("subtraction".into())
1785 })?,
1786 BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
1787 ConstantEvaluatorError::Overflow("multiplication".into())
1788 })?,
1789 BinaryOperator::Divide => a
1790 .checked_div(b)
1791 .ok_or(ConstantEvaluatorError::DivisionByZero)?,
1792 BinaryOperator::Modulo => a
1793 .checked_rem(b)
1794 .ok_or(ConstantEvaluatorError::RemainderByZero)?,
1795 BinaryOperator::And => a & b,
1796 BinaryOperator::ExclusiveOr => a ^ b,
1797 BinaryOperator::InclusiveOr => a | b,
1798 BinaryOperator::ShiftLeft => a
1799 .checked_shl(b)
1800 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
1801 BinaryOperator::ShiftRight => a
1802 .checked_shr(b)
1803 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
1804 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1805 }),
1806 (Literal::F32(a), Literal::F32(b)) => Literal::F32(match op {
1807 BinaryOperator::Add => a + b,
1808 BinaryOperator::Subtract => a - b,
1809 BinaryOperator::Multiply => a * b,
1810 BinaryOperator::Divide => a / b,
1811 BinaryOperator::Modulo => a % b,
1812 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1813 }),
1814 (Literal::AbstractInt(a), Literal::AbstractInt(b)) => {
1815 Literal::AbstractInt(match op {
1816 BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
1817 ConstantEvaluatorError::Overflow("addition".into())
1818 })?,
1819 BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
1820 ConstantEvaluatorError::Overflow("subtraction".into())
1821 })?,
1822 BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
1823 ConstantEvaluatorError::Overflow("multiplication".into())
1824 })?,
1825 BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
1826 if b == 0 {
1827 ConstantEvaluatorError::DivisionByZero
1828 } else {
1829 ConstantEvaluatorError::Overflow("division".into())
1830 }
1831 })?,
1832 BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
1833 if b == 0 {
1834 ConstantEvaluatorError::RemainderByZero
1835 } else {
1836 ConstantEvaluatorError::Overflow("remainder".into())
1837 }
1838 })?,
1839 BinaryOperator::And => a & b,
1840 BinaryOperator::ExclusiveOr => a ^ b,
1841 BinaryOperator::InclusiveOr => a | b,
1842 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1843 })
1844 }
1845 (Literal::AbstractFloat(a), Literal::AbstractFloat(b)) => {
1846 Literal::AbstractFloat(match op {
1847 BinaryOperator::Add => a + b,
1848 BinaryOperator::Subtract => a - b,
1849 BinaryOperator::Multiply => a * b,
1850 BinaryOperator::Divide => a / b,
1851 BinaryOperator::Modulo => a % b,
1852 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1853 })
1854 }
1855 (Literal::Bool(a), Literal::Bool(b)) => Literal::Bool(match op {
1856 BinaryOperator::LogicalAnd => a && b,
1857 BinaryOperator::LogicalOr => a || b,
1858 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1859 }),
1860 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1861 },
1862 };
1863 Expression::Literal(literal)
1864 }
1865 (
1866 &Expression::Compose {
1867 components: ref src_components,
1868 ty,
1869 },
1870 &Expression::Literal(_),
1871 ) => {
1872 let mut components = src_components.clone();
1873 for component in &mut components {
1874 *component = self.binary_op(op, *component, right, span)?;
1875 }
1876 Expression::Compose { ty, components }
1877 }
1878 (
1879 &Expression::Literal(_),
1880 &Expression::Compose {
1881 components: ref src_components,
1882 ty,
1883 },
1884 ) => {
1885 let mut components = src_components.clone();
1886 for component in &mut components {
1887 *component = self.binary_op(op, left, *component, span)?;
1888 }
1889 Expression::Compose { ty, components }
1890 }
1891 (
1892 &Expression::Compose {
1893 components: ref left_components,
1894 ty: left_ty,
1895 },
1896 &Expression::Compose {
1897 components: ref right_components,
1898 ty: right_ty,
1899 },
1900 ) => {
1901 let left_flattened = crate::proc::flatten_compose(
1905 left_ty,
1906 left_components,
1907 self.expressions,
1908 self.types,
1909 );
1910 let right_flattened = crate::proc::flatten_compose(
1911 right_ty,
1912 right_components,
1913 self.expressions,
1914 self.types,
1915 );
1916
1917 let mut flattened = Vec::with_capacity(left_components.len());
1920 flattened.extend(left_flattened.zip(right_flattened));
1921
1922 match (&self.types[left_ty].inner, &self.types[right_ty].inner) {
1923 (
1924 &TypeInner::Vector {
1925 size: left_size, ..
1926 },
1927 &TypeInner::Vector {
1928 size: right_size, ..
1929 },
1930 ) if left_size == right_size => {
1931 self.binary_op_vector(op, left_size, &flattened, left_ty, span)?
1932 }
1933 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1934 }
1935 }
1936 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1937 };
1938
1939 self.register_evaluated_expr(expr, span)
1940 }
1941
1942 fn binary_op_vector(
1943 &mut self,
1944 op: BinaryOperator,
1945 size: crate::VectorSize,
1946 components: &[(Handle<Expression>, Handle<Expression>)],
1947 left_ty: Handle<Type>,
1948 span: Span,
1949 ) -> Result<Expression, ConstantEvaluatorError> {
1950 let ty = match op {
1951 BinaryOperator::Equal
1953 | BinaryOperator::NotEqual
1954 | BinaryOperator::Less
1955 | BinaryOperator::LessEqual
1956 | BinaryOperator::Greater
1957 | BinaryOperator::GreaterEqual => self.types.insert(
1958 Type {
1959 name: None,
1960 inner: TypeInner::Vector {
1961 size,
1962 scalar: crate::Scalar::BOOL,
1963 },
1964 },
1965 span,
1966 ),
1967
1968 BinaryOperator::Add
1971 | BinaryOperator::Subtract
1972 | BinaryOperator::Multiply
1973 | BinaryOperator::Divide
1974 | BinaryOperator::Modulo
1975 | BinaryOperator::And
1976 | BinaryOperator::ExclusiveOr
1977 | BinaryOperator::InclusiveOr
1978 | BinaryOperator::LogicalAnd
1979 | BinaryOperator::LogicalOr
1980 | BinaryOperator::ShiftLeft
1981 | BinaryOperator::ShiftRight => left_ty,
1982 };
1983
1984 let components = components
1985 .iter()
1986 .map(|&(left, right)| self.binary_op(op, left, right, span))
1987 .collect::<Result<Vec<_>, _>>()?;
1988
1989 Ok(Expression::Compose { ty, components })
1990 }
1991
1992 fn copy_from(
2000 &mut self,
2001 expr: Handle<Expression>,
2002 expressions: &Arena<Expression>,
2003 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2004 let span = expressions.get_span(expr);
2005 match expressions[expr] {
2006 ref expr @ (Expression::Literal(_)
2007 | Expression::Constant(_)
2008 | Expression::ZeroValue(_)) => self.register_evaluated_expr(expr.clone(), span),
2009 Expression::Compose { ty, ref components } => {
2010 let mut components = components.clone();
2011 for component in &mut components {
2012 *component = self.copy_from(*component, expressions)?;
2013 }
2014 self.register_evaluated_expr(Expression::Compose { ty, components }, span)
2015 }
2016 Expression::Splat { size, value } => {
2017 let value = self.copy_from(value, expressions)?;
2018 self.register_evaluated_expr(Expression::Splat { size, value }, span)
2019 }
2020 _ => {
2021 log::debug!("copy_from: SubexpressionsAreNotConstant");
2022 Err(ConstantEvaluatorError::SubexpressionsAreNotConstant)
2023 }
2024 }
2025 }
2026
2027 fn register_evaluated_expr(
2028 &mut self,
2029 expr: Expression,
2030 span: Span,
2031 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2032 if let Expression::Literal(literal) = expr {
2036 crate::valid::check_literal_value(literal)?;
2037 }
2038
2039 Ok(self.append_expr(expr, span, ExpressionKind::Const))
2040 }
2041
2042 fn append_expr(
2043 &mut self,
2044 expr: Expression,
2045 span: Span,
2046 expr_type: ExpressionKind,
2047 ) -> Handle<Expression> {
2048 let h = match self.behavior {
2049 Behavior::Wgsl(WgslRestrictions::Runtime(ref mut function_local_data))
2050 | Behavior::Glsl(GlslRestrictions::Runtime(ref mut function_local_data)) => {
2051 let is_running = function_local_data.emitter.is_running();
2052 let needs_pre_emit = expr.needs_pre_emit();
2053 if is_running && needs_pre_emit {
2054 function_local_data
2055 .block
2056 .extend(function_local_data.emitter.finish(self.expressions));
2057 let h = self.expressions.append(expr, span);
2058 function_local_data.emitter.start(self.expressions);
2059 h
2060 } else {
2061 self.expressions.append(expr, span)
2062 }
2063 }
2064 _ => self.expressions.append(expr, span),
2065 };
2066 self.expression_kind_tracker.insert(h, expr_type);
2067 h
2068 }
2069
2070 fn resolve_type(
2071 &self,
2072 expr: Handle<Expression>,
2073 ) -> Result<crate::proc::TypeResolution, ConstantEvaluatorError> {
2074 use crate::proc::TypeResolution as Tr;
2075 use crate::Expression as Ex;
2076 let resolution = match self.expressions[expr] {
2077 Ex::Literal(ref literal) => Tr::Value(literal.ty_inner()),
2078 Ex::Constant(c) => Tr::Handle(self.constants[c].ty),
2079 Ex::ZeroValue(ty) | Ex::Compose { ty, .. } => Tr::Handle(ty),
2080 Ex::Splat { size, value } => {
2081 let Tr::Value(TypeInner::Scalar(scalar)) = self.resolve_type(value)? else {
2082 return Err(ConstantEvaluatorError::SplatScalarOnly);
2083 };
2084 Tr::Value(TypeInner::Vector { scalar, size })
2085 }
2086 _ => {
2087 log::debug!("resolve_type: SubexpressionsAreNotConstant");
2088 return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
2089 }
2090 };
2091
2092 Ok(resolution)
2093 }
2094}
2095
2096trait TryFromAbstract<T>: Sized {
2098 fn try_from_abstract(value: T) -> Result<Self, ConstantEvaluatorError>;
2114}
2115
2116impl TryFromAbstract<i64> for i32 {
2117 fn try_from_abstract(value: i64) -> Result<i32, ConstantEvaluatorError> {
2118 i32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
2119 value: format!("{value:?}"),
2120 to_type: "i32",
2121 })
2122 }
2123}
2124
2125impl TryFromAbstract<i64> for u32 {
2126 fn try_from_abstract(value: i64) -> Result<u32, ConstantEvaluatorError> {
2127 u32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
2128 value: format!("{value:?}"),
2129 to_type: "u32",
2130 })
2131 }
2132}
2133
2134impl TryFromAbstract<i64> for u64 {
2135 fn try_from_abstract(value: i64) -> Result<u64, ConstantEvaluatorError> {
2136 u64::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
2137 value: format!("{value:?}"),
2138 to_type: "u64",
2139 })
2140 }
2141}
2142
2143impl TryFromAbstract<i64> for i64 {
2144 fn try_from_abstract(value: i64) -> Result<i64, ConstantEvaluatorError> {
2145 Ok(value)
2146 }
2147}
2148
2149impl TryFromAbstract<i64> for f32 {
2150 fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
2151 let f = value as f32;
2152 Ok(f)
2156 }
2157}
2158
2159impl TryFromAbstract<f64> for f32 {
2160 fn try_from_abstract(value: f64) -> Result<f32, ConstantEvaluatorError> {
2161 let f = value as f32;
2162 if f.is_infinite() {
2163 return Err(ConstantEvaluatorError::AutomaticConversionLossy {
2164 value: format!("{value:?}"),
2165 to_type: "f32",
2166 });
2167 }
2168 Ok(f)
2169 }
2170}
2171
2172impl TryFromAbstract<i64> for f64 {
2173 fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
2174 let f = value as f64;
2175 Ok(f)
2179 }
2180}
2181
2182impl TryFromAbstract<f64> for f64 {
2183 fn try_from_abstract(value: f64) -> Result<f64, ConstantEvaluatorError> {
2184 Ok(value)
2185 }
2186}
2187
2188impl TryFromAbstract<f64> for i32 {
2189 fn try_from_abstract(_: f64) -> Result<Self, ConstantEvaluatorError> {
2190 Err(ConstantEvaluatorError::AutomaticConversionFloatToInt { to_type: "i32" })
2191 }
2192}
2193
2194impl TryFromAbstract<f64> for u32 {
2195 fn try_from_abstract(_: f64) -> Result<Self, ConstantEvaluatorError> {
2196 Err(ConstantEvaluatorError::AutomaticConversionFloatToInt { to_type: "u32" })
2197 }
2198}
2199
2200impl TryFromAbstract<f64> for i64 {
2201 fn try_from_abstract(_: f64) -> Result<Self, ConstantEvaluatorError> {
2202 Err(ConstantEvaluatorError::AutomaticConversionFloatToInt { to_type: "i64" })
2203 }
2204}
2205
2206impl TryFromAbstract<f64> for u64 {
2207 fn try_from_abstract(_: f64) -> Result<Self, ConstantEvaluatorError> {
2208 Err(ConstantEvaluatorError::AutomaticConversionFloatToInt { to_type: "u64" })
2209 }
2210}
2211
2212#[cfg(test)]
2213mod tests {
2214 use std::vec;
2215
2216 use crate::{
2217 Arena, Constant, Expression, Literal, ScalarKind, Type, TypeInner, UnaryOperator,
2218 UniqueArena, VectorSize,
2219 };
2220
2221 use super::{Behavior, ConstantEvaluator, ExpressionKindTracker, WgslRestrictions};
2222
2223 #[test]
2224 fn unary_op() {
2225 let mut types = UniqueArena::new();
2226 let mut constants = Arena::new();
2227 let overrides = Arena::new();
2228 let mut global_expressions = Arena::new();
2229
2230 let scalar_ty = types.insert(
2231 Type {
2232 name: None,
2233 inner: TypeInner::Scalar(crate::Scalar::I32),
2234 },
2235 Default::default(),
2236 );
2237
2238 let vec_ty = types.insert(
2239 Type {
2240 name: None,
2241 inner: TypeInner::Vector {
2242 size: VectorSize::Bi,
2243 scalar: crate::Scalar::I32,
2244 },
2245 },
2246 Default::default(),
2247 );
2248
2249 let h = constants.append(
2250 Constant {
2251 name: None,
2252 ty: scalar_ty,
2253 init: global_expressions
2254 .append(Expression::Literal(Literal::I32(4)), Default::default()),
2255 },
2256 Default::default(),
2257 );
2258
2259 let h1 = constants.append(
2260 Constant {
2261 name: None,
2262 ty: scalar_ty,
2263 init: global_expressions
2264 .append(Expression::Literal(Literal::I32(8)), Default::default()),
2265 },
2266 Default::default(),
2267 );
2268
2269 let vec_h = constants.append(
2270 Constant {
2271 name: None,
2272 ty: vec_ty,
2273 init: global_expressions.append(
2274 Expression::Compose {
2275 ty: vec_ty,
2276 components: vec![constants[h].init, constants[h1].init],
2277 },
2278 Default::default(),
2279 ),
2280 },
2281 Default::default(),
2282 );
2283
2284 let expr = global_expressions.append(Expression::Constant(h), Default::default());
2285 let expr1 = global_expressions.append(Expression::Constant(vec_h), Default::default());
2286
2287 let expr2 = Expression::Unary {
2288 op: UnaryOperator::Negate,
2289 expr,
2290 };
2291
2292 let expr3 = Expression::Unary {
2293 op: UnaryOperator::BitwiseNot,
2294 expr,
2295 };
2296
2297 let expr4 = Expression::Unary {
2298 op: UnaryOperator::BitwiseNot,
2299 expr: expr1,
2300 };
2301
2302 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
2303 let mut solver = ConstantEvaluator {
2304 behavior: Behavior::Wgsl(WgslRestrictions::Const),
2305 types: &mut types,
2306 constants: &constants,
2307 overrides: &overrides,
2308 expressions: &mut global_expressions,
2309 expression_kind_tracker,
2310 };
2311
2312 let res1 = solver
2313 .try_eval_and_append(expr2, Default::default())
2314 .unwrap();
2315 let res2 = solver
2316 .try_eval_and_append(expr3, Default::default())
2317 .unwrap();
2318 let res3 = solver
2319 .try_eval_and_append(expr4, Default::default())
2320 .unwrap();
2321
2322 assert_eq!(
2323 global_expressions[res1],
2324 Expression::Literal(Literal::I32(-4))
2325 );
2326
2327 assert_eq!(
2328 global_expressions[res2],
2329 Expression::Literal(Literal::I32(!4))
2330 );
2331
2332 let res3_inner = &global_expressions[res3];
2333
2334 match *res3_inner {
2335 Expression::Compose {
2336 ref ty,
2337 ref components,
2338 } => {
2339 assert_eq!(*ty, vec_ty);
2340 let mut components_iter = components.iter().copied();
2341 assert_eq!(
2342 global_expressions[components_iter.next().unwrap()],
2343 Expression::Literal(Literal::I32(!4))
2344 );
2345 assert_eq!(
2346 global_expressions[components_iter.next().unwrap()],
2347 Expression::Literal(Literal::I32(!8))
2348 );
2349 assert!(components_iter.next().is_none());
2350 }
2351 _ => panic!("Expected vector"),
2352 }
2353 }
2354
2355 #[test]
2356 fn cast() {
2357 let mut types = UniqueArena::new();
2358 let mut constants = Arena::new();
2359 let overrides = Arena::new();
2360 let mut global_expressions = Arena::new();
2361
2362 let scalar_ty = types.insert(
2363 Type {
2364 name: None,
2365 inner: TypeInner::Scalar(crate::Scalar::I32),
2366 },
2367 Default::default(),
2368 );
2369
2370 let h = constants.append(
2371 Constant {
2372 name: None,
2373 ty: scalar_ty,
2374 init: global_expressions
2375 .append(Expression::Literal(Literal::I32(4)), Default::default()),
2376 },
2377 Default::default(),
2378 );
2379
2380 let expr = global_expressions.append(Expression::Constant(h), Default::default());
2381
2382 let root = Expression::As {
2383 expr,
2384 kind: ScalarKind::Bool,
2385 convert: Some(crate::BOOL_WIDTH),
2386 };
2387
2388 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
2389 let mut solver = ConstantEvaluator {
2390 behavior: Behavior::Wgsl(WgslRestrictions::Const),
2391 types: &mut types,
2392 constants: &constants,
2393 overrides: &overrides,
2394 expressions: &mut global_expressions,
2395 expression_kind_tracker,
2396 };
2397
2398 let res = solver
2399 .try_eval_and_append(root, Default::default())
2400 .unwrap();
2401
2402 assert_eq!(
2403 global_expressions[res],
2404 Expression::Literal(Literal::Bool(true))
2405 );
2406 }
2407
2408 #[test]
2409 fn access() {
2410 let mut types = UniqueArena::new();
2411 let mut constants = Arena::new();
2412 let overrides = Arena::new();
2413 let mut global_expressions = Arena::new();
2414
2415 let matrix_ty = types.insert(
2416 Type {
2417 name: None,
2418 inner: TypeInner::Matrix {
2419 columns: VectorSize::Bi,
2420 rows: VectorSize::Tri,
2421 scalar: crate::Scalar::F32,
2422 },
2423 },
2424 Default::default(),
2425 );
2426
2427 let vec_ty = types.insert(
2428 Type {
2429 name: None,
2430 inner: TypeInner::Vector {
2431 size: VectorSize::Tri,
2432 scalar: crate::Scalar::F32,
2433 },
2434 },
2435 Default::default(),
2436 );
2437
2438 let mut vec1_components = Vec::with_capacity(3);
2439 let mut vec2_components = Vec::with_capacity(3);
2440
2441 for i in 0..3 {
2442 let h = global_expressions.append(
2443 Expression::Literal(Literal::F32(i as f32)),
2444 Default::default(),
2445 );
2446
2447 vec1_components.push(h)
2448 }
2449
2450 for i in 3..6 {
2451 let h = global_expressions.append(
2452 Expression::Literal(Literal::F32(i as f32)),
2453 Default::default(),
2454 );
2455
2456 vec2_components.push(h)
2457 }
2458
2459 let vec1 = constants.append(
2460 Constant {
2461 name: None,
2462 ty: vec_ty,
2463 init: global_expressions.append(
2464 Expression::Compose {
2465 ty: vec_ty,
2466 components: vec1_components,
2467 },
2468 Default::default(),
2469 ),
2470 },
2471 Default::default(),
2472 );
2473
2474 let vec2 = constants.append(
2475 Constant {
2476 name: None,
2477 ty: vec_ty,
2478 init: global_expressions.append(
2479 Expression::Compose {
2480 ty: vec_ty,
2481 components: vec2_components,
2482 },
2483 Default::default(),
2484 ),
2485 },
2486 Default::default(),
2487 );
2488
2489 let h = constants.append(
2490 Constant {
2491 name: None,
2492 ty: matrix_ty,
2493 init: global_expressions.append(
2494 Expression::Compose {
2495 ty: matrix_ty,
2496 components: vec![constants[vec1].init, constants[vec2].init],
2497 },
2498 Default::default(),
2499 ),
2500 },
2501 Default::default(),
2502 );
2503
2504 let base = global_expressions.append(Expression::Constant(h), Default::default());
2505
2506 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
2507 let mut solver = ConstantEvaluator {
2508 behavior: Behavior::Wgsl(WgslRestrictions::Const),
2509 types: &mut types,
2510 constants: &constants,
2511 overrides: &overrides,
2512 expressions: &mut global_expressions,
2513 expression_kind_tracker,
2514 };
2515
2516 let root1 = Expression::AccessIndex { base, index: 1 };
2517
2518 let res1 = solver
2519 .try_eval_and_append(root1, Default::default())
2520 .unwrap();
2521
2522 let root2 = Expression::AccessIndex {
2523 base: res1,
2524 index: 2,
2525 };
2526
2527 let res2 = solver
2528 .try_eval_and_append(root2, Default::default())
2529 .unwrap();
2530
2531 match global_expressions[res1] {
2532 Expression::Compose {
2533 ref ty,
2534 ref components,
2535 } => {
2536 assert_eq!(*ty, vec_ty);
2537 let mut components_iter = components.iter().copied();
2538 assert_eq!(
2539 global_expressions[components_iter.next().unwrap()],
2540 Expression::Literal(Literal::F32(3.))
2541 );
2542 assert_eq!(
2543 global_expressions[components_iter.next().unwrap()],
2544 Expression::Literal(Literal::F32(4.))
2545 );
2546 assert_eq!(
2547 global_expressions[components_iter.next().unwrap()],
2548 Expression::Literal(Literal::F32(5.))
2549 );
2550 assert!(components_iter.next().is_none());
2551 }
2552 _ => panic!("Expected vector"),
2553 }
2554
2555 assert_eq!(
2556 global_expressions[res2],
2557 Expression::Literal(Literal::F32(5.))
2558 );
2559 }
2560
2561 #[test]
2562 fn compose_of_constants() {
2563 let mut types = UniqueArena::new();
2564 let mut constants = Arena::new();
2565 let overrides = Arena::new();
2566 let mut global_expressions = Arena::new();
2567
2568 let i32_ty = types.insert(
2569 Type {
2570 name: None,
2571 inner: TypeInner::Scalar(crate::Scalar::I32),
2572 },
2573 Default::default(),
2574 );
2575
2576 let vec2_i32_ty = types.insert(
2577 Type {
2578 name: None,
2579 inner: TypeInner::Vector {
2580 size: VectorSize::Bi,
2581 scalar: crate::Scalar::I32,
2582 },
2583 },
2584 Default::default(),
2585 );
2586
2587 let h = constants.append(
2588 Constant {
2589 name: None,
2590 ty: i32_ty,
2591 init: global_expressions
2592 .append(Expression::Literal(Literal::I32(4)), Default::default()),
2593 },
2594 Default::default(),
2595 );
2596
2597 let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
2598
2599 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
2600 let mut solver = ConstantEvaluator {
2601 behavior: Behavior::Wgsl(WgslRestrictions::Const),
2602 types: &mut types,
2603 constants: &constants,
2604 overrides: &overrides,
2605 expressions: &mut global_expressions,
2606 expression_kind_tracker,
2607 };
2608
2609 let solved_compose = solver
2610 .try_eval_and_append(
2611 Expression::Compose {
2612 ty: vec2_i32_ty,
2613 components: vec![h_expr, h_expr],
2614 },
2615 Default::default(),
2616 )
2617 .unwrap();
2618 let solved_negate = solver
2619 .try_eval_and_append(
2620 Expression::Unary {
2621 op: UnaryOperator::Negate,
2622 expr: solved_compose,
2623 },
2624 Default::default(),
2625 )
2626 .unwrap();
2627
2628 let pass = match global_expressions[solved_negate] {
2629 Expression::Compose { ty, ref components } => {
2630 ty == vec2_i32_ty
2631 && components.iter().all(|&component| {
2632 let component = &global_expressions[component];
2633 matches!(*component, Expression::Literal(Literal::I32(-4)))
2634 })
2635 }
2636 _ => false,
2637 };
2638 if !pass {
2639 panic!("unexpected evaluation result")
2640 }
2641 }
2642
2643 #[test]
2644 fn splat_of_constant() {
2645 let mut types = UniqueArena::new();
2646 let mut constants = Arena::new();
2647 let overrides = Arena::new();
2648 let mut global_expressions = Arena::new();
2649
2650 let i32_ty = types.insert(
2651 Type {
2652 name: None,
2653 inner: TypeInner::Scalar(crate::Scalar::I32),
2654 },
2655 Default::default(),
2656 );
2657
2658 let vec2_i32_ty = types.insert(
2659 Type {
2660 name: None,
2661 inner: TypeInner::Vector {
2662 size: VectorSize::Bi,
2663 scalar: crate::Scalar::I32,
2664 },
2665 },
2666 Default::default(),
2667 );
2668
2669 let h = constants.append(
2670 Constant {
2671 name: None,
2672 ty: i32_ty,
2673 init: global_expressions
2674 .append(Expression::Literal(Literal::I32(4)), Default::default()),
2675 },
2676 Default::default(),
2677 );
2678
2679 let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
2680
2681 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
2682 let mut solver = ConstantEvaluator {
2683 behavior: Behavior::Wgsl(WgslRestrictions::Const),
2684 types: &mut types,
2685 constants: &constants,
2686 overrides: &overrides,
2687 expressions: &mut global_expressions,
2688 expression_kind_tracker,
2689 };
2690
2691 let solved_compose = solver
2692 .try_eval_and_append(
2693 Expression::Splat {
2694 size: VectorSize::Bi,
2695 value: h_expr,
2696 },
2697 Default::default(),
2698 )
2699 .unwrap();
2700 let solved_negate = solver
2701 .try_eval_and_append(
2702 Expression::Unary {
2703 op: UnaryOperator::Negate,
2704 expr: solved_compose,
2705 },
2706 Default::default(),
2707 )
2708 .unwrap();
2709
2710 let pass = match global_expressions[solved_negate] {
2711 Expression::Compose { ty, ref components } => {
2712 ty == vec2_i32_ty
2713 && components.iter().all(|&component| {
2714 let component = &global_expressions[component];
2715 matches!(*component, Expression::Literal(Literal::I32(-4)))
2716 })
2717 }
2718 _ => false,
2719 };
2720 if !pass {
2721 panic!("unexpected evaluation result")
2722 }
2723 }
2724}