naga/back/
pipeline_constants.rs

1use super::PipelineConstants;
2use crate::{
3    proc::{ConstantEvaluator, ConstantEvaluatorError, Emitter},
4    valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator},
5    Arena, Block, Constant, Expression, Function, Handle, Literal, Module, Override, Range, Scalar,
6    Span, Statement, TypeInner, WithSpan,
7};
8use std::{borrow::Cow, collections::HashSet, mem};
9use thiserror::Error;
10
11#[derive(Error, Debug, Clone)]
12#[cfg_attr(test, derive(PartialEq))]
13pub enum PipelineConstantError {
14    #[error("Missing value for pipeline-overridable constant with identifier string: '{0}'")]
15    MissingValue(String),
16    #[error("Source f64 value needs to be finite (NaNs and Inifinites are not allowed) for number destinations")]
17    SrcNeedsToBeFinite,
18    #[error("Source f64 value doesn't fit in destination")]
19    DstRangeTooSmall,
20    #[error(transparent)]
21    ConstantEvaluatorError(#[from] ConstantEvaluatorError),
22    #[error(transparent)]
23    ValidationError(#[from] WithSpan<ValidationError>),
24}
25
26/// Replace all overrides in `module` with constants.
27///
28/// If no changes are needed, this just returns `Cow::Borrowed`
29/// references to `module` and `module_info`. Otherwise, it clones
30/// `module`, edits its [`global_expressions`] arena to contain only
31/// fully-evaluated expressions, and returns `Cow::Owned` values
32/// holding the simplified module and its validation results.
33///
34/// In either case, the module returned has an empty `overrides`
35/// arena, and the `global_expressions` arena contains only
36/// fully-evaluated expressions.
37///
38/// [`global_expressions`]: Module::global_expressions
39pub fn process_overrides<'a>(
40    module: &'a Module,
41    module_info: &'a ModuleInfo,
42    pipeline_constants: &PipelineConstants,
43) -> Result<(Cow<'a, Module>, Cow<'a, ModuleInfo>), PipelineConstantError> {
44    if module.overrides.is_empty() {
45        return Ok((Cow::Borrowed(module), Cow::Borrowed(module_info)));
46    }
47
48    let mut module = module.clone();
49
50    // A map from override handles to the handles of the constants
51    // we've replaced them with.
52    let mut override_map = Vec::with_capacity(module.overrides.len());
53
54    // A map from `module`'s original global expression handles to
55    // handles in the new, simplified global expression arena.
56    let mut adjusted_global_expressions = Vec::with_capacity(module.global_expressions.len());
57
58    // The set of constants whose initializer handles we've already
59    // updated to refer to the newly built global expression arena.
60    //
61    // All constants in `module` must have their `init` handles
62    // updated to point into the new, simplified global expression
63    // arena. Some of these we can most easily handle as a side effect
64    // during the simplification process, but we must handle the rest
65    // in a final fixup pass, guided by `adjusted_global_expressions`. We
66    // add their handles to this set, so that the final fixup step can
67    // leave them alone.
68    let mut adjusted_constant_initializers = HashSet::with_capacity(module.constants.len());
69
70    let mut global_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
71
72    // An iterator through the original overrides table, consumed in
73    // approximate tandem with the global expressions.
74    let mut override_iter = module.overrides.drain();
75
76    // Do two things in tandem:
77    //
78    // - Rebuild the global expression arena from scratch, fully
79    //   evaluating all expressions, and replacing each `Override`
80    //   expression in `module.global_expressions` with a `Constant`
81    //   expression.
82    //
83    // - Build a new `Constant` in `module.constants` to take the
84    //   place of each `Override`.
85    //
86    // Build a map from old global expression handles to their
87    // fully-evaluated counterparts in `adjusted_global_expressions` as we
88    // go.
89    //
90    // Why in tandem? Overrides refer to expressions, and expressions
91    // refer to overrides, so we can't disentangle the two into
92    // separate phases. However, we can take advantage of the fact
93    // that the overrides and expressions must form a DAG, and work
94    // our way from the leaves to the roots, replacing and evaluating
95    // as we go.
96    //
97    // Although the two loops are nested, this is really two
98    // alternating phases: we adjust and evaluate constant expressions
99    // until we hit an `Override` expression, at which point we switch
100    // to building `Constant`s for `Overrides` until we've handled the
101    // one used by the expression. Then we switch back to processing
102    // expressions. Because we know they form a DAG, we know the
103    // `Override` expressions we encounter can only have initializers
104    // referring to global expressions we've already simplified.
105    for (old_h, expr, span) in module.global_expressions.drain() {
106        let mut expr = match expr {
107            Expression::Override(h) => {
108                let c_h = if let Some(new_h) = override_map.get(h.index()) {
109                    *new_h
110                } else {
111                    let mut new_h = None;
112                    for entry in override_iter.by_ref() {
113                        let stop = entry.0 == h;
114                        new_h = Some(process_override(
115                            entry,
116                            pipeline_constants,
117                            &mut module,
118                            &mut override_map,
119                            &adjusted_global_expressions,
120                            &mut adjusted_constant_initializers,
121                            &mut global_expression_kind_tracker,
122                        )?);
123                        if stop {
124                            break;
125                        }
126                    }
127                    new_h.unwrap()
128                };
129                Expression::Constant(c_h)
130            }
131            Expression::Constant(c_h) => {
132                if adjusted_constant_initializers.insert(c_h) {
133                    let init = &mut module.constants[c_h].init;
134                    *init = adjusted_global_expressions[init.index()];
135                }
136                expr
137            }
138            expr => expr,
139        };
140        let mut evaluator = ConstantEvaluator::for_wgsl_module(
141            &mut module,
142            &mut global_expression_kind_tracker,
143            false,
144        );
145        adjust_expr(&adjusted_global_expressions, &mut expr);
146        let h = evaluator.try_eval_and_append(expr, span)?;
147        debug_assert_eq!(old_h.index(), adjusted_global_expressions.len());
148        adjusted_global_expressions.push(h);
149    }
150
151    // Finish processing any overrides we didn't visit in the loop above.
152    for entry in override_iter {
153        process_override(
154            entry,
155            pipeline_constants,
156            &mut module,
157            &mut override_map,
158            &adjusted_global_expressions,
159            &mut adjusted_constant_initializers,
160            &mut global_expression_kind_tracker,
161        )?;
162    }
163
164    // Update the initialization expression handles of all `Constant`s
165    // and `GlobalVariable`s. Skip `Constant`s we'd already updated en
166    // passant.
167    for (_, c) in module
168        .constants
169        .iter_mut()
170        .filter(|&(c_h, _)| !adjusted_constant_initializers.contains(&c_h))
171    {
172        c.init = adjusted_global_expressions[c.init.index()];
173    }
174
175    for (_, v) in module.global_variables.iter_mut() {
176        if let Some(ref mut init) = v.init {
177            *init = adjusted_global_expressions[init.index()];
178        }
179    }
180
181    let mut functions = mem::take(&mut module.functions);
182    for (_, function) in functions.iter_mut() {
183        process_function(&mut module, &override_map, function)?;
184    }
185    module.functions = functions;
186
187    let mut entry_points = mem::take(&mut module.entry_points);
188    for ep in entry_points.iter_mut() {
189        process_function(&mut module, &override_map, &mut ep.function)?;
190    }
191    module.entry_points = entry_points;
192
193    // Now that we've rewritten all the expressions, we need to
194    // recompute their types and other metadata. For the time being,
195    // do a full re-validation.
196    let mut validator = Validator::new(ValidationFlags::all(), Capabilities::all());
197    let module_info = validator.validate_no_overrides(&module)?;
198
199    Ok((Cow::Owned(module), Cow::Owned(module_info)))
200}
201
202/// Add a [`Constant`] to `module` for the override `old_h`.
203///
204/// Add the new `Constant` to `override_map` and `adjusted_constant_initializers`.
205fn process_override(
206    (old_h, override_, span): (Handle<Override>, Override, Span),
207    pipeline_constants: &PipelineConstants,
208    module: &mut Module,
209    override_map: &mut Vec<Handle<Constant>>,
210    adjusted_global_expressions: &[Handle<Expression>],
211    adjusted_constant_initializers: &mut HashSet<Handle<Constant>>,
212    global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker,
213) -> Result<Handle<Constant>, PipelineConstantError> {
214    // Determine which key to use for `override_` in `pipeline_constants`.
215    let key = if let Some(id) = override_.id {
216        Cow::Owned(id.to_string())
217    } else if let Some(ref name) = override_.name {
218        Cow::Borrowed(name)
219    } else {
220        unreachable!();
221    };
222
223    // Generate a global expression for `override_`'s value, either
224    // from the provided `pipeline_constants` table or its initializer
225    // in the module.
226    let init = if let Some(value) = pipeline_constants.get::<str>(&key) {
227        let literal = match module.types[override_.ty].inner {
228            TypeInner::Scalar(scalar) => map_value_to_literal(*value, scalar)?,
229            _ => unreachable!(),
230        };
231        let expr = module
232            .global_expressions
233            .append(Expression::Literal(literal), Span::UNDEFINED);
234        global_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Const);
235        expr
236    } else if let Some(init) = override_.init {
237        adjusted_global_expressions[init.index()]
238    } else {
239        return Err(PipelineConstantError::MissingValue(key.to_string()));
240    };
241
242    // Generate a new `Constant` to represent the override's value.
243    let constant = Constant {
244        name: override_.name,
245        ty: override_.ty,
246        init,
247    };
248    let h = module.constants.append(constant, span);
249    debug_assert_eq!(old_h.index(), override_map.len());
250    override_map.push(h);
251    adjusted_constant_initializers.insert(h);
252    Ok(h)
253}
254
255/// Replace all override expressions in `function` with fully-evaluated constants.
256///
257/// Replace all `Expression::Override`s in `function`'s expression arena with
258/// the corresponding `Expression::Constant`s, as given in `override_map`.
259/// Replace any expressions whose values are now known with their fully
260/// evaluated form.
261///
262/// If `h` is a `Handle<Override>`, then `override_map[h.index()]` is the
263/// `Handle<Constant>` for the override's final value.
264fn process_function(
265    module: &mut Module,
266    override_map: &[Handle<Constant>],
267    function: &mut Function,
268) -> Result<(), ConstantEvaluatorError> {
269    // A map from original local expression handles to
270    // handles in the new, local expression arena.
271    let mut adjusted_local_expressions = Vec::with_capacity(function.expressions.len());
272
273    let mut local_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
274
275    let mut expressions = mem::take(&mut function.expressions);
276
277    // Dummy `emitter` and `block` for the constant evaluator.
278    // We can ignore the concept of emitting expressions here since
279    // expressions have already been covered by a `Statement::Emit`
280    // in the frontend.
281    // The only thing we might have to do is remove some expressions
282    // that have been covered by a `Statement::Emit`. See the docs of
283    // `filter_emits_in_block` for the reasoning.
284    let mut emitter = Emitter::default();
285    let mut block = Block::new();
286
287    let mut evaluator = ConstantEvaluator::for_wgsl_function(
288        module,
289        &mut function.expressions,
290        &mut local_expression_kind_tracker,
291        &mut emitter,
292        &mut block,
293    );
294
295    for (old_h, mut expr, span) in expressions.drain() {
296        if let Expression::Override(h) = expr {
297            expr = Expression::Constant(override_map[h.index()]);
298        }
299        adjust_expr(&adjusted_local_expressions, &mut expr);
300        let h = evaluator.try_eval_and_append(expr, span)?;
301        debug_assert_eq!(old_h.index(), adjusted_local_expressions.len());
302        adjusted_local_expressions.push(h);
303    }
304
305    adjust_block(&adjusted_local_expressions, &mut function.body);
306
307    filter_emits_in_block(&mut function.body, &function.expressions);
308
309    // Update local expression initializers.
310    for (_, local) in function.local_variables.iter_mut() {
311        if let &mut Some(ref mut init) = &mut local.init {
312            *init = adjusted_local_expressions[init.index()];
313        }
314    }
315
316    // We've changed the keys of `function.named_expression`, so we have to
317    // rebuild it from scratch.
318    let named_expressions = mem::take(&mut function.named_expressions);
319    for (expr_h, name) in named_expressions {
320        function
321            .named_expressions
322            .insert(adjusted_local_expressions[expr_h.index()], name);
323    }
324
325    Ok(())
326}
327
328/// Replace every expression handle in `expr` with its counterpart
329/// given by `new_pos`.
330fn adjust_expr(new_pos: &[Handle<Expression>], expr: &mut Expression) {
331    let adjust = |expr: &mut Handle<Expression>| {
332        *expr = new_pos[expr.index()];
333    };
334    match *expr {
335        Expression::Compose {
336            ref mut components,
337            ty: _,
338        } => {
339            for c in components.iter_mut() {
340                adjust(c);
341            }
342        }
343        Expression::Access {
344            ref mut base,
345            ref mut index,
346        } => {
347            adjust(base);
348            adjust(index);
349        }
350        Expression::AccessIndex {
351            ref mut base,
352            index: _,
353        } => {
354            adjust(base);
355        }
356        Expression::Splat {
357            ref mut value,
358            size: _,
359        } => {
360            adjust(value);
361        }
362        Expression::Swizzle {
363            ref mut vector,
364            size: _,
365            pattern: _,
366        } => {
367            adjust(vector);
368        }
369        Expression::Load { ref mut pointer } => {
370            adjust(pointer);
371        }
372        Expression::ImageSample {
373            ref mut image,
374            ref mut sampler,
375            ref mut coordinate,
376            ref mut array_index,
377            ref mut offset,
378            ref mut level,
379            ref mut depth_ref,
380            gather: _,
381        } => {
382            adjust(image);
383            adjust(sampler);
384            adjust(coordinate);
385            if let Some(e) = array_index.as_mut() {
386                adjust(e);
387            }
388            if let Some(e) = offset.as_mut() {
389                adjust(e);
390            }
391            match *level {
392                crate::SampleLevel::Exact(ref mut expr)
393                | crate::SampleLevel::Bias(ref mut expr) => {
394                    adjust(expr);
395                }
396                crate::SampleLevel::Gradient {
397                    ref mut x,
398                    ref mut y,
399                } => {
400                    adjust(x);
401                    adjust(y);
402                }
403                _ => {}
404            }
405            if let Some(e) = depth_ref.as_mut() {
406                adjust(e);
407            }
408        }
409        Expression::ImageLoad {
410            ref mut image,
411            ref mut coordinate,
412            ref mut array_index,
413            ref mut sample,
414            ref mut level,
415        } => {
416            adjust(image);
417            adjust(coordinate);
418            if let Some(e) = array_index.as_mut() {
419                adjust(e);
420            }
421            if let Some(e) = sample.as_mut() {
422                adjust(e);
423            }
424            if let Some(e) = level.as_mut() {
425                adjust(e);
426            }
427        }
428        Expression::ImageQuery {
429            ref mut image,
430            ref mut query,
431        } => {
432            adjust(image);
433            match *query {
434                crate::ImageQuery::Size { ref mut level } => {
435                    if let Some(e) = level.as_mut() {
436                        adjust(e);
437                    }
438                }
439                crate::ImageQuery::NumLevels
440                | crate::ImageQuery::NumLayers
441                | crate::ImageQuery::NumSamples => {}
442            }
443        }
444        Expression::Unary {
445            ref mut expr,
446            op: _,
447        } => {
448            adjust(expr);
449        }
450        Expression::Binary {
451            ref mut left,
452            ref mut right,
453            op: _,
454        } => {
455            adjust(left);
456            adjust(right);
457        }
458        Expression::Select {
459            ref mut condition,
460            ref mut accept,
461            ref mut reject,
462        } => {
463            adjust(condition);
464            adjust(accept);
465            adjust(reject);
466        }
467        Expression::Derivative {
468            ref mut expr,
469            axis: _,
470            ctrl: _,
471        } => {
472            adjust(expr);
473        }
474        Expression::Relational {
475            ref mut argument,
476            fun: _,
477        } => {
478            adjust(argument);
479        }
480        Expression::Math {
481            ref mut arg,
482            ref mut arg1,
483            ref mut arg2,
484            ref mut arg3,
485            fun: _,
486        } => {
487            adjust(arg);
488            if let Some(e) = arg1.as_mut() {
489                adjust(e);
490            }
491            if let Some(e) = arg2.as_mut() {
492                adjust(e);
493            }
494            if let Some(e) = arg3.as_mut() {
495                adjust(e);
496            }
497        }
498        Expression::As {
499            ref mut expr,
500            kind: _,
501            convert: _,
502        } => {
503            adjust(expr);
504        }
505        Expression::ArrayLength(ref mut expr) => {
506            adjust(expr);
507        }
508        Expression::RayQueryGetIntersection {
509            ref mut query,
510            committed: _,
511        } => {
512            adjust(query);
513        }
514        Expression::Literal(_)
515        | Expression::FunctionArgument(_)
516        | Expression::GlobalVariable(_)
517        | Expression::LocalVariable(_)
518        | Expression::CallResult(_)
519        | Expression::RayQueryProceedResult
520        | Expression::Constant(_)
521        | Expression::Override(_)
522        | Expression::ZeroValue(_)
523        | Expression::AtomicResult {
524            ty: _,
525            comparison: _,
526        }
527        | Expression::WorkGroupUniformLoadResult { ty: _ }
528        | Expression::SubgroupBallotResult
529        | Expression::SubgroupOperationResult { .. } => {}
530    }
531}
532
533/// Replace every expression handle in `block` with its counterpart
534/// given by `new_pos`.
535fn adjust_block(new_pos: &[Handle<Expression>], block: &mut Block) {
536    for stmt in block.iter_mut() {
537        adjust_stmt(new_pos, stmt);
538    }
539}
540
541/// Replace every expression handle in `stmt` with its counterpart
542/// given by `new_pos`.
543fn adjust_stmt(new_pos: &[Handle<Expression>], stmt: &mut Statement) {
544    let adjust = |expr: &mut Handle<Expression>| {
545        *expr = new_pos[expr.index()];
546    };
547    match *stmt {
548        Statement::Emit(ref mut range) => {
549            if let Some((mut first, mut last)) = range.first_and_last() {
550                adjust(&mut first);
551                adjust(&mut last);
552                *range = Range::new_from_bounds(first, last);
553            }
554        }
555        Statement::Block(ref mut block) => {
556            adjust_block(new_pos, block);
557        }
558        Statement::If {
559            ref mut condition,
560            ref mut accept,
561            ref mut reject,
562        } => {
563            adjust(condition);
564            adjust_block(new_pos, accept);
565            adjust_block(new_pos, reject);
566        }
567        Statement::Switch {
568            ref mut selector,
569            ref mut cases,
570        } => {
571            adjust(selector);
572            for case in cases.iter_mut() {
573                adjust_block(new_pos, &mut case.body);
574            }
575        }
576        Statement::Loop {
577            ref mut body,
578            ref mut continuing,
579            ref mut break_if,
580        } => {
581            adjust_block(new_pos, body);
582            adjust_block(new_pos, continuing);
583            if let Some(e) = break_if.as_mut() {
584                adjust(e);
585            }
586        }
587        Statement::Return { ref mut value } => {
588            if let Some(e) = value.as_mut() {
589                adjust(e);
590            }
591        }
592        Statement::Store {
593            ref mut pointer,
594            ref mut value,
595        } => {
596            adjust(pointer);
597            adjust(value);
598        }
599        Statement::ImageStore {
600            ref mut image,
601            ref mut coordinate,
602            ref mut array_index,
603            ref mut value,
604        } => {
605            adjust(image);
606            adjust(coordinate);
607            if let Some(e) = array_index.as_mut() {
608                adjust(e);
609            }
610            adjust(value);
611        }
612        crate::Statement::Atomic {
613            ref mut pointer,
614            ref mut value,
615            ref mut result,
616            ref mut fun,
617        } => {
618            adjust(pointer);
619            adjust(value);
620            adjust(result);
621            match *fun {
622                crate::AtomicFunction::Exchange {
623                    compare: Some(ref mut compare),
624                } => {
625                    adjust(compare);
626                }
627                crate::AtomicFunction::Add
628                | crate::AtomicFunction::Subtract
629                | crate::AtomicFunction::And
630                | crate::AtomicFunction::ExclusiveOr
631                | crate::AtomicFunction::InclusiveOr
632                | crate::AtomicFunction::Min
633                | crate::AtomicFunction::Max
634                | crate::AtomicFunction::Exchange { compare: None } => {}
635            }
636        }
637        Statement::WorkGroupUniformLoad {
638            ref mut pointer,
639            ref mut result,
640        } => {
641            adjust(pointer);
642            adjust(result);
643        }
644        Statement::SubgroupBallot {
645            ref mut result,
646            ref mut predicate,
647        } => {
648            if let Some(ref mut predicate) = *predicate {
649                adjust(predicate);
650            }
651            adjust(result);
652        }
653        Statement::SubgroupCollectiveOperation {
654            ref mut argument,
655            ref mut result,
656            ..
657        } => {
658            adjust(argument);
659            adjust(result);
660        }
661        Statement::SubgroupGather {
662            ref mut mode,
663            ref mut argument,
664            ref mut result,
665        } => {
666            match *mode {
667                crate::GatherMode::BroadcastFirst => {}
668                crate::GatherMode::Broadcast(ref mut index)
669                | crate::GatherMode::Shuffle(ref mut index)
670                | crate::GatherMode::ShuffleDown(ref mut index)
671                | crate::GatherMode::ShuffleUp(ref mut index)
672                | crate::GatherMode::ShuffleXor(ref mut index) => {
673                    adjust(index);
674                }
675            }
676            adjust(argument);
677            adjust(result)
678        }
679        Statement::Call {
680            ref mut arguments,
681            ref mut result,
682            function: _,
683        } => {
684            for argument in arguments.iter_mut() {
685                adjust(argument);
686            }
687            if let Some(e) = result.as_mut() {
688                adjust(e);
689            }
690        }
691        Statement::RayQuery {
692            ref mut query,
693            ref mut fun,
694        } => {
695            adjust(query);
696            match *fun {
697                crate::RayQueryFunction::Initialize {
698                    ref mut acceleration_structure,
699                    ref mut descriptor,
700                } => {
701                    adjust(acceleration_structure);
702                    adjust(descriptor);
703                }
704                crate::RayQueryFunction::Proceed { ref mut result } => {
705                    adjust(result);
706                }
707                crate::RayQueryFunction::Terminate => {}
708            }
709        }
710        Statement::Break | Statement::Continue | Statement::Kill | Statement::Barrier(_) => {}
711    }
712}
713
714/// Adjust [`Emit`] statements in `block` to skip [`needs_pre_emit`] expressions we have introduced.
715///
716/// According to validation, [`Emit`] statements must not cover any expressions
717/// for which [`Expression::needs_pre_emit`] returns true. All expressions built
718/// by successful constant evaluation fall into that category, meaning that
719/// `process_function` will usually rewrite [`Override`] expressions and those
720/// that use their values into pre-emitted expressions, leaving any [`Emit`]
721/// statements that cover them invalid.
722///
723/// This function rewrites all [`Emit`] statements into zero or more new
724/// [`Emit`] statements covering only those expressions in the original range
725/// that are not pre-emitted.
726///
727/// [`Emit`]: Statement::Emit
728/// [`needs_pre_emit`]: Expression::needs_pre_emit
729/// [`Override`]: Expression::Override
730fn filter_emits_in_block(block: &mut Block, expressions: &Arena<Expression>) {
731    let original = std::mem::replace(block, Block::with_capacity(block.len()));
732    for (stmt, span) in original.span_into_iter() {
733        match stmt {
734            Statement::Emit(range) => {
735                let mut current = None;
736                for expr_h in range {
737                    if expressions[expr_h].needs_pre_emit() {
738                        if let Some((first, last)) = current {
739                            block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
740                        }
741
742                        current = None;
743                    } else if let Some((_, ref mut last)) = current {
744                        *last = expr_h;
745                    } else {
746                        current = Some((expr_h, expr_h));
747                    }
748                }
749                if let Some((first, last)) = current {
750                    block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
751                }
752            }
753            Statement::Block(mut child) => {
754                filter_emits_in_block(&mut child, expressions);
755                block.push(Statement::Block(child), span);
756            }
757            Statement::If {
758                condition,
759                mut accept,
760                mut reject,
761            } => {
762                filter_emits_in_block(&mut accept, expressions);
763                filter_emits_in_block(&mut reject, expressions);
764                block.push(
765                    Statement::If {
766                        condition,
767                        accept,
768                        reject,
769                    },
770                    span,
771                );
772            }
773            Statement::Switch {
774                selector,
775                mut cases,
776            } => {
777                for case in &mut cases {
778                    filter_emits_in_block(&mut case.body, expressions);
779                }
780                block.push(Statement::Switch { selector, cases }, span);
781            }
782            Statement::Loop {
783                mut body,
784                mut continuing,
785                break_if,
786            } => {
787                filter_emits_in_block(&mut body, expressions);
788                filter_emits_in_block(&mut continuing, expressions);
789                block.push(
790                    Statement::Loop {
791                        body,
792                        continuing,
793                        break_if,
794                    },
795                    span,
796                );
797            }
798            stmt => block.push(stmt.clone(), span),
799        }
800    }
801}
802
803fn map_value_to_literal(value: f64, scalar: Scalar) -> Result<Literal, PipelineConstantError> {
804    // note that in rust 0.0 == -0.0
805    match scalar {
806        Scalar::BOOL => {
807            // https://webidl.spec.whatwg.org/#js-boolean
808            let value = value != 0.0 && !value.is_nan();
809            Ok(Literal::Bool(value))
810        }
811        Scalar::I32 => {
812            // https://webidl.spec.whatwg.org/#js-long
813            if !value.is_finite() {
814                return Err(PipelineConstantError::SrcNeedsToBeFinite);
815            }
816
817            let value = value.trunc();
818            if value < f64::from(i32::MIN) || value > f64::from(i32::MAX) {
819                return Err(PipelineConstantError::DstRangeTooSmall);
820            }
821
822            let value = value as i32;
823            Ok(Literal::I32(value))
824        }
825        Scalar::U32 => {
826            // https://webidl.spec.whatwg.org/#js-unsigned-long
827            if !value.is_finite() {
828                return Err(PipelineConstantError::SrcNeedsToBeFinite);
829            }
830
831            let value = value.trunc();
832            if value < f64::from(u32::MIN) || value > f64::from(u32::MAX) {
833                return Err(PipelineConstantError::DstRangeTooSmall);
834            }
835
836            let value = value as u32;
837            Ok(Literal::U32(value))
838        }
839        Scalar::F32 => {
840            // https://webidl.spec.whatwg.org/#js-float
841            if !value.is_finite() {
842                return Err(PipelineConstantError::SrcNeedsToBeFinite);
843            }
844
845            let value = value as f32;
846            if !value.is_finite() {
847                return Err(PipelineConstantError::DstRangeTooSmall);
848            }
849
850            Ok(Literal::F32(value))
851        }
852        Scalar::F64 => {
853            // https://webidl.spec.whatwg.org/#js-double
854            if !value.is_finite() {
855                return Err(PipelineConstantError::SrcNeedsToBeFinite);
856            }
857
858            Ok(Literal::F64(value))
859        }
860        _ => unreachable!(),
861    }
862}
863
864#[test]
865fn test_map_value_to_literal() {
866    let bool_test_cases = [
867        (0.0, false),
868        (-0.0, false),
869        (f64::NAN, false),
870        (1.0, true),
871        (f64::INFINITY, true),
872        (f64::NEG_INFINITY, true),
873    ];
874    for (value, out) in bool_test_cases {
875        let res = Ok(Literal::Bool(out));
876        assert_eq!(map_value_to_literal(value, Scalar::BOOL), res);
877    }
878
879    for scalar in [Scalar::I32, Scalar::U32, Scalar::F32, Scalar::F64] {
880        for value in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] {
881            let res = Err(PipelineConstantError::SrcNeedsToBeFinite);
882            assert_eq!(map_value_to_literal(value, scalar), res);
883        }
884    }
885
886    // i32
887    assert_eq!(
888        map_value_to_literal(f64::from(i32::MIN), Scalar::I32),
889        Ok(Literal::I32(i32::MIN))
890    );
891    assert_eq!(
892        map_value_to_literal(f64::from(i32::MAX), Scalar::I32),
893        Ok(Literal::I32(i32::MAX))
894    );
895    assert_eq!(
896        map_value_to_literal(f64::from(i32::MIN) - 1.0, Scalar::I32),
897        Err(PipelineConstantError::DstRangeTooSmall)
898    );
899    assert_eq!(
900        map_value_to_literal(f64::from(i32::MAX) + 1.0, Scalar::I32),
901        Err(PipelineConstantError::DstRangeTooSmall)
902    );
903
904    // u32
905    assert_eq!(
906        map_value_to_literal(f64::from(u32::MIN), Scalar::U32),
907        Ok(Literal::U32(u32::MIN))
908    );
909    assert_eq!(
910        map_value_to_literal(f64::from(u32::MAX), Scalar::U32),
911        Ok(Literal::U32(u32::MAX))
912    );
913    assert_eq!(
914        map_value_to_literal(f64::from(u32::MIN) - 1.0, Scalar::U32),
915        Err(PipelineConstantError::DstRangeTooSmall)
916    );
917    assert_eq!(
918        map_value_to_literal(f64::from(u32::MAX) + 1.0, Scalar::U32),
919        Err(PipelineConstantError::DstRangeTooSmall)
920    );
921
922    // f32
923    assert_eq!(
924        map_value_to_literal(f64::from(f32::MIN), Scalar::F32),
925        Ok(Literal::F32(f32::MIN))
926    );
927    assert_eq!(
928        map_value_to_literal(f64::from(f32::MAX), Scalar::F32),
929        Ok(Literal::F32(f32::MAX))
930    );
931    assert_eq!(
932        map_value_to_literal(-f64::from_bits(0x47efffffefffffff), Scalar::F32),
933        Ok(Literal::F32(f32::MIN))
934    );
935    assert_eq!(
936        map_value_to_literal(f64::from_bits(0x47efffffefffffff), Scalar::F32),
937        Ok(Literal::F32(f32::MAX))
938    );
939    assert_eq!(
940        map_value_to_literal(-f64::from_bits(0x47effffff0000000), Scalar::F32),
941        Err(PipelineConstantError::DstRangeTooSmall)
942    );
943    assert_eq!(
944        map_value_to_literal(f64::from_bits(0x47effffff0000000), Scalar::F32),
945        Err(PipelineConstantError::DstRangeTooSmall)
946    );
947
948    // f64
949    assert_eq!(
950        map_value_to_literal(f64::MIN, Scalar::F64),
951        Ok(Literal::F64(f64::MIN))
952    );
953    assert_eq!(
954        map_value_to_literal(f64::MAX, Scalar::F64),
955        Ok(Literal::F64(f64::MAX))
956    );
957}