naga/back/msl/
mod.rs

1/*!
2Backend for [MSL][msl] (Metal Shading Language).
3
4## Binding model
5
6Metal's bindings are flat per resource. Since there isn't an obvious mapping
7from SPIR-V's descriptor sets, we require a separate mapping provided in the options.
8This mapping may have one or more resource end points for each descriptor set + index
9pair.
10
11## Entry points
12
13Even though MSL and our IR appear to be similar in that the entry points in both can
14accept arguments and return values, the restrictions are different.
15MSL allows the varyings to be either in separate arguments, or inside a single
16`[[stage_in]]` struct. We gather input varyings and form this artificial structure.
17We also add all the (non-Private) globals into the arguments.
18
19At the beginning of the entry point, we assign the local constants and re-compose
20the arguments as they are declared on IR side, so that the rest of the logic can
21pretend that MSL doesn't have all the restrictions it has.
22
23For the result type, if it's a structure, we re-compose it with a temporary value
24holding the result.
25
26[msl]: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
27*/
28
29use crate::{arena::Handle, proc::index, valid::ModuleInfo};
30use std::fmt::{Error as FmtError, Write};
31
32mod keywords;
33pub mod sampler;
34mod writer;
35
36pub use writer::Writer;
37
38pub type Slot = u8;
39pub type InlineSamplerIndex = u8;
40
41#[derive(Clone, Debug, PartialEq, Eq, Hash)]
42#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
43#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
44pub enum BindSamplerTarget {
45    Resource(Slot),
46    Inline(InlineSamplerIndex),
47}
48
49#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
50#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
51#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
52#[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(default))]
53pub struct BindTarget {
54    pub buffer: Option<Slot>,
55    pub texture: Option<Slot>,
56    pub sampler: Option<BindSamplerTarget>,
57    /// If the binding is an unsized binding array, this overrides the size.
58    pub binding_array_size: Option<u32>,
59    pub mutable: bool,
60}
61
62// Using `BTreeMap` instead of `HashMap` so that we can hash itself.
63pub type BindingMap = std::collections::BTreeMap<crate::ResourceBinding, BindTarget>;
64
65#[derive(Clone, Debug, Default, Hash, Eq, PartialEq)]
66#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
67#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
68#[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(default))]
69pub struct EntryPointResources {
70    pub resources: BindingMap,
71
72    pub push_constant_buffer: Option<Slot>,
73
74    /// The slot of a buffer that contains an array of `u32`,
75    /// one for the size of each bound buffer that contains a runtime array,
76    /// in order of [`crate::GlobalVariable`] declarations.
77    pub sizes_buffer: Option<Slot>,
78}
79
80pub type EntryPointResourceMap = std::collections::BTreeMap<String, EntryPointResources>;
81
82enum ResolvedBinding {
83    BuiltIn(crate::BuiltIn),
84    Attribute(u32),
85    Color {
86        location: u32,
87        second_blend_source: bool,
88    },
89    User {
90        prefix: &'static str,
91        index: u32,
92        interpolation: Option<ResolvedInterpolation>,
93    },
94    Resource(BindTarget),
95}
96
97#[derive(Copy, Clone)]
98enum ResolvedInterpolation {
99    CenterPerspective,
100    CenterNoPerspective,
101    CentroidPerspective,
102    CentroidNoPerspective,
103    SamplePerspective,
104    SampleNoPerspective,
105    Flat,
106}
107
108// Note: some of these should be removed in favor of proper IR validation.
109
110#[derive(Debug, thiserror::Error)]
111pub enum Error {
112    #[error(transparent)]
113    Format(#[from] FmtError),
114    #[error("bind target {0:?} is empty")]
115    UnimplementedBindTarget(BindTarget),
116    #[error("composing of {0:?} is not implemented yet")]
117    UnsupportedCompose(Handle<crate::Type>),
118    #[error("operation {0:?} is not implemented yet")]
119    UnsupportedBinaryOp(crate::BinaryOperator),
120    #[error("standard function '{0}' is not implemented yet")]
121    UnsupportedCall(String),
122    #[error("feature '{0}' is not implemented yet")]
123    FeatureNotImplemented(String),
124    #[error("internal naga error: module should not have validated: {0}")]
125    GenericValidation(String),
126    #[error("BuiltIn {0:?} is not supported")]
127    UnsupportedBuiltIn(crate::BuiltIn),
128    #[error("capability {0:?} is not supported")]
129    CapabilityNotSupported(crate::valid::Capabilities),
130    #[error("attribute '{0}' is not supported for target MSL version")]
131    UnsupportedAttribute(String),
132    #[error("function '{0}' is not supported for target MSL version")]
133    UnsupportedFunction(String),
134    #[error("can not use writeable storage buffers in fragment stage prior to MSL 1.2")]
135    UnsupportedWriteableStorageBuffer,
136    #[error("can not use writeable storage textures in {0:?} stage prior to MSL 1.2")]
137    UnsupportedWriteableStorageTexture(crate::ShaderStage),
138    #[error("can not use read-write storage textures prior to MSL 1.2")]
139    UnsupportedRWStorageTexture,
140    #[error("array of '{0}' is not supported for target MSL version")]
141    UnsupportedArrayOf(String),
142    #[error("array of type '{0:?}' is not supported")]
143    UnsupportedArrayOfType(Handle<crate::Type>),
144    #[error("ray tracing is not supported prior to MSL 2.3")]
145    UnsupportedRayTracing,
146    #[error("overrides should not be present at this stage")]
147    Override,
148}
149
150#[derive(Clone, Debug, PartialEq, thiserror::Error)]
151#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
152#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
153pub enum EntryPointError {
154    #[error("global '{0}' doesn't have a binding")]
155    MissingBinding(String),
156    #[error("mapping of {0:?} is missing")]
157    MissingBindTarget(crate::ResourceBinding),
158    #[error("mapping for push constants is missing")]
159    MissingPushConstants,
160    #[error("mapping for sizes buffer is missing")]
161    MissingSizesBuffer,
162}
163
164/// Points in the MSL code where we might emit a pipeline input or output.
165///
166/// Note that, even though vertex shaders' outputs are always fragment
167/// shaders' inputs, we still need to distinguish `VertexOutput` and
168/// `FragmentInput`, since there are certain differences in the way
169/// [`ResolvedBinding`s] are represented on either side.
170///
171/// [`ResolvedBinding`s]: ResolvedBinding
172#[derive(Clone, Copy, Debug)]
173enum LocationMode {
174    /// Input to the vertex shader.
175    VertexInput,
176
177    /// Output from the vertex shader.
178    VertexOutput,
179
180    /// Input to the fragment shader.
181    FragmentInput,
182
183    /// Output from the fragment shader.
184    FragmentOutput,
185
186    /// Compute shader input or output.
187    Uniform,
188}
189
190#[derive(Clone, Debug, Hash, PartialEq, Eq)]
191#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
192#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
193pub struct Options {
194    /// (Major, Minor) target version of the Metal Shading Language.
195    pub lang_version: (u8, u8),
196    /// Map of entry-point resources, indexed by entry point function name, to slots.
197    pub per_entry_point_map: EntryPointResourceMap,
198    /// Samplers to be inlined into the code.
199    pub inline_samplers: Vec<sampler::InlineSampler>,
200    /// Make it possible to link different stages via SPIRV-Cross.
201    pub spirv_cross_compatibility: bool,
202    /// Don't panic on missing bindings, instead generate invalid MSL.
203    pub fake_missing_bindings: bool,
204    /// Bounds checking policies.
205    #[cfg_attr(feature = "deserialize", serde(default))]
206    pub bounds_check_policies: index::BoundsCheckPolicies,
207    /// Should workgroup variables be zero initialized (by polyfilling)?
208    pub zero_initialize_workgroup_memory: bool,
209}
210
211impl Default for Options {
212    fn default() -> Self {
213        Options {
214            lang_version: (1, 0),
215            per_entry_point_map: EntryPointResourceMap::default(),
216            inline_samplers: Vec::new(),
217            spirv_cross_compatibility: false,
218            fake_missing_bindings: true,
219            bounds_check_policies: index::BoundsCheckPolicies::default(),
220            zero_initialize_workgroup_memory: true,
221        }
222    }
223}
224
225/// A subset of options that are meant to be changed per pipeline.
226#[derive(Debug, Default, Clone)]
227#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
228#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
229pub struct PipelineOptions {
230    /// Allow `BuiltIn::PointSize` and inject it if doesn't exist.
231    ///
232    /// Metal doesn't like this for non-point primitive topologies and requires it for
233    /// point primitive topologies.
234    ///
235    /// Enable this for vertex shaders with point primitive topologies.
236    pub allow_and_force_point_size: bool,
237}
238
239impl Options {
240    fn resolve_local_binding(
241        &self,
242        binding: &crate::Binding,
243        mode: LocationMode,
244    ) -> Result<ResolvedBinding, Error> {
245        match *binding {
246            crate::Binding::BuiltIn(mut built_in) => {
247                match built_in {
248                    crate::BuiltIn::Position { ref mut invariant } => {
249                        if *invariant && self.lang_version < (2, 1) {
250                            return Err(Error::UnsupportedAttribute("invariant".to_string()));
251                        }
252
253                        // The 'invariant' attribute may only appear on vertex
254                        // shader outputs, not fragment shader inputs.
255                        if !matches!(mode, LocationMode::VertexOutput) {
256                            *invariant = false;
257                        }
258                    }
259                    crate::BuiltIn::BaseInstance if self.lang_version < (1, 2) => {
260                        return Err(Error::UnsupportedAttribute("base_instance".to_string()));
261                    }
262                    crate::BuiltIn::InstanceIndex if self.lang_version < (1, 2) => {
263                        return Err(Error::UnsupportedAttribute("instance_id".to_string()));
264                    }
265                    // macOS: Since Metal 2.2
266                    // iOS: Since Metal 2.3 (check depends on https://github.com/gfx-rs/naga/issues/2164)
267                    crate::BuiltIn::PrimitiveIndex if self.lang_version < (2, 2) => {
268                        return Err(Error::UnsupportedAttribute("primitive_id".to_string()));
269                    }
270                    _ => {}
271                }
272
273                Ok(ResolvedBinding::BuiltIn(built_in))
274            }
275            crate::Binding::Location {
276                location,
277                interpolation,
278                sampling,
279                second_blend_source,
280            } => match mode {
281                LocationMode::VertexInput => Ok(ResolvedBinding::Attribute(location)),
282                LocationMode::FragmentOutput => {
283                    if second_blend_source && self.lang_version < (1, 2) {
284                        return Err(Error::UnsupportedAttribute(
285                            "second_blend_source".to_string(),
286                        ));
287                    }
288                    Ok(ResolvedBinding::Color {
289                        location,
290                        second_blend_source,
291                    })
292                }
293                LocationMode::VertexOutput | LocationMode::FragmentInput => {
294                    Ok(ResolvedBinding::User {
295                        prefix: if self.spirv_cross_compatibility {
296                            "locn"
297                        } else {
298                            "loc"
299                        },
300                        index: location,
301                        interpolation: {
302                            // unwrap: The verifier ensures that vertex shader outputs and fragment
303                            // shader inputs always have fully specified interpolation, and that
304                            // sampling is `None` only for Flat interpolation.
305                            let interpolation = interpolation.unwrap();
306                            let sampling = sampling.unwrap_or(crate::Sampling::Center);
307                            Some(ResolvedInterpolation::from_binding(interpolation, sampling))
308                        },
309                    })
310                }
311                LocationMode::Uniform => Err(Error::GenericValidation(format!(
312                    "Unexpected Binding::Location({}) for the Uniform mode",
313                    location
314                ))),
315            },
316        }
317    }
318
319    fn get_entry_point_resources(&self, ep: &crate::EntryPoint) -> Option<&EntryPointResources> {
320        self.per_entry_point_map.get(&ep.name)
321    }
322
323    fn get_resource_binding_target(
324        &self,
325        ep: &crate::EntryPoint,
326        res_binding: &crate::ResourceBinding,
327    ) -> Option<&BindTarget> {
328        self.get_entry_point_resources(ep)
329            .and_then(|res| res.resources.get(res_binding))
330    }
331
332    fn resolve_resource_binding(
333        &self,
334        ep: &crate::EntryPoint,
335        res_binding: &crate::ResourceBinding,
336    ) -> Result<ResolvedBinding, EntryPointError> {
337        let target = self.get_resource_binding_target(ep, res_binding);
338        match target {
339            Some(target) => Ok(ResolvedBinding::Resource(target.clone())),
340            None if self.fake_missing_bindings => Ok(ResolvedBinding::User {
341                prefix: "fake",
342                index: 0,
343                interpolation: None,
344            }),
345            None => Err(EntryPointError::MissingBindTarget(res_binding.clone())),
346        }
347    }
348
349    fn resolve_push_constants(
350        &self,
351        ep: &crate::EntryPoint,
352    ) -> Result<ResolvedBinding, EntryPointError> {
353        let slot = self
354            .get_entry_point_resources(ep)
355            .and_then(|res| res.push_constant_buffer);
356        match slot {
357            Some(slot) => Ok(ResolvedBinding::Resource(BindTarget {
358                buffer: Some(slot),
359                ..Default::default()
360            })),
361            None if self.fake_missing_bindings => Ok(ResolvedBinding::User {
362                prefix: "fake",
363                index: 0,
364                interpolation: None,
365            }),
366            None => Err(EntryPointError::MissingPushConstants),
367        }
368    }
369
370    fn resolve_sizes_buffer(
371        &self,
372        ep: &crate::EntryPoint,
373    ) -> Result<ResolvedBinding, EntryPointError> {
374        let slot = self
375            .get_entry_point_resources(ep)
376            .and_then(|res| res.sizes_buffer);
377        match slot {
378            Some(slot) => Ok(ResolvedBinding::Resource(BindTarget {
379                buffer: Some(slot),
380                ..Default::default()
381            })),
382            None if self.fake_missing_bindings => Ok(ResolvedBinding::User {
383                prefix: "fake",
384                index: 0,
385                interpolation: None,
386            }),
387            None => Err(EntryPointError::MissingSizesBuffer),
388        }
389    }
390}
391
392impl ResolvedBinding {
393    fn as_inline_sampler<'a>(&self, options: &'a Options) -> Option<&'a sampler::InlineSampler> {
394        match *self {
395            Self::Resource(BindTarget {
396                sampler: Some(BindSamplerTarget::Inline(index)),
397                ..
398            }) => Some(&options.inline_samplers[index as usize]),
399            _ => None,
400        }
401    }
402
403    const fn as_bind_target(&self) -> Option<&BindTarget> {
404        match *self {
405            Self::Resource(ref target) => Some(target),
406            _ => None,
407        }
408    }
409
410    fn try_fmt<W: Write>(&self, out: &mut W) -> Result<(), Error> {
411        write!(out, " [[")?;
412        match *self {
413            Self::BuiltIn(built_in) => {
414                use crate::BuiltIn as Bi;
415                let name = match built_in {
416                    Bi::Position { invariant: false } => "position",
417                    Bi::Position { invariant: true } => "position, invariant",
418                    // vertex
419                    Bi::BaseInstance => "base_instance",
420                    Bi::BaseVertex => "base_vertex",
421                    Bi::ClipDistance => "clip_distance",
422                    Bi::InstanceIndex => "instance_id",
423                    Bi::PointSize => "point_size",
424                    Bi::VertexIndex => "vertex_id",
425                    // fragment
426                    Bi::FragDepth => "depth(any)",
427                    Bi::PointCoord => "point_coord",
428                    Bi::FrontFacing => "front_facing",
429                    Bi::PrimitiveIndex => "primitive_id",
430                    Bi::SampleIndex => "sample_id",
431                    Bi::SampleMask => "sample_mask",
432                    // compute
433                    Bi::GlobalInvocationId => "thread_position_in_grid",
434                    Bi::LocalInvocationId => "thread_position_in_threadgroup",
435                    Bi::LocalInvocationIndex => "thread_index_in_threadgroup",
436                    Bi::WorkGroupId => "threadgroup_position_in_grid",
437                    Bi::WorkGroupSize => "dispatch_threads_per_threadgroup",
438                    Bi::NumWorkGroups => "threadgroups_per_grid",
439                    // subgroup
440                    Bi::NumSubgroups => "simdgroups_per_threadgroup",
441                    Bi::SubgroupId => "simdgroup_index_in_threadgroup",
442                    Bi::SubgroupSize => "threads_per_simdgroup",
443                    Bi::SubgroupInvocationId => "thread_index_in_simdgroup",
444                    Bi::CullDistance | Bi::ViewIndex => {
445                        return Err(Error::UnsupportedBuiltIn(built_in))
446                    }
447                };
448                write!(out, "{name}")?;
449            }
450            Self::Attribute(index) => write!(out, "attribute({index})")?,
451            Self::Color {
452                location,
453                second_blend_source,
454            } => {
455                if second_blend_source {
456                    write!(out, "color({location}) index(1)")?
457                } else {
458                    write!(out, "color({location})")?
459                }
460            }
461            Self::User {
462                prefix,
463                index,
464                interpolation,
465            } => {
466                write!(out, "user({prefix}{index})")?;
467                if let Some(interpolation) = interpolation {
468                    write!(out, ", ")?;
469                    interpolation.try_fmt(out)?;
470                }
471            }
472            Self::Resource(ref target) => {
473                if let Some(id) = target.buffer {
474                    write!(out, "buffer({id})")?;
475                } else if let Some(id) = target.texture {
476                    write!(out, "texture({id})")?;
477                } else if let Some(BindSamplerTarget::Resource(id)) = target.sampler {
478                    write!(out, "sampler({id})")?;
479                } else {
480                    return Err(Error::UnimplementedBindTarget(target.clone()));
481                }
482            }
483        }
484        write!(out, "]]")?;
485        Ok(())
486    }
487}
488
489impl ResolvedInterpolation {
490    const fn from_binding(interpolation: crate::Interpolation, sampling: crate::Sampling) -> Self {
491        use crate::Interpolation as I;
492        use crate::Sampling as S;
493
494        match (interpolation, sampling) {
495            (I::Perspective, S::Center) => Self::CenterPerspective,
496            (I::Perspective, S::Centroid) => Self::CentroidPerspective,
497            (I::Perspective, S::Sample) => Self::SamplePerspective,
498            (I::Linear, S::Center) => Self::CenterNoPerspective,
499            (I::Linear, S::Centroid) => Self::CentroidNoPerspective,
500            (I::Linear, S::Sample) => Self::SampleNoPerspective,
501            (I::Flat, _) => Self::Flat,
502        }
503    }
504
505    fn try_fmt<W: Write>(self, out: &mut W) -> Result<(), Error> {
506        let identifier = match self {
507            Self::CenterPerspective => "center_perspective",
508            Self::CenterNoPerspective => "center_no_perspective",
509            Self::CentroidPerspective => "centroid_perspective",
510            Self::CentroidNoPerspective => "centroid_no_perspective",
511            Self::SamplePerspective => "sample_perspective",
512            Self::SampleNoPerspective => "sample_no_perspective",
513            Self::Flat => "flat",
514        };
515        out.write_str(identifier)?;
516        Ok(())
517    }
518}
519
520/// Information about a translated module that is required
521/// for the use of the result.
522pub struct TranslationInfo {
523    /// Mapping of the entry point names. Each item in the array
524    /// corresponds to an entry point index.
525    ///
526    ///Note: Some entry points may fail translation because of missing bindings.
527    pub entry_point_names: Vec<Result<String, EntryPointError>>,
528}
529
530pub fn write_string(
531    module: &crate::Module,
532    info: &ModuleInfo,
533    options: &Options,
534    pipeline_options: &PipelineOptions,
535) -> Result<(String, TranslationInfo), Error> {
536    let mut w = writer::Writer::new(String::new());
537    let info = w.write(module, info, options, pipeline_options)?;
538    Ok((w.finish(), info))
539}
540
541#[test]
542fn test_error_size() {
543    use std::mem::size_of;
544    assert_eq!(size_of::<Error>(), 32);
545}
546
547impl crate::AtomicFunction {
548    fn to_msl(self) -> Result<&'static str, Error> {
549        Ok(match self {
550            Self::Add => "fetch_add",
551            Self::Subtract => "fetch_sub",
552            Self::And => "fetch_and",
553            Self::InclusiveOr => "fetch_or",
554            Self::ExclusiveOr => "fetch_xor",
555            Self::Min => "fetch_min",
556            Self::Max => "fetch_max",
557            Self::Exchange { compare: None } => "exchange",
558            Self::Exchange { compare: Some(_) } => Err(Error::FeatureNotImplemented(
559                "atomic CompareExchange".to_string(),
560            ))?,
561        })
562    }
563}