1use std::collections::VecDeque;
2
3use bevy_ecs::{
4 entity::Entity,
5 query::{QueryData, QueryFilter, WorldQuery},
6 system::Query,
7};
8
9use crate::{Children, Parent};
10
11pub trait HierarchyQueryExt<'w, 's, D: QueryData, F: QueryFilter> {
13 fn iter_descendants(&'w self, entity: Entity) -> DescendantIter<'w, 's, D, F>
34 where
35 D::ReadOnly: WorldQuery<Item<'w> = &'w Children>;
36
37 fn iter_ancestors(&'w self, entity: Entity) -> AncestorIter<'w, 's, D, F>
56 where
57 D::ReadOnly: WorldQuery<Item<'w> = &'w Parent>;
58}
59
60impl<'w, 's, D: QueryData, F: QueryFilter> HierarchyQueryExt<'w, 's, D, F> for Query<'w, 's, D, F> {
61 fn iter_descendants(&'w self, entity: Entity) -> DescendantIter<'w, 's, D, F>
62 where
63 D::ReadOnly: WorldQuery<Item<'w> = &'w Children>,
64 {
65 DescendantIter::new(self, entity)
66 }
67
68 fn iter_ancestors(&'w self, entity: Entity) -> AncestorIter<'w, 's, D, F>
69 where
70 D::ReadOnly: WorldQuery<Item<'w> = &'w Parent>,
71 {
72 AncestorIter::new(self, entity)
73 }
74}
75
76pub struct DescendantIter<'w, 's, D: QueryData, F: QueryFilter>
80where
81 D::ReadOnly: WorldQuery<Item<'w> = &'w Children>,
82{
83 children_query: &'w Query<'w, 's, D, F>,
84 vecdeque: VecDeque<Entity>,
85}
86
87impl<'w, 's, D: QueryData, F: QueryFilter> DescendantIter<'w, 's, D, F>
88where
89 D::ReadOnly: WorldQuery<Item<'w> = &'w Children>,
90{
91 pub fn new(children_query: &'w Query<'w, 's, D, F>, entity: Entity) -> Self {
93 DescendantIter {
94 children_query,
95 vecdeque: children_query
96 .get(entity)
97 .into_iter()
98 .flatten()
99 .copied()
100 .collect(),
101 }
102 }
103}
104
105impl<'w, 's, D: QueryData, F: QueryFilter> Iterator for DescendantIter<'w, 's, D, F>
106where
107 D::ReadOnly: WorldQuery<Item<'w> = &'w Children>,
108{
109 type Item = Entity;
110
111 fn next(&mut self) -> Option<Self::Item> {
112 let entity = self.vecdeque.pop_front()?;
113
114 if let Ok(children) = self.children_query.get(entity) {
115 self.vecdeque.extend(children);
116 }
117
118 Some(entity)
119 }
120}
121
122pub struct AncestorIter<'w, 's, D: QueryData, F: QueryFilter>
124where
125 D::ReadOnly: WorldQuery<Item<'w> = &'w Parent>,
126{
127 parent_query: &'w Query<'w, 's, D, F>,
128 next: Option<Entity>,
129}
130
131impl<'w, 's, D: QueryData, F: QueryFilter> AncestorIter<'w, 's, D, F>
132where
133 D::ReadOnly: WorldQuery<Item<'w> = &'w Parent>,
134{
135 pub fn new(parent_query: &'w Query<'w, 's, D, F>, entity: Entity) -> Self {
137 AncestorIter {
138 parent_query,
139 next: Some(entity),
140 }
141 }
142}
143
144impl<'w, 's, D: QueryData, F: QueryFilter> Iterator for AncestorIter<'w, 's, D, F>
145where
146 D::ReadOnly: WorldQuery<Item<'w> = &'w Parent>,
147{
148 type Item = Entity;
149
150 fn next(&mut self) -> Option<Self::Item> {
151 self.next = self.parent_query.get(self.next?).ok().map(|p| p.get());
152 self.next
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use bevy_ecs::{
159 prelude::Component,
160 system::{Query, SystemState},
161 world::World,
162 };
163
164 use crate::{query_extension::HierarchyQueryExt, BuildWorldChildren, Children, Parent};
165
166 #[derive(Component, PartialEq, Debug)]
167 struct A(usize);
168
169 #[test]
170 fn descendant_iter() {
171 let world = &mut World::new();
172
173 let [a, b, c, d] = std::array::from_fn(|i| world.spawn(A(i)).id());
174
175 world.entity_mut(a).push_children(&[b, c]);
176 world.entity_mut(c).push_children(&[d]);
177
178 let mut system_state = SystemState::<(Query<&Children>, Query<&A>)>::new(world);
179 let (children_query, a_query) = system_state.get(world);
180
181 let result: Vec<_> = a_query
182 .iter_many(children_query.iter_descendants(a))
183 .collect();
184
185 assert_eq!([&A(1), &A(2), &A(3)], result.as_slice());
186 }
187
188 #[test]
189 fn ancestor_iter() {
190 let world = &mut World::new();
191
192 let [a, b, c] = std::array::from_fn(|i| world.spawn(A(i)).id());
193
194 world.entity_mut(a).push_children(&[b]);
195 world.entity_mut(b).push_children(&[c]);
196
197 let mut system_state = SystemState::<(Query<&Parent>, Query<&A>)>::new(world);
198 let (parent_query, a_query) = system_state.get(world);
199
200 let result: Vec<_> = a_query.iter_many(parent_query.iter_ancestors(c)).collect();
201
202 assert_eq!([&A(1), &A(0)], result.as_slice());
203 }
204}