kymata-atlas / kymata-core

Core Kymata codebase, including statistical analysis and plotting tools
https://kymata.org
MIT License
5 stars 0 forks source link

Loading .nkg files is slow for 1000+ functions #271

Closed young-x-skyee closed 1 month ago

young-x-skyee commented 5 months ago

When 1000+ functions are contained in an nkg file, it can be very slow to load it.

caiw commented 5 months ago

If you send/link me to an example nkg file, I'm happy to have a look at this, at least to profile it

young-x-skyee commented 5 months ago

If you send/link me to an example nkg file, I'm happy to have a look at this, at least to profile it

Thanks a lot, Cai! An example is here: /imaging/projects/cbu/kymata/analyses/tianyi/kymata-toolbox/kymata-toolbox-data/output/model.decoder.layers.31.nkg

caiw commented 5 months ago

The slowness

It does take a while... Running a profiler it looks like the majority of the time is spent slicing the sparse array, inside this list comprehension:

def load_expression_set(...):
    # ...
    elif type_identifier == _ExpressionSetTypeIdentifier.sensor:
        return SensorExpressionSet(
            functions=data_dict[_Keys.functions],
            sensors=[SensorDType(c) for c in data_dict[_Keys.channels][BLOCK_SCALP]],
            latencies=data_dict[_Keys.latencies],
            #                                        ,--------here
            #                                        v 
            data=[data_dict[_Keys.data][BLOCK_SCALP][:, :, i]
                  for i in range(len(data_dict[_Keys.functions]))],
        )

The slice is done once for each function (i.e. 1000+ times in this example).

Directions for a solution

This slicing is done because of the way the class constructor works. The constructor asks for a list of blocks of data, one for each function being passed. However once inside the constructor, these are actually all concatenated back into the same data block. So really the the work is wasted: a single data block is loaded from disk, sliced in the function dimension to pass as an argument to the constructor, then re-concatenated inside the constructor.

Clearly there's a way to bypass this: rewrite the constructor so that the whole block can be passed in as-is without slicing and concatenating. However this will require a bit of thought: the data block is a sparse matrix which is rather opaque to the outside world, and the constructor signature should make logical sense and allow for data to be passed in a form which isn't a monolithic block in cases where that's isn't possible. In other words, the user of the class shouldn't need to know how to appropriately concatenate spare matrices in the right axis.

I think the best solution will be just a bit of carefully validated argument inspection, so the constructor can accept either, and the arguments are carefully inspected so that (e.g.) a single data block can either be the data for one function (so don't need to wrap in a list when there's only one function), or a list like now, or a larger block, which then must have the correct dimensions to pass validation. Should be doable with some additional tests!

caiw commented 5 months ago

I'm removing myself as assignee as I'm focussing on the paper revisions for now, so someone else can have a go if they like (I'm happy to review). Or I can pick this up when I'm back on code stuff.

neukym commented 5 months ago

Thanks for the explanation - yes, let's keep it off your plate for the moment.

young-x-skyee commented 3 months ago

I tried to make it possible to do this when loading the expression set

    elif type_identifier == _ExpressionSetTypeIdentifier.sensor:
        return SensorExpressionSet(
            functions=data_dict[_Keys.functions],
            sensors=[SensorDType(c) for c in data_dict[_Keys.channels][BLOCK_SCALP]],
            latencies=data_dict[_Keys.latencies],
            data=data_dict[_Keys.data][BLOCK_SCALP]
        )

which I believe dealt with the problem @caiw suggested. However, it is still very slow to load a large number of functions. For example, I tried to load the expression sets for all the layers (66 layers, each one has one .nkg file containing 1280 functions) in a whisper large model by

        # Initialize expression_data
        expression_data = None

        # Loop through the file paths and load the data
        for file_path in tqdm(nkg_files):
            data = load_expression_set(file_path)
            if expression_data is None:
                expression_data = data
            else:
                expression_data += data

It seems that it is not very slow when loading the first expression set, but get tremendously slower from the second one onwards

neukym commented 3 months ago

This is a great start, thanks Tianyi 👍

young-x-skyee commented 3 months ago

Another thing I noticed that may slow things down tremendously is where the addition is defined:

    def __add__(self, other: SensorExpressionSet) -> SensorExpressionSet:
        assert array_equal(self.sensors, other.sensors), "Sensors mismatch"
        assert array_equal(self.latencies, other.latencies), "Latencies mismatch"
        # constructor expects a sequence of function names and sequences of 2d matrices
        functions = []
        data = []
        for expr_set in [self, other]:
            for i, function in enumerate(expr_set.functions):
                functions.append(function)
                data.append(expr_set._data[BLOCK_SCALP].data[:, :, i])
        return SensorExpressionSet(
            functions=functions,
            sensors=self.sensors, latencies=self.latencies,
            data=data,
        )

I'll have a think about how to improve that as well.

young-x-skyee commented 3 months ago

I think I kind of solved the problem. But now I notice that it's very slow to execute expression_plot. Maybe I should close this issue, add another one for that and submit a pull request when I tackled everything?

neukym commented 3 months ago

Great idea, well done 👍

caiw commented 3 months ago

@young-x-skyee Is this fixed on main? If you link me to the branch where you've fixed it I can have a look at what might be going wrong with expression_plot().

caiw commented 3 months ago

(Reopening temporarily until it's confirmed fixed on main)

young-x-skyee commented 3 months ago

The branch I fixed it was kymata-language (sorry that I was too busy to reply until now)

young-x-skyee commented 3 months ago

You may also see other changes there because I'm also trying to look at #354