1use crate::{
2 render_graph::{
3 Edge, Node, NodeRunError, NodeState, RenderGraphContext, RenderGraphError, RenderLabel,
4 SlotInfo, SlotLabel,
5 },
6 renderer::RenderContext,
7};
8use bevy_ecs::{define_label, intern::Interned, prelude::World, system::Resource};
9use bevy_utils::HashMap;
10use std::fmt::Debug;
11
12use super::{EdgeExistence, InternedRenderLabel, IntoRenderNodeArray};
13
14pub use bevy_render_macros::RenderSubGraph;
15
16define_label!(
17 RenderSubGraph,
19 RENDER_SUB_GRAPH_INTERNER
20);
21
22pub type InternedRenderSubGraph = Interned<dyn RenderSubGraph>;
24
25#[derive(Resource, Default)]
71pub struct RenderGraph {
72 nodes: HashMap<InternedRenderLabel, NodeState>,
73 sub_graphs: HashMap<InternedRenderSubGraph, RenderGraph>,
74}
75
76#[derive(Debug, Hash, PartialEq, Eq, Clone, RenderLabel)]
78pub struct GraphInput;
79
80impl RenderGraph {
81 pub fn update(&mut self, world: &mut World) {
83 for node in self.nodes.values_mut() {
84 node.node.update(world);
85 }
86
87 for sub_graph in self.sub_graphs.values_mut() {
88 sub_graph.update(world);
89 }
90 }
91
92 pub fn set_input(&mut self, inputs: Vec<SlotInfo>) {
94 assert!(
95 matches!(
96 self.get_node_state(GraphInput),
97 Err(RenderGraphError::InvalidNode(_))
98 ),
99 "Graph already has an input node"
100 );
101
102 self.add_node(GraphInput, GraphInputNode { inputs });
103 }
104
105 #[inline]
111 pub fn get_input_node(&self) -> Option<&NodeState> {
112 self.get_node_state(GraphInput).ok()
113 }
114
115 #[inline]
125 pub fn input_node(&self) -> &NodeState {
126 self.get_input_node().unwrap()
127 }
128
129 pub fn add_node<T>(&mut self, label: impl RenderLabel, node: T)
132 where
133 T: Node,
134 {
135 let label = label.intern();
136 let node_state = NodeState::new(label, node);
137 self.nodes.insert(label, node_state);
138 }
139
140 pub fn add_node_edges<const N: usize>(&mut self, edges: impl IntoRenderNodeArray<N>) {
145 for window in edges.into_array().windows(2) {
146 let [a, b] = window else {
147 break;
148 };
149 if let Err(err) = self.try_add_node_edge(*a, *b) {
150 match err {
151 RenderGraphError::EdgeAlreadyExists(_) => {}
154 _ => panic!("{err:?}"),
155 }
156 }
157 }
158 }
159
160 pub fn remove_node(&mut self, label: impl RenderLabel) -> Result<(), RenderGraphError> {
163 let label = label.intern();
164 if let Some(node_state) = self.nodes.remove(&label) {
165 for input_edge in node_state.edges.input_edges() {
168 match input_edge {
169 Edge::SlotEdge { output_node, .. }
170 | Edge::NodeEdge {
171 input_node: _,
172 output_node,
173 } => {
174 if let Ok(output_node) = self.get_node_state_mut(*output_node) {
175 output_node.edges.remove_output_edge(input_edge.clone())?;
176 }
177 }
178 }
179 }
180 for output_edge in node_state.edges.output_edges() {
183 match output_edge {
184 Edge::SlotEdge {
185 output_node: _,
186 output_index: _,
187 input_node,
188 input_index: _,
189 }
190 | Edge::NodeEdge {
191 output_node: _,
192 input_node,
193 } => {
194 if let Ok(input_node) = self.get_node_state_mut(*input_node) {
195 input_node.edges.remove_input_edge(output_edge.clone())?;
196 }
197 }
198 }
199 }
200 }
201
202 Ok(())
203 }
204
205 pub fn get_node_state(&self, label: impl RenderLabel) -> Result<&NodeState, RenderGraphError> {
207 let label = label.intern();
208 self.nodes
209 .get(&label)
210 .ok_or(RenderGraphError::InvalidNode(label))
211 }
212
213 pub fn get_node_state_mut(
215 &mut self,
216 label: impl RenderLabel,
217 ) -> Result<&mut NodeState, RenderGraphError> {
218 let label = label.intern();
219 self.nodes
220 .get_mut(&label)
221 .ok_or(RenderGraphError::InvalidNode(label))
222 }
223
224 pub fn get_node<T>(&self, label: impl RenderLabel) -> Result<&T, RenderGraphError>
226 where
227 T: Node,
228 {
229 self.get_node_state(label).and_then(|n| n.node())
230 }
231
232 pub fn get_node_mut<T>(&mut self, label: impl RenderLabel) -> Result<&mut T, RenderGraphError>
234 where
235 T: Node,
236 {
237 self.get_node_state_mut(label).and_then(|n| n.node_mut())
238 }
239
240 pub fn try_add_slot_edge(
249 &mut self,
250 output_node: impl RenderLabel,
251 output_slot: impl Into<SlotLabel>,
252 input_node: impl RenderLabel,
253 input_slot: impl Into<SlotLabel>,
254 ) -> Result<(), RenderGraphError> {
255 let output_slot = output_slot.into();
256 let input_slot = input_slot.into();
257
258 let output_node = output_node.intern();
259 let input_node = input_node.intern();
260
261 let output_index = self
262 .get_node_state(output_node)?
263 .output_slots
264 .get_slot_index(output_slot.clone())
265 .ok_or(RenderGraphError::InvalidOutputNodeSlot(output_slot))?;
266 let input_index = self
267 .get_node_state(input_node)?
268 .input_slots
269 .get_slot_index(input_slot.clone())
270 .ok_or(RenderGraphError::InvalidInputNodeSlot(input_slot))?;
271
272 let edge = Edge::SlotEdge {
273 output_node,
274 output_index,
275 input_node,
276 input_index,
277 };
278
279 self.validate_edge(&edge, EdgeExistence::DoesNotExist)?;
280
281 {
282 let output_node = self.get_node_state_mut(output_node)?;
283 output_node.edges.add_output_edge(edge.clone())?;
284 }
285 let input_node = self.get_node_state_mut(input_node)?;
286 input_node.edges.add_input_edge(edge)?;
287
288 Ok(())
289 }
290
291 pub fn add_slot_edge(
302 &mut self,
303 output_node: impl RenderLabel,
304 output_slot: impl Into<SlotLabel>,
305 input_node: impl RenderLabel,
306 input_slot: impl Into<SlotLabel>,
307 ) {
308 self.try_add_slot_edge(output_node, output_slot, input_node, input_slot)
309 .unwrap();
310 }
311
312 pub fn remove_slot_edge(
315 &mut self,
316 output_node: impl RenderLabel,
317 output_slot: impl Into<SlotLabel>,
318 input_node: impl RenderLabel,
319 input_slot: impl Into<SlotLabel>,
320 ) -> Result<(), RenderGraphError> {
321 let output_slot = output_slot.into();
322 let input_slot = input_slot.into();
323
324 let output_node = output_node.intern();
325 let input_node = input_node.intern();
326
327 let output_index = self
328 .get_node_state(output_node)?
329 .output_slots
330 .get_slot_index(output_slot.clone())
331 .ok_or(RenderGraphError::InvalidOutputNodeSlot(output_slot))?;
332 let input_index = self
333 .get_node_state(input_node)?
334 .input_slots
335 .get_slot_index(input_slot.clone())
336 .ok_or(RenderGraphError::InvalidInputNodeSlot(input_slot))?;
337
338 let edge = Edge::SlotEdge {
339 output_node,
340 output_index,
341 input_node,
342 input_index,
343 };
344
345 self.validate_edge(&edge, EdgeExistence::Exists)?;
346
347 {
348 let output_node = self.get_node_state_mut(output_node)?;
349 output_node.edges.remove_output_edge(edge.clone())?;
350 }
351 let input_node = self.get_node_state_mut(input_node)?;
352 input_node.edges.remove_input_edge(edge)?;
353
354 Ok(())
355 }
356
357 pub fn try_add_node_edge(
366 &mut self,
367 output_node: impl RenderLabel,
368 input_node: impl RenderLabel,
369 ) -> Result<(), RenderGraphError> {
370 let output_node = output_node.intern();
371 let input_node = input_node.intern();
372
373 let edge = Edge::NodeEdge {
374 output_node,
375 input_node,
376 };
377
378 self.validate_edge(&edge, EdgeExistence::DoesNotExist)?;
379
380 {
381 let output_node = self.get_node_state_mut(output_node)?;
382 output_node.edges.add_output_edge(edge.clone())?;
383 }
384 let input_node = self.get_node_state_mut(input_node)?;
385 input_node.edges.add_input_edge(edge)?;
386
387 Ok(())
388 }
389
390 pub fn add_node_edge(&mut self, output_node: impl RenderLabel, input_node: impl RenderLabel) {
401 self.try_add_node_edge(output_node, input_node).unwrap();
402 }
403
404 pub fn remove_node_edge(
407 &mut self,
408 output_node: impl RenderLabel,
409 input_node: impl RenderLabel,
410 ) -> Result<(), RenderGraphError> {
411 let output_node = output_node.intern();
412 let input_node = input_node.intern();
413
414 let edge = Edge::NodeEdge {
415 output_node,
416 input_node,
417 };
418
419 self.validate_edge(&edge, EdgeExistence::Exists)?;
420
421 {
422 let output_node = self.get_node_state_mut(output_node)?;
423 output_node.edges.remove_output_edge(edge.clone())?;
424 }
425 let input_node = self.get_node_state_mut(input_node)?;
426 input_node.edges.remove_input_edge(edge)?;
427
428 Ok(())
429 }
430
431 pub fn validate_edge(
434 &mut self,
435 edge: &Edge,
436 should_exist: EdgeExistence,
437 ) -> Result<(), RenderGraphError> {
438 if should_exist == EdgeExistence::Exists && !self.has_edge(edge) {
439 return Err(RenderGraphError::EdgeDoesNotExist(edge.clone()));
440 } else if should_exist == EdgeExistence::DoesNotExist && self.has_edge(edge) {
441 return Err(RenderGraphError::EdgeAlreadyExists(edge.clone()));
442 }
443
444 match *edge {
445 Edge::SlotEdge {
446 output_node,
447 output_index,
448 input_node,
449 input_index,
450 } => {
451 let output_node_state = self.get_node_state(output_node)?;
452 let input_node_state = self.get_node_state(input_node)?;
453
454 let output_slot = output_node_state
455 .output_slots
456 .get_slot(output_index)
457 .ok_or(RenderGraphError::InvalidOutputNodeSlot(SlotLabel::Index(
458 output_index,
459 )))?;
460 let input_slot = input_node_state.input_slots.get_slot(input_index).ok_or(
461 RenderGraphError::InvalidInputNodeSlot(SlotLabel::Index(input_index)),
462 )?;
463
464 if let Some(Edge::SlotEdge {
465 output_node: current_output_node,
466 ..
467 }) = input_node_state.edges.input_edges().iter().find(|e| {
468 if let Edge::SlotEdge {
469 input_index: current_input_index,
470 ..
471 } = e
472 {
473 input_index == *current_input_index
474 } else {
475 false
476 }
477 }) {
478 if should_exist == EdgeExistence::DoesNotExist {
479 return Err(RenderGraphError::NodeInputSlotAlreadyOccupied {
480 node: input_node,
481 input_slot: input_index,
482 occupied_by_node: *current_output_node,
483 });
484 }
485 }
486
487 if output_slot.slot_type != input_slot.slot_type {
488 return Err(RenderGraphError::MismatchedNodeSlots {
489 output_node,
490 output_slot: output_index,
491 input_node,
492 input_slot: input_index,
493 });
494 }
495 }
496 Edge::NodeEdge { .. } => { }
497 }
498
499 Ok(())
500 }
501
502 pub fn has_edge(&self, edge: &Edge) -> bool {
504 let output_node_state = self.get_node_state(edge.get_output_node());
505 let input_node_state = self.get_node_state(edge.get_input_node());
506 if let Ok(output_node_state) = output_node_state {
507 if output_node_state.edges.output_edges().contains(edge) {
508 if let Ok(input_node_state) = input_node_state {
509 if input_node_state.edges.input_edges().contains(edge) {
510 return true;
511 }
512 }
513 }
514 }
515
516 false
517 }
518
519 pub fn iter_nodes(&self) -> impl Iterator<Item = &NodeState> {
521 self.nodes.values()
522 }
523
524 pub fn iter_nodes_mut(&mut self) -> impl Iterator<Item = &mut NodeState> {
526 self.nodes.values_mut()
527 }
528
529 pub fn iter_sub_graphs(&self) -> impl Iterator<Item = (InternedRenderSubGraph, &RenderGraph)> {
531 self.sub_graphs.iter().map(|(name, graph)| (*name, graph))
532 }
533
534 pub fn iter_sub_graphs_mut(
536 &mut self,
537 ) -> impl Iterator<Item = (InternedRenderSubGraph, &mut RenderGraph)> {
538 self.sub_graphs
539 .iter_mut()
540 .map(|(name, graph)| (*name, graph))
541 }
542
543 pub fn iter_node_inputs(
546 &self,
547 label: impl RenderLabel,
548 ) -> Result<impl Iterator<Item = (&Edge, &NodeState)>, RenderGraphError> {
549 let node = self.get_node_state(label)?;
550 Ok(node
551 .edges
552 .input_edges()
553 .iter()
554 .map(|edge| (edge, edge.get_output_node()))
555 .map(move |(edge, output_node)| (edge, self.get_node_state(output_node).unwrap())))
556 }
557
558 pub fn iter_node_outputs(
561 &self,
562 label: impl RenderLabel,
563 ) -> Result<impl Iterator<Item = (&Edge, &NodeState)>, RenderGraphError> {
564 let node = self.get_node_state(label)?;
565 Ok(node
566 .edges
567 .output_edges()
568 .iter()
569 .map(|edge| (edge, edge.get_input_node()))
570 .map(move |(edge, input_node)| (edge, self.get_node_state(input_node).unwrap())))
571 }
572
573 pub fn add_sub_graph(&mut self, label: impl RenderSubGraph, sub_graph: RenderGraph) {
576 self.sub_graphs.insert(label.intern(), sub_graph);
577 }
578
579 pub fn remove_sub_graph(&mut self, label: impl RenderSubGraph) {
582 self.sub_graphs.remove(&label.intern());
583 }
584
585 pub fn get_sub_graph(&self, label: impl RenderSubGraph) -> Option<&RenderGraph> {
587 self.sub_graphs.get(&label.intern())
588 }
589
590 pub fn get_sub_graph_mut(&mut self, label: impl RenderSubGraph) -> Option<&mut RenderGraph> {
592 self.sub_graphs.get_mut(&label.intern())
593 }
594
595 pub fn sub_graph(&self, label: impl RenderSubGraph) -> &RenderGraph {
605 let label = label.intern();
606 self.sub_graphs
607 .get(&label)
608 .unwrap_or_else(|| panic!("Subgraph {label:?} not found"))
609 }
610
611 pub fn sub_graph_mut(&mut self, label: impl RenderSubGraph) -> &mut RenderGraph {
621 let label = label.intern();
622 self.sub_graphs
623 .get_mut(&label)
624 .unwrap_or_else(|| panic!("Subgraph {label:?} not found"))
625 }
626}
627
628impl Debug for RenderGraph {
629 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
630 for node in self.iter_nodes() {
631 writeln!(f, "{:?}", node.label)?;
632 writeln!(f, " in: {:?}", node.input_slots)?;
633 writeln!(f, " out: {:?}", node.output_slots)?;
634 }
635
636 Ok(())
637 }
638}
639
640pub struct GraphInputNode {
643 inputs: Vec<SlotInfo>,
644}
645
646impl Node for GraphInputNode {
647 fn input(&self) -> Vec<SlotInfo> {
648 self.inputs.clone()
649 }
650
651 fn output(&self) -> Vec<SlotInfo> {
652 self.inputs.clone()
653 }
654
655 fn run(
656 &self,
657 graph: &mut RenderGraphContext,
658 _render_context: &mut RenderContext,
659 _world: &World,
660 ) -> Result<(), NodeRunError> {
661 for i in 0..graph.inputs().len() {
662 let input = graph.inputs()[i].clone();
663 graph.set_output(i, input)?;
664 }
665 Ok(())
666 }
667}
668
669#[cfg(test)]
670mod tests {
671 use crate::{
672 render_graph::{
673 node::IntoRenderNodeArray, Edge, InternedRenderLabel, Node, NodeRunError, RenderGraph,
674 RenderGraphContext, RenderGraphError, RenderLabel, SlotInfo, SlotType,
675 },
676 renderer::RenderContext,
677 };
678 use bevy_ecs::world::{FromWorld, World};
679 use bevy_utils::HashSet;
680
681 #[derive(Debug, Hash, PartialEq, Eq, Clone, RenderLabel)]
682 enum TestLabel {
683 A,
684 B,
685 C,
686 D,
687 }
688
689 #[derive(Debug)]
690 struct TestNode {
691 inputs: Vec<SlotInfo>,
692 outputs: Vec<SlotInfo>,
693 }
694
695 impl TestNode {
696 pub fn new(inputs: usize, outputs: usize) -> Self {
697 TestNode {
698 inputs: (0..inputs)
699 .map(|i| SlotInfo::new(format!("in_{i}"), SlotType::TextureView))
700 .collect(),
701 outputs: (0..outputs)
702 .map(|i| SlotInfo::new(format!("out_{i}"), SlotType::TextureView))
703 .collect(),
704 }
705 }
706 }
707
708 impl Node for TestNode {
709 fn input(&self) -> Vec<SlotInfo> {
710 self.inputs.clone()
711 }
712
713 fn output(&self) -> Vec<SlotInfo> {
714 self.outputs.clone()
715 }
716
717 fn run(
718 &self,
719 _: &mut RenderGraphContext,
720 _: &mut RenderContext,
721 _: &World,
722 ) -> Result<(), NodeRunError> {
723 Ok(())
724 }
725 }
726
727 fn input_nodes(label: impl RenderLabel, graph: &RenderGraph) -> HashSet<InternedRenderLabel> {
728 graph
729 .iter_node_inputs(label)
730 .unwrap()
731 .map(|(_edge, node)| node.label)
732 .collect::<HashSet<InternedRenderLabel>>()
733 }
734
735 fn output_nodes(label: impl RenderLabel, graph: &RenderGraph) -> HashSet<InternedRenderLabel> {
736 graph
737 .iter_node_outputs(label)
738 .unwrap()
739 .map(|(_edge, node)| node.label)
740 .collect::<HashSet<InternedRenderLabel>>()
741 }
742
743 #[test]
744 fn test_graph_edges() {
745 let mut graph = RenderGraph::default();
746 graph.add_node(TestLabel::A, TestNode::new(0, 1));
747 graph.add_node(TestLabel::B, TestNode::new(0, 1));
748 graph.add_node(TestLabel::C, TestNode::new(1, 1));
749 graph.add_node(TestLabel::D, TestNode::new(1, 0));
750
751 graph.add_slot_edge(TestLabel::A, "out_0", TestLabel::C, "in_0");
752 graph.add_node_edge(TestLabel::B, TestLabel::C);
753 graph.add_slot_edge(TestLabel::C, 0, TestLabel::D, 0);
754
755 assert!(
756 input_nodes(TestLabel::A, &graph).is_empty(),
757 "A has no inputs"
758 );
759 assert_eq!(
760 output_nodes(TestLabel::A, &graph),
761 HashSet::from_iter((TestLabel::C,).into_array()),
762 "A outputs to C"
763 );
764
765 assert!(
766 input_nodes(TestLabel::B, &graph).is_empty(),
767 "B has no inputs"
768 );
769 assert_eq!(
770 output_nodes(TestLabel::B, &graph),
771 HashSet::from_iter((TestLabel::C,).into_array()),
772 "B outputs to C"
773 );
774
775 assert_eq!(
776 input_nodes(TestLabel::C, &graph),
777 HashSet::from_iter((TestLabel::A, TestLabel::B).into_array()),
778 "A and B input to C"
779 );
780 assert_eq!(
781 output_nodes(TestLabel::C, &graph),
782 HashSet::from_iter((TestLabel::D,).into_array()),
783 "C outputs to D"
784 );
785
786 assert_eq!(
787 input_nodes(TestLabel::D, &graph),
788 HashSet::from_iter((TestLabel::C,).into_array()),
789 "C inputs to D"
790 );
791 assert!(
792 output_nodes(TestLabel::D, &graph).is_empty(),
793 "D has no outputs"
794 );
795 }
796
797 #[test]
798 fn test_get_node_typed() {
799 struct MyNode {
800 value: usize,
801 }
802
803 impl Node for MyNode {
804 fn run(
805 &self,
806 _: &mut RenderGraphContext,
807 _: &mut RenderContext,
808 _: &World,
809 ) -> Result<(), NodeRunError> {
810 Ok(())
811 }
812 }
813
814 let mut graph = RenderGraph::default();
815
816 graph.add_node(TestLabel::A, MyNode { value: 42 });
817
818 let node: &MyNode = graph.get_node(TestLabel::A).unwrap();
819 assert_eq!(node.value, 42, "node value matches");
820
821 let result: Result<&TestNode, RenderGraphError> = graph.get_node(TestLabel::A);
822 assert_eq!(
823 result.unwrap_err(),
824 RenderGraphError::WrongNodeType,
825 "expect a wrong node type error"
826 );
827 }
828
829 #[test]
830 fn test_slot_already_occupied() {
831 let mut graph = RenderGraph::default();
832
833 graph.add_node(TestLabel::A, TestNode::new(0, 1));
834 graph.add_node(TestLabel::B, TestNode::new(0, 1));
835 graph.add_node(TestLabel::C, TestNode::new(1, 1));
836
837 graph.add_slot_edge(TestLabel::A, 0, TestLabel::C, 0);
838 assert_eq!(
839 graph.try_add_slot_edge(TestLabel::B, 0, TestLabel::C, 0),
840 Err(RenderGraphError::NodeInputSlotAlreadyOccupied {
841 node: TestLabel::C.intern(),
842 input_slot: 0,
843 occupied_by_node: TestLabel::A.intern(),
844 }),
845 "Adding to a slot that is already occupied should return an error"
846 );
847 }
848
849 #[test]
850 fn test_edge_already_exists() {
851 let mut graph = RenderGraph::default();
852
853 graph.add_node(TestLabel::A, TestNode::new(0, 1));
854 graph.add_node(TestLabel::B, TestNode::new(1, 0));
855
856 graph.add_slot_edge(TestLabel::A, 0, TestLabel::B, 0);
857 assert_eq!(
858 graph.try_add_slot_edge(TestLabel::A, 0, TestLabel::B, 0),
859 Err(RenderGraphError::EdgeAlreadyExists(Edge::SlotEdge {
860 output_node: TestLabel::A.intern(),
861 output_index: 0,
862 input_node: TestLabel::B.intern(),
863 input_index: 0,
864 })),
865 "Adding to a duplicate edge should return an error"
866 );
867 }
868
869 #[test]
870 fn test_add_node_edges() {
871 struct SimpleNode;
872 impl Node for SimpleNode {
873 fn run(
874 &self,
875 _graph: &mut RenderGraphContext,
876 _render_context: &mut RenderContext,
877 _world: &World,
878 ) -> Result<(), NodeRunError> {
879 Ok(())
880 }
881 }
882 impl FromWorld for SimpleNode {
883 fn from_world(_world: &mut World) -> Self {
884 Self
885 }
886 }
887
888 let mut graph = RenderGraph::default();
889 graph.add_node(TestLabel::A, SimpleNode);
890 graph.add_node(TestLabel::B, SimpleNode);
891 graph.add_node(TestLabel::C, SimpleNode);
892
893 graph.add_node_edges((TestLabel::A, TestLabel::B, TestLabel::C));
894
895 assert_eq!(
896 output_nodes(TestLabel::A, &graph),
897 HashSet::from_iter((TestLabel::B,).into_array()),
898 "A -> B"
899 );
900 assert_eq!(
901 input_nodes(TestLabel::B, &graph),
902 HashSet::from_iter((TestLabel::A,).into_array()),
903 "A -> B"
904 );
905 assert_eq!(
906 output_nodes(TestLabel::B, &graph),
907 HashSet::from_iter((TestLabel::C,).into_array()),
908 "B -> C"
909 );
910 assert_eq!(
911 input_nodes(TestLabel::C, &graph),
912 HashSet::from_iter((TestLabel::B,).into_array()),
913 "B -> C"
914 );
915 }
916}