Open snoop2head opened 1 month ago
Thanks for your interest! I chose 8 after experimenting with different values for fitting toy functions (see toy_functions.py). Have you checked that MLP layer and the ChebyKAN have the same number of parameters? Does the model converge correctly in ChebyKAN case?
Thank you for your advice.
Strangely enough, ChebyKAN layer turns out to have more parameters than MLP. It has about 0.33M params per feedforward(ff) module, where standard MLP module has 0.30M params per ff module. I think input/output same is homogenous for both of the modules. Can this difference be considered as negligible?
Below is the result of module.tabulate()
of ViT-tiny, with ChebyKAN hidden dimension of 192 and sequence length of 197.
Component | Type | Input | Output | Parameters |
---|---|---|---|---|
model/layer_10/ff | KANLayer | float32[1,197,192] | float32[1,197,192] | |
True | ||||
model/layer_10/ff/ChebyKAN_0 | ChebyKAN | float32[197,192] | float32[197,192] | coefficients: float32[192,192,9] |
331,776 (1.3 MB) |
Below is the result of module.tabulate()
of ViT-tiny, with normal MLP with same dim/seq_len configuration.
Component | Type | Input | Output | Parameters |
---|---|---|---|---|
model/layer_10/ff | FeedForward | float32[1,197,192] | float32[1,197,192] | |
True | ||||
model/layer_10/ff/w1 | Dense | float32[1,197,192] | float32[1,197,768] | bias: float32[768] |
kernel: float32[192,768] | ||||
148,224 (592.9 KB) | ||||
model/layer_10/ff/drop | Dropout | float32[1,197,768] | float32[1,197,768] | |
True | ||||
model/layer_10/ff/w2 | Dense | float32[1,197,768] | float32[1,197,192] | bias: float32[192] |
kernel: float32[768,192] | ||||
147,648 (590.6 KB) |
ViT with ChebyKAN layer does converge, but underperforming by significant margin(more than -10%p). I will make some adjustment in polynomial degree + learning rate. Feel free to offer any additional suggestions or ask questions! I will keep updating.
If you try a chebykan layer with a much higher degree, does that work?
I don't think that is currently feasible in my device (TPUv3-8) and below is the log from running polynomial degree of 64. Degree of 16 doesn't work either and I will look into it.
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Compilation failure: Aborting compilation early because it's unlikely to have enough memory. Requires 595.95G, has 14.71G available. If more detailed logging is desired, set --xla_tpu_impure_oom_fast_exit_threshold=-1
I updated the code as follows: https://github.com/snoop2head/KAN-ViT Please comment or give feedbacks any time.
@snoop2head @CG80499 Hi, thank for your contributions. I think we can optimize the original code so that we can set a higher degree. In fact, we don't need the entire large matrix; we can use a for loop for accumulation. This can save d times the memory and, when tested on KAN-ViT, the speed is approximately 3 times faster than the original implementation.
def kan_ops2(x, coefficients):
# x: (batch_size, in_features)
# normalize x between -1 and 1
x = nn.tanh(x)
in_features = x.shape[1]
degree = coefficients.shape[-1] - 1
cheby_values = jnp.ones((x.shape[0], in_features, 2), dtype=x.dtype)
cheby_values = cheby_values.at[:, :, 1].set(x)
prev_values = cheby_values[:, :, 0]
values = cheby_values[:, :, 1]
out = jnp.einsum('bi,ij->bj', prev_values,
coefficients[:, :, 0], ) + jnp.einsum('bi,ij->bj', values,
coefficients[:, :, 1],
)
for i in range(2, degree + 1):
temp = 2 * x * values - prev_values
prev_values = values
values = temp
out += jnp.einsum('bi,ij->bj', temp, coefficients[:, :, i], )
# def loop_body(i, carry):
# prev_values, values, out = carry
# temp = 2 * x * values - prev_values
# prev_values = values
# values = temp
# out += jnp.einsum('bi,ij->bj', temp, coefficients[:, :, i])
# return (prev_values, values, out)
#
# prev_values, values, out = jax.lax.fori_loop(2, degree + 1, body_fun=loop_body, init_val=(prev_values, values, out))
return out
# Inspired by https://github.com/SynodicMonth/ChebyKAN/blob/main/ChebyKANLayer.py
# Imported from https://github.com/CG80499/KAN-GPT-2/blob/master/transformer.py
class ChebyKAN(nn.Module):
in_features: int
out_features: int
degree: int # degree of the basis polynomials
def setup(self):
assert self.degree > 0, "Degree of the Chebyshev polynomials must be greater than 0"
mean, std = 0.0, 1/ (self.in_features * (self.degree + 1))
self.coefficients = self.param("coefficients", lambda key, shape: mean + std * jax.random.normal(key, shape), (self.in_features, self.out_features, self.degree+1))
def __call__(self, x):
return kan_ops2(x, self.coefficients)
"""
# x: (batch_size, in_features)
# normalize x between -1 and 1
x = jnp.tanh(x)
cheby_values = jnp.ones((x.shape[0], self.in_features, self.degree+1))
cheby_values = cheby_values.at[:, :, 1].set(x)
for i in range(2, self.degree+1):
next_value = 2 * x * cheby_values[:, :, i-1] - cheby_values[:, :, i-2]
cheby_values = cheby_values.at[:, :, i].set(next_value)
# cheby_values: (batch_size, in_features, degree+1)
# multiply by coefficients (in_features, out_features, degree+1)
return jnp.einsum('bid,ijd->bj', cheby_values, self.coefficients)
"""
@demon2036 @CG80499 Let's check the performance out! I will run on Imagenet1K right away.
Hm... It seems like the performance has deteriorated. I will investigate further
Hm... It seems like the performance has deteriorated. I will investigate further
It's quite strange. My code should increase speed and reduce memory usage. I've tested it on TPU v4-8 and TPU v4-32. Are you running kan-deit-b16-224-in1k-300ep.sh? I don't think you should run the tiny or small versions, as you can see that the runtimes for tiny and small on jax-deit are nearly the same, which suggests a serious I/O issue, possibly related to webdataset. We can also discuss this problem further.
I think you should run the base version. If the original base version causes an out-of-memory (OOM) error, you can try reducing the batch size, for example, to 512. On TPU v4-8, the original version causes an OOM error, whereas my code does not. On TPU v4-32, the original code runs at around 4 it/s , whereas mine runs at 11-12 it/s
Dear @CG80499 , Thank you for your contribution.
Using your implementation of ChebyKAN layer, I am currently training Vision Transformer tiny model on ImageNet1K. But it seems like it is underperforming the original implementation.
Any suggestions? Why did you set polynomial degree as 8 for GPT?
Here are some minor details for my implementation on ViT.
Thank you.