juhaku / utoipa

Simple, Fast, Code first and Compile time generated OpenAPI documentation for Rust
Apache License 2.0
2.23k stars 173 forks source link

Filter by tag #879

Open Sytten opened 7 months ago

Sytten commented 7 months ago

Sometimes we want to offer part of the API (say the public part). It would be nice if we could filter the generated OpenAPI object by a tag or something.

Sytten commented 7 months ago

I put together a hackish way to do it. There are likely things that I don't support in it, but it works for me :) Objects are retained only if one path references them (and it works with nested objects!)

pub fn filter_tag(open_api: OpenApi, tag: &str) -> OpenApi {
    let mut open_api = open_api;
    let mut references = HashSet::new();

    // Retain paths
    open_api.paths.paths.retain(|_, path| {
        // Path has tag
        let retain = path
            .operations
            .values()
            .any(|operation| match &operation.tags {
                None => false,
                Some(tags) => tags.iter().any(|operation_tag| operation_tag == tag),
            });

        // Collect references
        if retain {
            path.operations
                .values()
                .for_each(|operation| collect_operation_references(&mut references, operation));
        }
        retain
    });

    // Retain references
    if let Some(components) = &mut open_api.components {
        collect_components_references(&mut references, components);
        components
            .schemas
            .retain(|name, _| references.contains(name));
    }

    open_api
}

fn collect_components_references(references: &mut HashSet<String>, components: &Components) {
    let mut new_references = references.clone();
    loop {
        let mut child_references = HashSet::new();

        components.schemas.iter().for_each(|(name, schema)| {
            if new_references.contains(name) {
                match &schema {
                    RefOr::Ref(reference) => {
                        collect_ref(references, reference);
                    }
                    RefOr::T(schema) => {
                        collect_schema_references(&mut child_references, schema);
                    }
                }
            }
        });

        new_references = child_references.difference(references).cloned().collect();
        references.extend(child_references);

        if new_references.is_empty() {
            break;
        }
    }
}

fn collect_operation_references(references: &mut HashSet<String>, operation: &Operation) {
    // Collect parameters
    if let Some(parameters) = &operation.parameters {
        parameters.iter().for_each(|parameter| {
            if let Some(RefOr::Ref(reference)) = &parameter.schema {
                collect_ref(references, reference);
            }
        })
    }

    // Collect requests
    if let Some(request) = &operation.request_body {
        request.content.iter().for_each(|(_, content)| {
            if let RefOr::Ref(reference) = &content.schema {
                collect_ref(references, reference);
            }
        })
    }

    // Collect responses
    operation
        .responses
        .responses
        .iter()
        .for_each(|(_, schema)| match &schema {
            RefOr::Ref(reference) => {
                collect_ref(references, reference);
            }
            RefOr::T(response) => {
                response.content.iter().for_each(|(_, content)| {
                    if let RefOr::Ref(reference) = &content.schema {
                        collect_ref(references, reference);
                    }
                });
            }
        });
}

fn collect_schema_references(references: &mut HashSet<String>, schema: &Schema) {
    match schema {
        Schema::Array(a) => match &*a.items {
            RefOr::Ref(reference) => {
                collect_ref(references, reference);
            }
            RefOr::T(schema) => {
                collect_schema_references(references, schema);
            }
        },
        Schema::Object(o) => {
            o.properties.values().for_each(|schema| match schema {
                RefOr::Ref(reference) => {
                    collect_ref(references, reference);
                }
                RefOr::T(schema) => {
                    collect_schema_references(references, schema);
                }
            });
        }
        Schema::OneOf(o) => {
            o.items.iter().for_each(|schema| match schema {
                RefOr::Ref(reference) => {
                    collect_ref(references, reference);
                }
                RefOr::T(schema) => {
                    collect_schema_references(references, schema);
                }
            });
        }
        Schema::AllOf(a) => {
            a.items.iter().for_each(|schema| match schema {
                RefOr::Ref(reference) => {
                    collect_ref(references, reference);
                }
                RefOr::T(schema) => {
                    collect_schema_references(references, schema);
                }
            });
        }
        Schema::AnyOf(o) => {
            o.items.iter().for_each(|schema| match schema {
                RefOr::Ref(reference) => {
                    collect_ref(references, reference);
                }
                RefOr::T(schema) => {
                    collect_schema_references(references, schema);
                }
            });
        }
        _ => {}
    }
}

fn collect_ref(references: &mut HashSet<String>, reference: &Ref) {
    references.insert(
        reference
            .ref_location
            .rsplit_terminator('/')
            .next()
            .expect("Not a valid ref location")
            .to_string(),
    );
}