Closed young-x-skyee closed 2 months ago
If you send/link me to an example nkg file, I'm happy to have a look at this, at least to profile it
If you send/link me to an example nkg file, I'm happy to have a look at this, at least to profile it
Thanks a lot, Cai! An example is here: /imaging/projects/cbu/kymata/analyses/tianyi/kymata-toolbox/kymata-toolbox-data/output/model.decoder.layers.31.nkg
It does take a while... Running a profiler it looks like the majority of the time is spent slicing the sparse array, inside this list comprehension:
def load_expression_set(...):
# ...
elif type_identifier == _ExpressionSetTypeIdentifier.sensor:
return SensorExpressionSet(
functions=data_dict[_Keys.functions],
sensors=[SensorDType(c) for c in data_dict[_Keys.channels][BLOCK_SCALP]],
latencies=data_dict[_Keys.latencies],
# ,--------here
# v
data=[data_dict[_Keys.data][BLOCK_SCALP][:, :, i]
for i in range(len(data_dict[_Keys.functions]))],
)
The slice is done once for each function (i.e. 1000+ times in this example).
This slicing is done because of the way the class constructor works. The constructor asks for a list of blocks of data, one for each function being passed. However once inside the constructor, these are actually all concatenated back into the same data block. So really the the work is wasted: a single data block is loaded from disk, sliced in the function dimension to pass as an argument to the constructor, then re-concatenated inside the constructor.
Clearly there's a way to bypass this: rewrite the constructor so that the whole block can be passed in as-is without slicing and concatenating. However this will require a bit of thought: the data block is a sparse matrix which is rather opaque to the outside world, and the constructor signature should make logical sense and allow for data to be passed in a form which isn't a monolithic block in cases where that's isn't possible. In other words, the user of the class shouldn't need to know how to appropriately concatenate spare matrices in the right axis.
I think the best solution will be just a bit of carefully validated argument inspection, so the constructor can accept either, and the arguments are carefully inspected so that (e.g.) a single data block can either be the data for one function (so don't need to wrap in a list when there's only one function), or a list like now, or a larger block, which then must have the correct dimensions to pass validation. Should be doable with some additional tests!
I'm removing myself as assignee as I'm focussing on the paper revisions for now, so someone else can have a go if they like (I'm happy to review). Or I can pick this up when I'm back on code stuff.
Thanks for the explanation - yes, let's keep it off your plate for the moment.
I tried to make it possible to do this when loading the expression set
elif type_identifier == _ExpressionSetTypeIdentifier.sensor:
return SensorExpressionSet(
functions=data_dict[_Keys.functions],
sensors=[SensorDType(c) for c in data_dict[_Keys.channels][BLOCK_SCALP]],
latencies=data_dict[_Keys.latencies],
data=data_dict[_Keys.data][BLOCK_SCALP]
)
which I believe dealt with the problem @caiw suggested. However, it is still very slow to load a large number of functions. For example, I tried to load the expression sets for all the layers (66 layers, each one has one .nkg file containing 1280 functions) in a whisper large model by
# Initialize expression_data
expression_data = None
# Loop through the file paths and load the data
for file_path in tqdm(nkg_files):
data = load_expression_set(file_path)
if expression_data is None:
expression_data = data
else:
expression_data += data
It seems that it is not very slow when loading the first expression set, but get tremendously slower from the second one onwards
This is a great start, thanks Tianyi 👍
Another thing I noticed that may slow things down tremendously is where the addition is defined:
def __add__(self, other: SensorExpressionSet) -> SensorExpressionSet:
assert array_equal(self.sensors, other.sensors), "Sensors mismatch"
assert array_equal(self.latencies, other.latencies), "Latencies mismatch"
# constructor expects a sequence of function names and sequences of 2d matrices
functions = []
data = []
for expr_set in [self, other]:
for i, function in enumerate(expr_set.functions):
functions.append(function)
data.append(expr_set._data[BLOCK_SCALP].data[:, :, i])
return SensorExpressionSet(
functions=functions,
sensors=self.sensors, latencies=self.latencies,
data=data,
)
I'll have a think about how to improve that as well.
I think I kind of solved the problem. But now I notice that it's very slow to execute expression_plot
. Maybe I should close this issue, add another one for that and submit a pull request when I tackled everything?
Great idea, well done 👍
@young-x-skyee Is this fixed on main
? If you link me to the branch where you've fixed it I can have a look at what might be going wrong with expression_plot()
.
(Reopening temporarily until it's confirmed fixed on main
)
The branch I fixed it was kymata-language (sorry that I was too busy to reply until now)
You may also see other changes there because I'm also trying to look at #354
When 1000+ functions are contained in an nkg file, it can be very slow to load it.