1use super::{
2 compose::validate_compose, validate_atomic_compare_exchange_struct, FunctionInfo, ModuleInfo,
3 ShaderStages, TypeFlags,
4};
5use crate::arena::UniqueArena;
6
7use crate::{
8 arena::Handle,
9 proc::{IndexableLengthError, ResolveError},
10};
11
12#[derive(Clone, Debug, thiserror::Error)]
13#[cfg_attr(test, derive(PartialEq))]
14pub enum ExpressionError {
15 #[error("Doesn't exist")]
16 DoesntExist,
17 #[error("Used by a statement before it was introduced into the scope by any of the dominating blocks")]
18 NotInScope,
19 #[error("Base type {0:?} is not compatible with this expression")]
20 InvalidBaseType(Handle<crate::Expression>),
21 #[error("Accessing with index {0:?} can't be done")]
22 InvalidIndexType(Handle<crate::Expression>),
23 #[error("Accessing {0:?} via a negative index is invalid")]
24 NegativeIndex(Handle<crate::Expression>),
25 #[error("Accessing index {1} is out of {0:?} bounds")]
26 IndexOutOfBounds(Handle<crate::Expression>, u32),
27 #[error("The expression {0:?} may only be indexed by a constant")]
28 IndexMustBeConstant(Handle<crate::Expression>),
29 #[error("Function argument {0:?} doesn't exist")]
30 FunctionArgumentDoesntExist(u32),
31 #[error("Loading of {0:?} can't be done")]
32 InvalidPointerType(Handle<crate::Expression>),
33 #[error("Array length of {0:?} can't be done")]
34 InvalidArrayType(Handle<crate::Expression>),
35 #[error("Get intersection of {0:?} can't be done")]
36 InvalidRayQueryType(Handle<crate::Expression>),
37 #[error("Splatting {0:?} can't be done")]
38 InvalidSplatType(Handle<crate::Expression>),
39 #[error("Swizzling {0:?} can't be done")]
40 InvalidVectorType(Handle<crate::Expression>),
41 #[error("Swizzle component {0:?} is outside of vector size {1:?}")]
42 InvalidSwizzleComponent(crate::SwizzleComponent, crate::VectorSize),
43 #[error(transparent)]
44 Compose(#[from] super::ComposeError),
45 #[error(transparent)]
46 IndexableLength(#[from] IndexableLengthError),
47 #[error("Operation {0:?} can't work with {1:?}")]
48 InvalidUnaryOperandType(crate::UnaryOperator, Handle<crate::Expression>),
49 #[error("Operation {0:?} can't work with {1:?} and {2:?}")]
50 InvalidBinaryOperandTypes(
51 crate::BinaryOperator,
52 Handle<crate::Expression>,
53 Handle<crate::Expression>,
54 ),
55 #[error("Selecting is not possible")]
56 InvalidSelectTypes,
57 #[error("Relational argument {0:?} is not a boolean vector")]
58 InvalidBooleanVector(Handle<crate::Expression>),
59 #[error("Relational argument {0:?} is not a float")]
60 InvalidFloatArgument(Handle<crate::Expression>),
61 #[error("Type resolution failed")]
62 Type(#[from] ResolveError),
63 #[error("Not a global variable")]
64 ExpectedGlobalVariable,
65 #[error("Not a global variable or a function argument")]
66 ExpectedGlobalOrArgument,
67 #[error("Needs to be an binding array instead of {0:?}")]
68 ExpectedBindingArrayType(Handle<crate::Type>),
69 #[error("Needs to be an image instead of {0:?}")]
70 ExpectedImageType(Handle<crate::Type>),
71 #[error("Needs to be an image instead of {0:?}")]
72 ExpectedSamplerType(Handle<crate::Type>),
73 #[error("Unable to operate on image class {0:?}")]
74 InvalidImageClass(crate::ImageClass),
75 #[error("Derivatives can only be taken from scalar and vector floats")]
76 InvalidDerivative,
77 #[error("Image array index parameter is misplaced")]
78 InvalidImageArrayIndex,
79 #[error("Inappropriate sample or level-of-detail index for texel access")]
80 InvalidImageOtherIndex,
81 #[error("Image array index type of {0:?} is not an integer scalar")]
82 InvalidImageArrayIndexType(Handle<crate::Expression>),
83 #[error("Image sample or level-of-detail index's type of {0:?} is not an integer scalar")]
84 InvalidImageOtherIndexType(Handle<crate::Expression>),
85 #[error("Image coordinate type of {1:?} does not match dimension {0:?}")]
86 InvalidImageCoordinateType(crate::ImageDimension, Handle<crate::Expression>),
87 #[error("Comparison sampling mismatch: image has class {image:?}, but the sampler is comparison={sampler}, and the reference was provided={has_ref}")]
88 ComparisonSamplingMismatch {
89 image: crate::ImageClass,
90 sampler: bool,
91 has_ref: bool,
92 },
93 #[error("Sample offset must be a const-expression")]
94 InvalidSampleOffsetExprType,
95 #[error("Sample offset constant {1:?} doesn't match the image dimension {0:?}")]
96 InvalidSampleOffset(crate::ImageDimension, Handle<crate::Expression>),
97 #[error("Depth reference {0:?} is not a scalar float")]
98 InvalidDepthReference(Handle<crate::Expression>),
99 #[error("Depth sample level can only be Auto or Zero")]
100 InvalidDepthSampleLevel,
101 #[error("Gather level can only be Zero")]
102 InvalidGatherLevel,
103 #[error("Gather component {0:?} doesn't exist in the image")]
104 InvalidGatherComponent(crate::SwizzleComponent),
105 #[error("Gather can't be done for image dimension {0:?}")]
106 InvalidGatherDimension(crate::ImageDimension),
107 #[error("Sample level (exact) type {0:?} is not a scalar float")]
108 InvalidSampleLevelExactType(Handle<crate::Expression>),
109 #[error("Sample level (bias) type {0:?} is not a scalar float")]
110 InvalidSampleLevelBiasType(Handle<crate::Expression>),
111 #[error("Sample level (gradient) of {1:?} doesn't match the image dimension {0:?}")]
112 InvalidSampleLevelGradientType(crate::ImageDimension, Handle<crate::Expression>),
113 #[error("Unable to cast")]
114 InvalidCastArgument,
115 #[error("Invalid argument count for {0:?}")]
116 WrongArgumentCount(crate::MathFunction),
117 #[error("Argument [{1}] to {0:?} as expression {2:?} has an invalid type.")]
118 InvalidArgumentType(crate::MathFunction, u32, Handle<crate::Expression>),
119 #[error("Atomic result type can't be {0:?}")]
120 InvalidAtomicResultType(Handle<crate::Type>),
121 #[error(
122 "workgroupUniformLoad result type can't be {0:?}. It can only be a constructible type."
123 )]
124 InvalidWorkGroupUniformLoadResultType(Handle<crate::Type>),
125 #[error("Shader requires capability {0:?}")]
126 MissingCapabilities(super::Capabilities),
127 #[error(transparent)]
128 Literal(#[from] LiteralError),
129 #[error("{0:?} is not supported for Width {2} {1:?} arguments yet, see https://github.com/gfx-rs/wgpu/issues/5276")]
130 UnsupportedWidth(crate::MathFunction, crate::ScalarKind, crate::Bytes),
131}
132
133#[derive(Clone, Debug, thiserror::Error)]
134#[cfg_attr(test, derive(PartialEq))]
135pub enum ConstExpressionError {
136 #[error("The expression is not a constant or override expression")]
137 NonConstOrOverride,
138 #[error("The expression is not a fully evaluated constant expression")]
139 NonFullyEvaluatedConst,
140 #[error(transparent)]
141 Compose(#[from] super::ComposeError),
142 #[error("Splatting {0:?} can't be done")]
143 InvalidSplatType(Handle<crate::Expression>),
144 #[error("Type resolution failed")]
145 Type(#[from] ResolveError),
146 #[error(transparent)]
147 Literal(#[from] LiteralError),
148 #[error(transparent)]
149 Width(#[from] super::r#type::WidthError),
150}
151
152#[derive(Clone, Debug, thiserror::Error)]
153#[cfg_attr(test, derive(PartialEq))]
154pub enum LiteralError {
155 #[error("Float literal is NaN")]
156 NaN,
157 #[error("Float literal is infinite")]
158 Infinity,
159 #[error(transparent)]
160 Width(#[from] super::r#type::WidthError),
161}
162
163struct ExpressionTypeResolver<'a> {
164 root: Handle<crate::Expression>,
165 types: &'a UniqueArena<crate::Type>,
166 info: &'a FunctionInfo,
167}
168
169impl<'a> std::ops::Index<Handle<crate::Expression>> for ExpressionTypeResolver<'a> {
170 type Output = crate::TypeInner;
171
172 #[allow(clippy::panic)]
173 fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
174 if handle < self.root {
175 self.info[handle].ty.inner_with(self.types)
176 } else {
177 panic!(
179 "Depends on {:?}, which has not been processed yet",
180 self.root
181 )
182 }
183 }
184}
185
186impl super::Validator {
187 pub(super) fn validate_const_expression(
188 &self,
189 handle: Handle<crate::Expression>,
190 gctx: crate::proc::GlobalCtx,
191 mod_info: &ModuleInfo,
192 global_expr_kind: &crate::proc::ExpressionKindTracker,
193 ) -> Result<(), ConstExpressionError> {
194 use crate::Expression as E;
195
196 if !global_expr_kind.is_const_or_override(handle) {
197 return Err(ConstExpressionError::NonConstOrOverride);
198 }
199
200 match gctx.global_expressions[handle] {
201 E::Literal(literal) => {
202 self.validate_literal(literal)?;
203 }
204 E::Constant(_) | E::ZeroValue(_) => {}
205 E::Compose { ref components, ty } => {
206 validate_compose(
207 ty,
208 gctx,
209 components.iter().map(|&handle| mod_info[handle].clone()),
210 )?;
211 }
212 E::Splat { value, .. } => match *mod_info[value].inner_with(gctx.types) {
213 crate::TypeInner::Scalar { .. } => {}
214 _ => return Err(ConstExpressionError::InvalidSplatType(value)),
215 },
216 _ if global_expr_kind.is_const(handle) || !self.allow_overrides => {
217 return Err(ConstExpressionError::NonFullyEvaluatedConst)
218 }
219 _ => {}
221 }
222
223 Ok(())
224 }
225
226 #[allow(clippy::too_many_arguments)]
227 pub(super) fn validate_expression(
228 &self,
229 root: Handle<crate::Expression>,
230 expression: &crate::Expression,
231 function: &crate::Function,
232 module: &crate::Module,
233 info: &FunctionInfo,
234 mod_info: &ModuleInfo,
235 global_expr_kind: &crate::proc::ExpressionKindTracker,
236 ) -> Result<ShaderStages, ExpressionError> {
237 use crate::{Expression as E, Scalar as Sc, ScalarKind as Sk, TypeInner as Ti};
238
239 let resolver = ExpressionTypeResolver {
240 root,
241 types: &module.types,
242 info,
243 };
244
245 let stages = match *expression {
246 E::Access { base, index } => {
247 let base_type = &resolver[base];
248 let dynamic_indexing_restricted = match *base_type {
250 Ti::Vector { .. } => false,
251 Ti::Matrix { .. } | Ti::Array { .. } => true,
252 Ti::Pointer { .. }
253 | Ti::ValuePointer { size: Some(_), .. }
254 | Ti::BindingArray { .. } => false,
255 ref other => {
256 log::error!("Indexing of {:?}", other);
257 return Err(ExpressionError::InvalidBaseType(base));
258 }
259 };
260 match resolver[index] {
261 Ti::Scalar(Sc {
263 kind: Sk::Sint | Sk::Uint,
264 ..
265 }) => {}
266 ref other => {
267 log::error!("Indexing by {:?}", other);
268 return Err(ExpressionError::InvalidIndexType(index));
269 }
270 }
271 if dynamic_indexing_restricted && function.expressions[index].is_dynamic_index() {
272 return Err(ExpressionError::IndexMustBeConstant(base));
273 }
274
275 if let crate::proc::IndexableLength::Known(known_length) =
278 base_type.indexable_length(module)?
279 {
280 match module
281 .to_ctx()
282 .eval_expr_to_u32_from(index, &function.expressions)
283 {
284 Ok(value) => {
285 if value >= known_length {
286 return Err(ExpressionError::IndexOutOfBounds(base, value));
287 }
288 }
289 Err(crate::proc::U32EvalError::Negative) => {
290 return Err(ExpressionError::NegativeIndex(base))
291 }
292 Err(crate::proc::U32EvalError::NonConst) => {}
293 }
294 }
295
296 ShaderStages::all()
297 }
298 E::AccessIndex { base, index } => {
299 fn resolve_index_limit(
300 module: &crate::Module,
301 top: Handle<crate::Expression>,
302 ty: &crate::TypeInner,
303 top_level: bool,
304 ) -> Result<u32, ExpressionError> {
305 let limit = match *ty {
306 Ti::Vector { size, .. }
307 | Ti::ValuePointer {
308 size: Some(size), ..
309 } => size as u32,
310 Ti::Matrix { columns, .. } => columns as u32,
311 Ti::Array {
312 size: crate::ArraySize::Constant(len),
313 ..
314 } => len.get(),
315 Ti::Array { .. } | Ti::BindingArray { .. } => u32::MAX, Ti::Pointer { base, .. } if top_level => {
317 resolve_index_limit(module, top, &module.types[base].inner, false)?
318 }
319 Ti::Struct { ref members, .. } => members.len() as u32,
320 ref other => {
321 log::error!("Indexing of {:?}", other);
322 return Err(ExpressionError::InvalidBaseType(top));
323 }
324 };
325 Ok(limit)
326 }
327
328 let limit = resolve_index_limit(module, base, &resolver[base], true)?;
329 if index >= limit {
330 return Err(ExpressionError::IndexOutOfBounds(base, limit));
331 }
332 ShaderStages::all()
333 }
334 E::Splat { size: _, value } => match resolver[value] {
335 Ti::Scalar { .. } => ShaderStages::all(),
336 ref other => {
337 log::error!("Splat scalar type {:?}", other);
338 return Err(ExpressionError::InvalidSplatType(value));
339 }
340 },
341 E::Swizzle {
342 size,
343 vector,
344 pattern,
345 } => {
346 let vec_size = match resolver[vector] {
347 Ti::Vector { size: vec_size, .. } => vec_size,
348 ref other => {
349 log::error!("Swizzle vector type {:?}", other);
350 return Err(ExpressionError::InvalidVectorType(vector));
351 }
352 };
353 for &sc in pattern[..size as usize].iter() {
354 if sc as u8 >= vec_size as u8 {
355 return Err(ExpressionError::InvalidSwizzleComponent(sc, vec_size));
356 }
357 }
358 ShaderStages::all()
359 }
360 E::Literal(literal) => {
361 self.validate_literal(literal)?;
362 ShaderStages::all()
363 }
364 E::Constant(_) | E::Override(_) | E::ZeroValue(_) => ShaderStages::all(),
365 E::Compose { ref components, ty } => {
366 validate_compose(
367 ty,
368 module.to_ctx(),
369 components.iter().map(|&handle| info[handle].ty.clone()),
370 )?;
371 ShaderStages::all()
372 }
373 E::FunctionArgument(index) => {
374 if index >= function.arguments.len() as u32 {
375 return Err(ExpressionError::FunctionArgumentDoesntExist(index));
376 }
377 ShaderStages::all()
378 }
379 E::GlobalVariable(_handle) => ShaderStages::all(),
380 E::LocalVariable(_handle) => ShaderStages::all(),
381 E::Load { pointer } => {
382 match resolver[pointer] {
383 Ti::Pointer { base, .. }
384 if self.types[base.index()]
385 .flags
386 .contains(TypeFlags::SIZED | TypeFlags::DATA) => {}
387 Ti::ValuePointer { .. } => {}
388 ref other => {
389 log::error!("Loading {:?}", other);
390 return Err(ExpressionError::InvalidPointerType(pointer));
391 }
392 }
393 ShaderStages::all()
394 }
395 E::ImageSample {
396 image,
397 sampler,
398 gather,
399 coordinate,
400 array_index,
401 offset,
402 level,
403 depth_ref,
404 } => {
405 let image_ty = Self::global_var_ty(module, function, image)?;
407 let sampler_ty = Self::global_var_ty(module, function, sampler)?;
408
409 let comparison = match module.types[sampler_ty].inner {
410 Ti::Sampler { comparison } => comparison,
411 _ => return Err(ExpressionError::ExpectedSamplerType(sampler_ty)),
412 };
413
414 let (class, dim) = match module.types[image_ty].inner {
415 Ti::Image {
416 class,
417 arrayed,
418 dim,
419 } => {
420 if arrayed != array_index.is_some() {
422 return Err(ExpressionError::InvalidImageArrayIndex);
423 }
424 if let Some(expr) = array_index {
425 match resolver[expr] {
426 Ti::Scalar(Sc {
427 kind: Sk::Sint | Sk::Uint,
428 ..
429 }) => {}
430 _ => return Err(ExpressionError::InvalidImageArrayIndexType(expr)),
431 }
432 }
433 (class, dim)
434 }
435 _ => return Err(ExpressionError::ExpectedImageType(image_ty)),
436 };
437
438 let image_depth = match class {
440 crate::ImageClass::Sampled {
441 kind: crate::ScalarKind::Float,
442 multi: false,
443 } => false,
444 crate::ImageClass::Sampled {
445 kind: crate::ScalarKind::Uint | crate::ScalarKind::Sint,
446 multi: false,
447 } if gather.is_some() => false,
448 crate::ImageClass::Depth { multi: false } => true,
449 _ => return Err(ExpressionError::InvalidImageClass(class)),
450 };
451 if comparison != depth_ref.is_some() || (comparison && !image_depth) {
452 return Err(ExpressionError::ComparisonSamplingMismatch {
453 image: class,
454 sampler: comparison,
455 has_ref: depth_ref.is_some(),
456 });
457 }
458
459 let num_components = match dim {
461 crate::ImageDimension::D1 => 1,
462 crate::ImageDimension::D2 => 2,
463 crate::ImageDimension::D3 | crate::ImageDimension::Cube => 3,
464 };
465 match resolver[coordinate] {
466 Ti::Scalar(Sc {
467 kind: Sk::Float, ..
468 }) if num_components == 1 => {}
469 Ti::Vector {
470 size,
471 scalar:
472 Sc {
473 kind: Sk::Float, ..
474 },
475 } if size as u32 == num_components => {}
476 _ => return Err(ExpressionError::InvalidImageCoordinateType(dim, coordinate)),
477 }
478
479 if let Some(const_expr) = offset {
481 if !global_expr_kind.is_const(const_expr) {
482 return Err(ExpressionError::InvalidSampleOffsetExprType);
483 }
484
485 match *mod_info[const_expr].inner_with(&module.types) {
486 Ti::Scalar(Sc { kind: Sk::Sint, .. }) if num_components == 1 => {}
487 Ti::Vector {
488 size,
489 scalar: Sc { kind: Sk::Sint, .. },
490 } if size as u32 == num_components => {}
491 _ => {
492 return Err(ExpressionError::InvalidSampleOffset(dim, const_expr));
493 }
494 }
495 }
496
497 if let Some(expr) = depth_ref {
499 match resolver[expr] {
500 Ti::Scalar(Sc {
501 kind: Sk::Float, ..
502 }) => {}
503 _ => return Err(ExpressionError::InvalidDepthReference(expr)),
504 }
505 match level {
506 crate::SampleLevel::Auto | crate::SampleLevel::Zero => {}
507 _ => return Err(ExpressionError::InvalidDepthSampleLevel),
508 }
509 }
510
511 if let Some(component) = gather {
512 match dim {
513 crate::ImageDimension::D2 | crate::ImageDimension::Cube => {}
514 crate::ImageDimension::D1 | crate::ImageDimension::D3 => {
515 return Err(ExpressionError::InvalidGatherDimension(dim))
516 }
517 };
518 let max_component = match class {
519 crate::ImageClass::Depth { .. } => crate::SwizzleComponent::X,
520 _ => crate::SwizzleComponent::W,
521 };
522 if component > max_component {
523 return Err(ExpressionError::InvalidGatherComponent(component));
524 }
525 match level {
526 crate::SampleLevel::Zero => {}
527 _ => return Err(ExpressionError::InvalidGatherLevel),
528 }
529 }
530
531 match level {
533 crate::SampleLevel::Auto => ShaderStages::FRAGMENT,
534 crate::SampleLevel::Zero => ShaderStages::all(),
535 crate::SampleLevel::Exact(expr) => {
536 match resolver[expr] {
537 Ti::Scalar(Sc {
538 kind: Sk::Float, ..
539 }) => {}
540 _ => return Err(ExpressionError::InvalidSampleLevelExactType(expr)),
541 }
542 ShaderStages::all()
543 }
544 crate::SampleLevel::Bias(expr) => {
545 match resolver[expr] {
546 Ti::Scalar(Sc {
547 kind: Sk::Float, ..
548 }) => {}
549 _ => return Err(ExpressionError::InvalidSampleLevelBiasType(expr)),
550 }
551 ShaderStages::FRAGMENT
552 }
553 crate::SampleLevel::Gradient { x, y } => {
554 match resolver[x] {
555 Ti::Scalar(Sc {
556 kind: Sk::Float, ..
557 }) if num_components == 1 => {}
558 Ti::Vector {
559 size,
560 scalar:
561 Sc {
562 kind: Sk::Float, ..
563 },
564 } if size as u32 == num_components => {}
565 _ => {
566 return Err(ExpressionError::InvalidSampleLevelGradientType(dim, x))
567 }
568 }
569 match resolver[y] {
570 Ti::Scalar(Sc {
571 kind: Sk::Float, ..
572 }) if num_components == 1 => {}
573 Ti::Vector {
574 size,
575 scalar:
576 Sc {
577 kind: Sk::Float, ..
578 },
579 } if size as u32 == num_components => {}
580 _ => {
581 return Err(ExpressionError::InvalidSampleLevelGradientType(dim, y))
582 }
583 }
584 ShaderStages::all()
585 }
586 }
587 }
588 E::ImageLoad {
589 image,
590 coordinate,
591 array_index,
592 sample,
593 level,
594 } => {
595 let ty = Self::global_var_ty(module, function, image)?;
596 match module.types[ty].inner {
597 Ti::Image {
598 class,
599 arrayed,
600 dim,
601 } => {
602 match resolver[coordinate].image_storage_coordinates() {
603 Some(coord_dim) if coord_dim == dim => {}
604 _ => {
605 return Err(ExpressionError::InvalidImageCoordinateType(
606 dim, coordinate,
607 ))
608 }
609 };
610 if arrayed != array_index.is_some() {
611 return Err(ExpressionError::InvalidImageArrayIndex);
612 }
613 if let Some(expr) = array_index {
614 match resolver[expr] {
615 Ti::Scalar(Sc {
616 kind: Sk::Sint | Sk::Uint,
617 width: _,
618 }) => {}
619 _ => return Err(ExpressionError::InvalidImageArrayIndexType(expr)),
620 }
621 }
622
623 match (sample, class.is_multisampled()) {
624 (None, false) => {}
625 (Some(sample), true) => {
626 if resolver[sample].scalar_kind() != Some(Sk::Sint) {
627 return Err(ExpressionError::InvalidImageOtherIndexType(
628 sample,
629 ));
630 }
631 }
632 _ => {
633 return Err(ExpressionError::InvalidImageOtherIndex);
634 }
635 }
636
637 match (level, class.is_mipmapped()) {
638 (None, false) => {}
639 (Some(level), true) => {
640 if resolver[level].scalar_kind() != Some(Sk::Sint) {
641 return Err(ExpressionError::InvalidImageOtherIndexType(level));
642 }
643 }
644 _ => {
645 return Err(ExpressionError::InvalidImageOtherIndex);
646 }
647 }
648 }
649 _ => return Err(ExpressionError::ExpectedImageType(ty)),
650 }
651 ShaderStages::all()
652 }
653 E::ImageQuery { image, query } => {
654 let ty = Self::global_var_ty(module, function, image)?;
655 match module.types[ty].inner {
656 Ti::Image { class, arrayed, .. } => {
657 let good = match query {
658 crate::ImageQuery::NumLayers => arrayed,
659 crate::ImageQuery::Size { level: None } => true,
660 crate::ImageQuery::Size { level: Some(_) }
661 | crate::ImageQuery::NumLevels => class.is_mipmapped(),
662 crate::ImageQuery::NumSamples => class.is_multisampled(),
663 };
664 if !good {
665 return Err(ExpressionError::InvalidImageClass(class));
666 }
667 }
668 _ => return Err(ExpressionError::ExpectedImageType(ty)),
669 }
670 ShaderStages::all()
671 }
672 E::Unary { op, expr } => {
673 use crate::UnaryOperator as Uo;
674 let inner = &resolver[expr];
675 match (op, inner.scalar_kind()) {
676 (Uo::Negate, Some(Sk::Float | Sk::Sint))
677 | (Uo::LogicalNot, Some(Sk::Bool))
678 | (Uo::BitwiseNot, Some(Sk::Sint | Sk::Uint)) => {}
679 other => {
680 log::error!("Op {:?} kind {:?}", op, other);
681 return Err(ExpressionError::InvalidUnaryOperandType(op, expr));
682 }
683 }
684 ShaderStages::all()
685 }
686 E::Binary { op, left, right } => {
687 use crate::BinaryOperator as Bo;
688 let left_inner = &resolver[left];
689 let right_inner = &resolver[right];
690 let good = match op {
691 Bo::Add | Bo::Subtract => match *left_inner {
692 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
693 Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
694 Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
695 },
696 Ti::Matrix { .. } => left_inner == right_inner,
697 _ => false,
698 },
699 Bo::Divide | Bo::Modulo => match *left_inner {
700 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
701 Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
702 Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
703 },
704 _ => false,
705 },
706 Bo::Multiply => {
707 let kind_allowed = match left_inner.scalar_kind() {
708 Some(Sk::Uint | Sk::Sint | Sk::Float) => true,
709 Some(Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat) | None => false,
710 };
711 let types_match = match (left_inner, right_inner) {
712 (&Ti::Scalar(scalar1), &Ti::Scalar(scalar2))
714 | (
715 &Ti::Vector {
716 scalar: scalar1, ..
717 },
718 &Ti::Scalar(scalar2),
719 )
720 | (
721 &Ti::Scalar(scalar1),
722 &Ti::Vector {
723 scalar: scalar2, ..
724 },
725 ) => scalar1 == scalar2,
726 (
728 &Ti::Scalar(Sc {
729 kind: Sk::Float, ..
730 }),
731 &Ti::Matrix { .. },
732 )
733 | (
734 &Ti::Matrix { .. },
735 &Ti::Scalar(Sc {
736 kind: Sk::Float, ..
737 }),
738 ) => true,
739 (
741 &Ti::Vector {
742 size: size1,
743 scalar: scalar1,
744 },
745 &Ti::Vector {
746 size: size2,
747 scalar: scalar2,
748 },
749 ) => scalar1 == scalar2 && size1 == size2,
750 (
752 &Ti::Matrix { columns, .. },
753 &Ti::Vector {
754 size,
755 scalar:
756 Sc {
757 kind: Sk::Float, ..
758 },
759 },
760 ) => columns == size,
761 (
763 &Ti::Vector {
764 size,
765 scalar:
766 Sc {
767 kind: Sk::Float, ..
768 },
769 },
770 &Ti::Matrix { rows, .. },
771 ) => size == rows,
772 (&Ti::Matrix { columns, .. }, &Ti::Matrix { rows, .. }) => {
773 columns == rows
774 }
775 _ => false,
776 };
777 let left_width = left_inner.scalar_width().unwrap_or(0);
778 let right_width = right_inner.scalar_width().unwrap_or(0);
779 kind_allowed && types_match && left_width == right_width
780 }
781 Bo::Equal | Bo::NotEqual => left_inner.is_sized() && left_inner == right_inner,
782 Bo::Less | Bo::LessEqual | Bo::Greater | Bo::GreaterEqual => {
783 match *left_inner {
784 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
785 Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
786 Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
787 },
788 ref other => {
789 log::error!("Op {:?} left type {:?}", op, other);
790 false
791 }
792 }
793 }
794 Bo::LogicalAnd | Bo::LogicalOr => match *left_inner {
795 Ti::Scalar(Sc { kind: Sk::Bool, .. })
796 | Ti::Vector {
797 scalar: Sc { kind: Sk::Bool, .. },
798 ..
799 } => left_inner == right_inner,
800 ref other => {
801 log::error!("Op {:?} left type {:?}", op, other);
802 false
803 }
804 },
805 Bo::And | Bo::InclusiveOr => match *left_inner {
806 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
807 Sk::Bool | Sk::Sint | Sk::Uint => left_inner == right_inner,
808 Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false,
809 },
810 ref other => {
811 log::error!("Op {:?} left type {:?}", op, other);
812 false
813 }
814 },
815 Bo::ExclusiveOr => match *left_inner {
816 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
817 Sk::Sint | Sk::Uint => left_inner == right_inner,
818 Sk::Bool | Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false,
819 },
820 ref other => {
821 log::error!("Op {:?} left type {:?}", op, other);
822 false
823 }
824 },
825 Bo::ShiftLeft | Bo::ShiftRight => {
826 let (base_size, base_scalar) = match *left_inner {
827 Ti::Scalar(scalar) => (Ok(None), scalar),
828 Ti::Vector { size, scalar } => (Ok(Some(size)), scalar),
829 ref other => {
830 log::error!("Op {:?} base type {:?}", op, other);
831 (Err(()), Sc::BOOL)
832 }
833 };
834 let shift_size = match *right_inner {
835 Ti::Scalar(Sc { kind: Sk::Uint, .. }) => Ok(None),
836 Ti::Vector {
837 size,
838 scalar: Sc { kind: Sk::Uint, .. },
839 } => Ok(Some(size)),
840 ref other => {
841 log::error!("Op {:?} shift type {:?}", op, other);
842 Err(())
843 }
844 };
845 match base_scalar.kind {
846 Sk::Sint | Sk::Uint => base_size.is_ok() && base_size == shift_size,
847 Sk::Float | Sk::AbstractInt | Sk::AbstractFloat | Sk::Bool => false,
848 }
849 }
850 };
851 if !good {
852 log::error!(
853 "Left: {:?} of type {:?}",
854 function.expressions[left],
855 left_inner
856 );
857 log::error!(
858 "Right: {:?} of type {:?}",
859 function.expressions[right],
860 right_inner
861 );
862 return Err(ExpressionError::InvalidBinaryOperandTypes(op, left, right));
863 }
864 ShaderStages::all()
865 }
866 E::Select {
867 condition,
868 accept,
869 reject,
870 } => {
871 let accept_inner = &resolver[accept];
872 let reject_inner = &resolver[reject];
873 let condition_good = match resolver[condition] {
874 Ti::Scalar(Sc {
875 kind: Sk::Bool,
876 width: _,
877 }) => {
878 match *accept_inner {
881 Ti::Scalar { .. } | Ti::Vector { .. } => true,
882 _ => false,
883 }
884 }
885 Ti::Vector {
886 size,
887 scalar:
888 Sc {
889 kind: Sk::Bool,
890 width: _,
891 },
892 } => match *accept_inner {
893 Ti::Vector {
894 size: other_size, ..
895 } => size == other_size,
896 _ => false,
897 },
898 _ => false,
899 };
900 if !condition_good || accept_inner != reject_inner {
901 return Err(ExpressionError::InvalidSelectTypes);
902 }
903 ShaderStages::all()
904 }
905 E::Derivative { expr, .. } => {
906 match resolver[expr] {
907 Ti::Scalar(Sc {
908 kind: Sk::Float, ..
909 })
910 | Ti::Vector {
911 scalar:
912 Sc {
913 kind: Sk::Float, ..
914 },
915 ..
916 } => {}
917 _ => return Err(ExpressionError::InvalidDerivative),
918 }
919 ShaderStages::FRAGMENT
920 }
921 E::Relational { fun, argument } => {
922 use crate::RelationalFunction as Rf;
923 let argument_inner = &resolver[argument];
924 match fun {
925 Rf::All | Rf::Any => match *argument_inner {
926 Ti::Vector {
927 scalar: Sc { kind: Sk::Bool, .. },
928 ..
929 } => {}
930 ref other => {
931 log::error!("All/Any of type {:?}", other);
932 return Err(ExpressionError::InvalidBooleanVector(argument));
933 }
934 },
935 Rf::IsNan | Rf::IsInf => match *argument_inner {
936 Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
937 if scalar.kind == Sk::Float => {}
938 ref other => {
939 log::error!("Float test of type {:?}", other);
940 return Err(ExpressionError::InvalidFloatArgument(argument));
941 }
942 },
943 }
944 ShaderStages::all()
945 }
946 E::Math {
947 fun,
948 arg,
949 arg1,
950 arg2,
951 arg3,
952 } => {
953 use crate::MathFunction as Mf;
954
955 let resolve = |arg| &resolver[arg];
956 let arg_ty = resolve(arg);
957 let arg1_ty = arg1.map(resolve);
958 let arg2_ty = arg2.map(resolve);
959 let arg3_ty = arg3.map(resolve);
960 match fun {
961 Mf::Abs => {
962 if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
963 return Err(ExpressionError::WrongArgumentCount(fun));
964 }
965 let good = match *arg_ty {
966 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
967 scalar.kind != Sk::Bool
968 }
969 _ => false,
970 };
971 if !good {
972 return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
973 }
974 }
975 Mf::Min | Mf::Max => {
976 let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
977 (Some(ty1), None, None) => ty1,
978 _ => return Err(ExpressionError::WrongArgumentCount(fun)),
979 };
980 let good = match *arg_ty {
981 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
982 scalar.kind != Sk::Bool
983 }
984 _ => false,
985 };
986 if !good {
987 return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
988 }
989 if arg1_ty != arg_ty {
990 return Err(ExpressionError::InvalidArgumentType(
991 fun,
992 1,
993 arg1.unwrap(),
994 ));
995 }
996 }
997 Mf::Clamp => {
998 let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
999 (Some(ty1), Some(ty2), None) => (ty1, ty2),
1000 _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1001 };
1002 let good = match *arg_ty {
1003 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
1004 scalar.kind != Sk::Bool
1005 }
1006 _ => false,
1007 };
1008 if !good {
1009 return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
1010 }
1011 if arg1_ty != arg_ty {
1012 return Err(ExpressionError::InvalidArgumentType(
1013 fun,
1014 1,
1015 arg1.unwrap(),
1016 ));
1017 }
1018 if arg2_ty != arg_ty {
1019 return Err(ExpressionError::InvalidArgumentType(
1020 fun,
1021 2,
1022 arg2.unwrap(),
1023 ));
1024 }
1025 }
1026 Mf::Saturate
1027 | Mf::Cos
1028 | Mf::Cosh
1029 | Mf::Sin
1030 | Mf::Sinh
1031 | Mf::Tan
1032 | Mf::Tanh
1033 | Mf::Acos
1034 | Mf::Asin
1035 | Mf::Atan
1036 | Mf::Asinh
1037 | Mf::Acosh
1038 | Mf::Atanh
1039 | Mf::Radians
1040 | Mf::Degrees
1041 | Mf::Ceil
1042 | Mf::Floor
1043 | Mf::Round
1044 | Mf::Fract
1045 | Mf::Trunc
1046 | Mf::Exp
1047 | Mf::Exp2
1048 | Mf::Log
1049 | Mf::Log2
1050 | Mf::Length
1051 | Mf::Sqrt
1052 | Mf::InverseSqrt => {
1053 if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1054 return Err(ExpressionError::WrongArgumentCount(fun));
1055 }
1056 match *arg_ty {
1057 Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
1058 if scalar.kind == Sk::Float => {}
1059 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1060 }
1061 }
1062 Mf::Sign => {
1063 if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1064 return Err(ExpressionError::WrongArgumentCount(fun));
1065 }
1066 match *arg_ty {
1067 Ti::Scalar(Sc {
1068 kind: Sk::Float | Sk::Sint,
1069 ..
1070 })
1071 | Ti::Vector {
1072 scalar:
1073 Sc {
1074 kind: Sk::Float | Sk::Sint,
1075 ..
1076 },
1077 ..
1078 } => {}
1079 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1080 }
1081 }
1082 Mf::Atan2 | Mf::Pow | Mf::Distance | Mf::Step => {
1083 let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
1084 (Some(ty1), None, None) => ty1,
1085 _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1086 };
1087 match *arg_ty {
1088 Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
1089 if scalar.kind == Sk::Float => {}
1090 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1091 }
1092 if arg1_ty != arg_ty {
1093 return Err(ExpressionError::InvalidArgumentType(
1094 fun,
1095 1,
1096 arg1.unwrap(),
1097 ));
1098 }
1099 }
1100 Mf::Modf | Mf::Frexp => {
1101 if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1102 return Err(ExpressionError::WrongArgumentCount(fun));
1103 }
1104 if !matches!(*arg_ty,
1105 Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
1106 if scalar.kind == Sk::Float)
1107 {
1108 return Err(ExpressionError::InvalidArgumentType(fun, 1, arg));
1109 }
1110 }
1111 Mf::Ldexp => {
1112 let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
1113 (Some(ty1), None, None) => ty1,
1114 _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1115 };
1116 let size0 = match *arg_ty {
1117 Ti::Scalar(Sc {
1118 kind: Sk::Float, ..
1119 }) => None,
1120 Ti::Vector {
1121 scalar:
1122 Sc {
1123 kind: Sk::Float, ..
1124 },
1125 size,
1126 } => Some(size),
1127 _ => {
1128 return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
1129 }
1130 };
1131 let good = match *arg1_ty {
1132 Ti::Scalar(Sc { kind: Sk::Sint, .. }) if size0.is_none() => true,
1133 Ti::Vector {
1134 size,
1135 scalar: Sc { kind: Sk::Sint, .. },
1136 } if Some(size) == size0 => true,
1137 _ => false,
1138 };
1139 if !good {
1140 return Err(ExpressionError::InvalidArgumentType(
1141 fun,
1142 1,
1143 arg1.unwrap(),
1144 ));
1145 }
1146 }
1147 Mf::Dot => {
1148 let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
1149 (Some(ty1), None, None) => ty1,
1150 _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1151 };
1152 match *arg_ty {
1153 Ti::Vector {
1154 scalar:
1155 Sc {
1156 kind: Sk::Float | Sk::Sint | Sk::Uint,
1157 ..
1158 },
1159 ..
1160 } => {}
1161 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1162 }
1163 if arg1_ty != arg_ty {
1164 return Err(ExpressionError::InvalidArgumentType(
1165 fun,
1166 1,
1167 arg1.unwrap(),
1168 ));
1169 }
1170 }
1171 Mf::Outer | Mf::Cross | Mf::Reflect => {
1172 let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
1173 (Some(ty1), None, None) => ty1,
1174 _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1175 };
1176 match *arg_ty {
1177 Ti::Vector {
1178 scalar:
1179 Sc {
1180 kind: Sk::Float, ..
1181 },
1182 ..
1183 } => {}
1184 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1185 }
1186 if arg1_ty != arg_ty {
1187 return Err(ExpressionError::InvalidArgumentType(
1188 fun,
1189 1,
1190 arg1.unwrap(),
1191 ));
1192 }
1193 }
1194 Mf::Refract => {
1195 let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
1196 (Some(ty1), Some(ty2), None) => (ty1, ty2),
1197 _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1198 };
1199
1200 match *arg_ty {
1201 Ti::Vector {
1202 scalar:
1203 Sc {
1204 kind: Sk::Float, ..
1205 },
1206 ..
1207 } => {}
1208 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1209 }
1210
1211 if arg1_ty != arg_ty {
1212 return Err(ExpressionError::InvalidArgumentType(
1213 fun,
1214 1,
1215 arg1.unwrap(),
1216 ));
1217 }
1218
1219 match (arg_ty, arg2_ty) {
1220 (
1221 &Ti::Vector {
1222 scalar:
1223 Sc {
1224 width: vector_width,
1225 ..
1226 },
1227 ..
1228 },
1229 &Ti::Scalar(Sc {
1230 width: scalar_width,
1231 kind: Sk::Float,
1232 }),
1233 ) if vector_width == scalar_width => {}
1234 _ => {
1235 return Err(ExpressionError::InvalidArgumentType(
1236 fun,
1237 2,
1238 arg2.unwrap(),
1239 ))
1240 }
1241 }
1242 }
1243 Mf::Normalize => {
1244 if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1245 return Err(ExpressionError::WrongArgumentCount(fun));
1246 }
1247 match *arg_ty {
1248 Ti::Vector {
1249 scalar:
1250 Sc {
1251 kind: Sk::Float, ..
1252 },
1253 ..
1254 } => {}
1255 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1256 }
1257 }
1258 Mf::FaceForward | Mf::Fma | Mf::SmoothStep => {
1259 let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
1260 (Some(ty1), Some(ty2), None) => (ty1, ty2),
1261 _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1262 };
1263 match *arg_ty {
1264 Ti::Scalar(Sc {
1265 kind: Sk::Float, ..
1266 })
1267 | Ti::Vector {
1268 scalar:
1269 Sc {
1270 kind: Sk::Float, ..
1271 },
1272 ..
1273 } => {}
1274 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1275 }
1276 if arg1_ty != arg_ty {
1277 return Err(ExpressionError::InvalidArgumentType(
1278 fun,
1279 1,
1280 arg1.unwrap(),
1281 ));
1282 }
1283 if arg2_ty != arg_ty {
1284 return Err(ExpressionError::InvalidArgumentType(
1285 fun,
1286 2,
1287 arg2.unwrap(),
1288 ));
1289 }
1290 }
1291 Mf::Mix => {
1292 let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
1293 (Some(ty1), Some(ty2), None) => (ty1, ty2),
1294 _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1295 };
1296 let arg_width = match *arg_ty {
1297 Ti::Scalar(Sc {
1298 kind: Sk::Float,
1299 width,
1300 })
1301 | Ti::Vector {
1302 scalar:
1303 Sc {
1304 kind: Sk::Float,
1305 width,
1306 },
1307 ..
1308 } => width,
1309 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1310 };
1311 if arg1_ty != arg_ty {
1312 return Err(ExpressionError::InvalidArgumentType(
1313 fun,
1314 1,
1315 arg1.unwrap(),
1316 ));
1317 }
1318 match *arg2_ty {
1320 Ti::Scalar(Sc {
1321 kind: Sk::Float,
1322 width,
1323 }) if width == arg_width => {}
1324 _ if arg2_ty == arg_ty => {}
1325 _ => {
1326 return Err(ExpressionError::InvalidArgumentType(
1327 fun,
1328 2,
1329 arg2.unwrap(),
1330 ));
1331 }
1332 }
1333 }
1334 Mf::Inverse | Mf::Determinant => {
1335 if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1336 return Err(ExpressionError::WrongArgumentCount(fun));
1337 }
1338 let good = match *arg_ty {
1339 Ti::Matrix { columns, rows, .. } => columns == rows,
1340 _ => false,
1341 };
1342 if !good {
1343 return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
1344 }
1345 }
1346 Mf::Transpose => {
1347 if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1348 return Err(ExpressionError::WrongArgumentCount(fun));
1349 }
1350 match *arg_ty {
1351 Ti::Matrix { .. } => {}
1352 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1353 }
1354 }
1355 Mf::CountLeadingZeros
1357 | Mf::CountTrailingZeros
1358 | Mf::CountOneBits
1359 | Mf::ReverseBits
1360 | Mf::FindMsb
1361 | Mf::FindLsb => {
1362 if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1363 return Err(ExpressionError::WrongArgumentCount(fun));
1364 }
1365 match *arg_ty {
1366 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
1367 Sk::Sint | Sk::Uint => {
1368 if scalar.width != 4 {
1369 return Err(ExpressionError::UnsupportedWidth(
1370 fun,
1371 scalar.kind,
1372 scalar.width,
1373 ));
1374 }
1375 }
1376 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1377 },
1378 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1379 }
1380 }
1381 Mf::InsertBits => {
1382 let (arg1_ty, arg2_ty, arg3_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
1383 (Some(ty1), Some(ty2), Some(ty3)) => (ty1, ty2, ty3),
1384 _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1385 };
1386 match *arg_ty {
1387 Ti::Scalar(Sc {
1388 kind: Sk::Sint | Sk::Uint,
1389 ..
1390 })
1391 | Ti::Vector {
1392 scalar:
1393 Sc {
1394 kind: Sk::Sint | Sk::Uint,
1395 ..
1396 },
1397 ..
1398 } => {}
1399 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1400 }
1401 if arg1_ty != arg_ty {
1402 return Err(ExpressionError::InvalidArgumentType(
1403 fun,
1404 1,
1405 arg1.unwrap(),
1406 ));
1407 }
1408 match *arg2_ty {
1409 Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
1410 _ => {
1411 return Err(ExpressionError::InvalidArgumentType(
1412 fun,
1413 2,
1414 arg2.unwrap(),
1415 ))
1416 }
1417 }
1418 match *arg3_ty {
1419 Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
1420 _ => {
1421 return Err(ExpressionError::InvalidArgumentType(
1422 fun,
1423 2,
1424 arg3.unwrap(),
1425 ))
1426 }
1427 }
1428 for &arg in [arg_ty, arg1_ty, arg2_ty, arg3_ty].iter() {
1430 match *arg {
1431 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
1432 if scalar.width != 4 {
1433 return Err(ExpressionError::UnsupportedWidth(
1434 fun,
1435 scalar.kind,
1436 scalar.width,
1437 ));
1438 }
1439 }
1440 _ => {}
1441 }
1442 }
1443 }
1444 Mf::ExtractBits => {
1445 let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
1446 (Some(ty1), Some(ty2), None) => (ty1, ty2),
1447 _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1448 };
1449 match *arg_ty {
1450 Ti::Scalar(Sc {
1451 kind: Sk::Sint | Sk::Uint,
1452 ..
1453 })
1454 | Ti::Vector {
1455 scalar:
1456 Sc {
1457 kind: Sk::Sint | Sk::Uint,
1458 ..
1459 },
1460 ..
1461 } => {}
1462 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1463 }
1464 match *arg1_ty {
1465 Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
1466 _ => {
1467 return Err(ExpressionError::InvalidArgumentType(
1468 fun,
1469 2,
1470 arg1.unwrap(),
1471 ))
1472 }
1473 }
1474 match *arg2_ty {
1475 Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
1476 _ => {
1477 return Err(ExpressionError::InvalidArgumentType(
1478 fun,
1479 2,
1480 arg2.unwrap(),
1481 ))
1482 }
1483 }
1484 for &arg in [arg_ty, arg1_ty, arg2_ty].iter() {
1486 match *arg {
1487 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
1488 if scalar.width != 4 {
1489 return Err(ExpressionError::UnsupportedWidth(
1490 fun,
1491 scalar.kind,
1492 scalar.width,
1493 ));
1494 }
1495 }
1496 _ => {}
1497 }
1498 }
1499 }
1500 Mf::Pack2x16unorm | Mf::Pack2x16snorm | Mf::Pack2x16float => {
1501 if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1502 return Err(ExpressionError::WrongArgumentCount(fun));
1503 }
1504 match *arg_ty {
1505 Ti::Vector {
1506 size: crate::VectorSize::Bi,
1507 scalar:
1508 Sc {
1509 kind: Sk::Float, ..
1510 },
1511 } => {}
1512 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1513 }
1514 }
1515 Mf::Pack4x8snorm | Mf::Pack4x8unorm => {
1516 if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1517 return Err(ExpressionError::WrongArgumentCount(fun));
1518 }
1519 match *arg_ty {
1520 Ti::Vector {
1521 size: crate::VectorSize::Quad,
1522 scalar:
1523 Sc {
1524 kind: Sk::Float, ..
1525 },
1526 } => {}
1527 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1528 }
1529 }
1530 Mf::Unpack2x16float
1531 | Mf::Unpack2x16snorm
1532 | Mf::Unpack2x16unorm
1533 | Mf::Unpack4x8snorm
1534 | Mf::Unpack4x8unorm => {
1535 if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
1536 return Err(ExpressionError::WrongArgumentCount(fun));
1537 }
1538 match *arg_ty {
1539 Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
1540 _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
1541 }
1542 }
1543 }
1544 ShaderStages::all()
1545 }
1546 E::As {
1547 expr,
1548 kind,
1549 convert,
1550 } => {
1551 let mut base_scalar = match resolver[expr] {
1552 crate::TypeInner::Scalar(scalar) | crate::TypeInner::Vector { scalar, .. } => {
1553 scalar
1554 }
1555 crate::TypeInner::Matrix { scalar, .. } => scalar,
1556 _ => return Err(ExpressionError::InvalidCastArgument),
1557 };
1558 base_scalar.kind = kind;
1559 if let Some(width) = convert {
1560 base_scalar.width = width;
1561 }
1562 if self.check_width(base_scalar).is_err() {
1563 return Err(ExpressionError::InvalidCastArgument);
1564 }
1565 ShaderStages::all()
1566 }
1567 E::CallResult(function) => mod_info.functions[function.index()].available_stages,
1568 E::AtomicResult { ty, comparison } => {
1569 let scalar_predicate = |ty: &crate::TypeInner| match ty {
1570 &crate::TypeInner::Scalar(
1571 scalar @ Sc {
1572 kind: crate::ScalarKind::Uint | crate::ScalarKind::Sint,
1573 ..
1574 },
1575 ) => self.check_width(scalar).is_ok(),
1576 _ => false,
1577 };
1578 let good = match &module.types[ty].inner {
1579 ty if !comparison => scalar_predicate(ty),
1580 &crate::TypeInner::Struct { ref members, .. } if comparison => {
1581 validate_atomic_compare_exchange_struct(
1582 &module.types,
1583 members,
1584 scalar_predicate,
1585 )
1586 }
1587 _ => false,
1588 };
1589 if !good {
1590 return Err(ExpressionError::InvalidAtomicResultType(ty));
1591 }
1592 ShaderStages::all()
1593 }
1594 E::WorkGroupUniformLoadResult { ty } => {
1595 if self.types[ty.index()]
1596 .flags
1597 .contains(TypeFlags::SIZED | TypeFlags::CONSTRUCTIBLE)
1600 {
1601 ShaderStages::COMPUTE
1602 } else {
1603 return Err(ExpressionError::InvalidWorkGroupUniformLoadResultType(ty));
1604 }
1605 }
1606 E::ArrayLength(expr) => match resolver[expr] {
1607 Ti::Pointer { base, .. } => {
1608 let base_ty = &resolver.types[base];
1609 if let Ti::Array {
1610 size: crate::ArraySize::Dynamic,
1611 ..
1612 } = base_ty.inner
1613 {
1614 ShaderStages::all()
1615 } else {
1616 return Err(ExpressionError::InvalidArrayType(expr));
1617 }
1618 }
1619 ref other => {
1620 log::error!("Array length of {:?}", other);
1621 return Err(ExpressionError::InvalidArrayType(expr));
1622 }
1623 },
1624 E::RayQueryProceedResult => ShaderStages::all(),
1625 E::RayQueryGetIntersection {
1626 query,
1627 committed: _,
1628 } => match resolver[query] {
1629 Ti::Pointer {
1630 base,
1631 space: crate::AddressSpace::Function,
1632 } => match resolver.types[base].inner {
1633 Ti::RayQuery => ShaderStages::all(),
1634 ref other => {
1635 log::error!("Intersection result of a pointer to {:?}", other);
1636 return Err(ExpressionError::InvalidRayQueryType(query));
1637 }
1638 },
1639 ref other => {
1640 log::error!("Intersection result of {:?}", other);
1641 return Err(ExpressionError::InvalidRayQueryType(query));
1642 }
1643 },
1644 E::SubgroupBallotResult | E::SubgroupOperationResult { .. } => self.subgroup_stages,
1645 };
1646 Ok(stages)
1647 }
1648
1649 fn global_var_ty(
1650 module: &crate::Module,
1651 function: &crate::Function,
1652 expr: Handle<crate::Expression>,
1653 ) -> Result<Handle<crate::Type>, ExpressionError> {
1654 use crate::Expression as Ex;
1655
1656 match function.expressions[expr] {
1657 Ex::GlobalVariable(var_handle) => Ok(module.global_variables[var_handle].ty),
1658 Ex::FunctionArgument(i) => Ok(function.arguments[i as usize].ty),
1659 Ex::Access { base, .. } | Ex::AccessIndex { base, .. } => {
1660 match function.expressions[base] {
1661 Ex::GlobalVariable(var_handle) => {
1662 let array_ty = module.global_variables[var_handle].ty;
1663
1664 match module.types[array_ty].inner {
1665 crate::TypeInner::BindingArray { base, .. } => Ok(base),
1666 _ => Err(ExpressionError::ExpectedBindingArrayType(array_ty)),
1667 }
1668 }
1669 _ => Err(ExpressionError::ExpectedGlobalVariable),
1670 }
1671 }
1672 _ => Err(ExpressionError::ExpectedGlobalVariable),
1673 }
1674 }
1675
1676 pub fn validate_literal(&self, literal: crate::Literal) -> Result<(), LiteralError> {
1677 self.check_width(literal.scalar())?;
1678 check_literal_value(literal)?;
1679
1680 Ok(())
1681 }
1682}
1683
1684pub fn check_literal_value(literal: crate::Literal) -> Result<(), LiteralError> {
1685 let is_nan = match literal {
1686 crate::Literal::F64(v) => v.is_nan(),
1687 crate::Literal::F32(v) => v.is_nan(),
1688 _ => false,
1689 };
1690 if is_nan {
1691 return Err(LiteralError::NaN);
1692 }
1693
1694 let is_infinite = match literal {
1695 crate::Literal::F64(v) => v.is_infinite(),
1696 crate::Literal::F32(v) => v.is_infinite(),
1697 _ => false,
1698 };
1699 if is_infinite {
1700 return Err(LiteralError::Infinity);
1701 }
1702
1703 Ok(())
1704}
1705
1706#[cfg(all(test, feature = "validate"))]
1707fn validate_with_expression(
1709 expr: crate::Expression,
1710 caps: super::Capabilities,
1711) -> Result<ModuleInfo, crate::span::WithSpan<super::ValidationError>> {
1712 use crate::span::Span;
1713
1714 let mut function = crate::Function::default();
1715 function.expressions.append(expr, Span::default());
1716 function.body.push(
1717 crate::Statement::Emit(function.expressions.range_from(0)),
1718 Span::default(),
1719 );
1720
1721 let mut module = crate::Module::default();
1722 module.functions.append(function, Span::default());
1723
1724 let mut validator = super::Validator::new(super::ValidationFlags::EXPRESSIONS, caps);
1725
1726 validator.validate(&module)
1727}
1728
1729#[cfg(all(test, feature = "validate"))]
1730fn validate_with_const_expression(
1732 expr: crate::Expression,
1733 caps: super::Capabilities,
1734) -> Result<ModuleInfo, crate::span::WithSpan<super::ValidationError>> {
1735 use crate::span::Span;
1736
1737 let mut module = crate::Module::default();
1738 module.global_expressions.append(expr, Span::default());
1739
1740 let mut validator = super::Validator::new(super::ValidationFlags::CONSTANTS, caps);
1741
1742 validator.validate(&module)
1743}
1744
1745#[cfg(feature = "validate")]
1747#[test]
1748fn f64_runtime_literals() {
1749 let result = validate_with_expression(
1750 crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1751 super::Capabilities::default(),
1752 );
1753 let error = result.unwrap_err().into_inner();
1754 assert!(matches!(
1755 error,
1756 crate::valid::ValidationError::Function {
1757 source: super::FunctionError::Expression {
1758 source: super::ExpressionError::Literal(super::LiteralError::Width(
1759 super::r#type::WidthError::MissingCapability {
1760 name: "f64",
1761 flag: "FLOAT64",
1762 }
1763 ),),
1764 ..
1765 },
1766 ..
1767 }
1768 ));
1769
1770 let result = validate_with_expression(
1771 crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1772 super::Capabilities::default() | super::Capabilities::FLOAT64,
1773 );
1774 assert!(result.is_ok());
1775}
1776
1777#[cfg(feature = "validate")]
1779#[test]
1780fn f64_const_literals() {
1781 let result = validate_with_const_expression(
1782 crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1783 super::Capabilities::default(),
1784 );
1785 let error = result.unwrap_err().into_inner();
1786 assert!(matches!(
1787 error,
1788 crate::valid::ValidationError::ConstExpression {
1789 source: super::ConstExpressionError::Literal(super::LiteralError::Width(
1790 super::r#type::WidthError::MissingCapability {
1791 name: "f64",
1792 flag: "FLOAT64",
1793 }
1794 )),
1795 ..
1796 }
1797 ));
1798
1799 let result = validate_with_const_expression(
1800 crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1801 super::Capabilities::default() | super::Capabilities::FLOAT64,
1802 );
1803 assert!(result.is_ok());
1804}
1805
1806#[cfg(feature = "validate")]
1808#[test]
1809fn i64_runtime_literals() {
1810 let result = validate_with_expression(
1811 crate::Expression::Literal(crate::Literal::I64(1729)),
1812 super::Capabilities::all(),
1814 );
1815 let error = result.unwrap_err().into_inner();
1816 assert!(matches!(
1817 error,
1818 crate::valid::ValidationError::Function {
1819 source: super::FunctionError::Expression {
1820 source: super::ExpressionError::Literal(super::LiteralError::Width(
1821 super::r#type::WidthError::Unsupported64Bit
1822 ),),
1823 ..
1824 },
1825 ..
1826 }
1827 ));
1828}
1829
1830#[cfg(feature = "validate")]
1832#[test]
1833fn i64_const_literals() {
1834 let result = validate_with_const_expression(
1835 crate::Expression::Literal(crate::Literal::I64(1729)),
1836 super::Capabilities::all(),
1838 );
1839 let error = result.unwrap_err().into_inner();
1840 assert!(matches!(
1841 error,
1842 crate::valid::ValidationError::ConstExpression {
1843 source: super::ConstExpressionError::Literal(super::LiteralError::Width(
1844 super::r#type::WidthError::Unsupported64Bit,
1845 ),),
1846 ..
1847 }
1848 ));
1849}