oasis-open / cti-python-stix2

OASIS TC Open Repository: Python APIs for STIX 2
https://stix2.readthedocs.io/
BSD 3-Clause "New" or "Revised" License
356 stars 113 forks source link

How to speed up common queries? #516

Closed rhaist closed 2 years ago

rhaist commented 3 years ago

I try to do a very common query on the data corpus (Get all IntrusionSets and corresponding Tools) and it takes about 4:20 minutes on average to get the data using the following Python script. Is there any way I am missing to speed up look-ups? I tried using different Sources without any huge impact on the runtime. Happy for any hints.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse

from stix2 import Filter, FileSystemSource
from stix2.utils import get_type_from_id

def get_maltool_by_id(stix_src, maltool_id) -> list:
    return stix_src.query([Filter("id", "=", maltool_id)])

def process_mitre_intrusion_sets(path: str):
    src = FileSystemSource(path)
    groups = src.query([Filter("type", "=", "intrusion-set")])

    tas = []

    for group in groups:
        name = group["name"]
        gid = group["id"]
        print(f"Processing I-S {name} {gid}")

        # get the malware, tools that the itrusion-set uses
        group_uses = [
            r
            for r in src.relationships(gid, "uses", source_only=True)
            if get_type_from_id(r.target_ref) in ["malware", "tool"]
        ]

        for item in group_uses:
            maltools = get_maltool_by_id(src, item.target_ref)
            for entry in maltools:
                tn = entry["name"]
                print(f"Tool for {name}: {tn}")

parser = argparse.ArgumentParser(description="List Intrusion Sets and Tools")

parser.add_argument(
    "-m",
    "--mitre",
    help="Path to the MITRE enterprise-attack",
    default="data/mitre-cti/enterprise-attack/",
)

args = parser.parse_args()

def main():
    process_mitre_intrusion_sets(args.mitre)

if __name__ == "__main__":
    main()
chisholm commented 3 years ago

I think the slowness here is from the src.relationships() call. It looks for SROs based on source_ref, but there is no quick way to find those. Your code will process all of them for every intrusion-set, and it looks like there are just under 11,000 relationships in enterprise-attack (I just did a fresh git pull). That's a lot of relationships to process over and over again! The filesystem datasource directory layout includes files and directories named after STIX types and IDs, so it can take some shortcuts when filtering based on those things, but not much else.

So, what you might try doing is creating your own "index" of sorts. If you want to look up an SRO based on source_ref, process the relationships once, creating a mapping from STIX ID to relationships which have that ID as source_ref. Then the group_uses lookup by source_ref will be fast.

Below is what I tried. Hopefully I've got equivalent logic to yours. There is a pause as it does the indexing, but after that it prints results pretty quickly.

import argparse

from stix2 import Filter, FileSystemSource
from stix2.utils import get_type_from_id

def find_intrusion_set_sros(src):
    """
    Find SROs whose source_ref is any intrusion set.  Create a mapping from
    the intrusion set ID to a list of relationships.

    :param src: A stix2 data source
    :return: A dict mapping from STIX ID to a list of relationships whose
        source_ref is that ID.  Only intrusion-set source_refs are included.
    """

    print("Indexing intrusion-set SROs...")
    sros = src.query([Filter("type", "=", "relationship")])

    intrusion_set_sros = {}
    for sro in sros:
        if sro.source_ref.startswith("intrusion-set--"):
            intrusion_set_sros.setdefault(sro.source_ref, []).append(sro)

    return intrusion_set_sros

def get_maltool_by_id(stix_src, maltool_id) -> list:
    return stix_src.query([Filter("id", "=", maltool_id)])

def process_mitre_intrusion_sets(path: str):
    src = FileSystemSource(path)

    intrusion_set_sros = find_intrusion_set_sros(src)
    groups = src.query([Filter("type", "=", "intrusion-set")])

    tas = []

    for group in groups:
        name = group["name"]
        gid = group["id"]
        print(f"Processing I-S {name} {gid}")

        if gid in intrusion_set_sros:
            group_uses = [
                r for r in intrusion_set_sros[gid]
                if r.relationship_type == "uses"
                   and get_type_from_id(r.target_ref) in ["malware", "tool"]
            ]

            for item in group_uses:
                maltools = get_maltool_by_id(src, item.target_ref)
                for entry in maltools:
                    tn = entry["name"]
                    print(f"Tool for {name}: {tn}")

parser = argparse.ArgumentParser(description="List Intrusion Sets and Tools")

parser.add_argument(
    "-m",
    "--mitre",
    help="Path to the MITRE enterprise-attack",
    default="data/mitre-cti/enterprise-attack/",
)

args = parser.parse_args()

def main():
    process_mitre_intrusion_sets(args.mitre)

if __name__ == "__main__":
    main()
rhaist commented 3 years ago

This is top-shelf help ❤️ My code runtime improves a lot with your changes and the results seem to be consistent. I did think about caching but I didn't know where to start. Your approach is pragmatic and effectiv. Love it. If we ever meet. Beer is on me :)

maybe-sybr commented 2 years ago

Just spotted this and thought I'd drop a recipe I have for a caching data source which jams everything it sees from a backing source into a memory store to avoid hitting the backing store again for simple fetches. It works well when you have a heavily cross-connected graph of explicit references or when you run multiple passes across a dataset and would otherwise be calling .get() frequently on repeated IDs. Given that I'm typically resolving references from an unambiguous root node rather than doing lots of filtering, this approach works well for me.

I imagine it could be specialised (or genericised better) for your use case. Specifically, I'm thinking that in @chisholm's script above, you could use a single caching source and run the relationship indexing step through that (would need to be added to the recipe below) rather than having a separate dictionary to look up into. It's effectively the same since a memory store is roughly a dict, but it means you get to maintain the illusion of working with something that implements the DataSource interface.

class CachingSourceProxy(stix2.datastore.DataSource):
    def __init__(self, backing_store):
        super(CachingSourceProxy, self).__init__()
        self.__bs = backing_store
        self.__ms = stix2.datastore.memory.MemoryStore()
        self.__cached = self.__called = 0

    @property
    def called(self):
        return self.__called
    @property
    def cached(self):
        return self.__cached
    @property
    def ratio(self):
        return self.cached / self.called

    def get(self, stix_id):
        self.__called += 1
        o = self.__ms.get(stix_id)
        if o is None:
            o = self.__bs.get(stix_id)
            if o is not None:
                self.__ms.add(o)
        else:
            self.__cached += 1
        return o

    def all_versions(self, stix_id):
        for obj in self.__bs.all_versions(stix_id):
            self.__ms.add(obj)
            yield obj

    def query(self, query=None):
        for obj in self.__bs.query(query):
            self.__ms.add(obj)
            yield obj
clenk commented 2 years ago

Looks like this got resolved so I'm closing it. Feel free to reopen or open a new issue if you need further help!