CG80499 / KAN-GPT-2

Training small GPT-2 style models using Kolmogorov-Arnold networks.
87 stars 4 forks source link

Any suggestions for KAN-ViT? #1

Open snoop2head opened 1 month ago

snoop2head commented 1 month ago

Dear @CG80499 , Thank you for your contribution.

Using your implementation of ChebyKAN layer, I am currently training Vision Transformer tiny model on ImageNet1K. But it seems like it is underperforming the original implementation.

Any suggestions? Why did you set polynomial degree as 8 for GPT?

Screenshot 2024-05-10 at 2 37 15 PM

Here are some minor details for my implementation on ViT.

Thank you.

CG80499 commented 1 month ago

Thanks for your interest! I chose 8 after experimenting with different values for fitting toy functions (see toy_functions.py). Have you checked that MLP layer and the ChebyKAN have the same number of parameters? Does the model converge correctly in ChebyKAN case?

snoop2head commented 1 month ago

Thank you for your advice.

Strangely enough, ChebyKAN layer turns out to have more parameters than MLP. It has about 0.33M params per feedforward(ff) module, where standard MLP module has 0.30M params per ff module. I think input/output same is homogenous for both of the modules. Can this difference be considered as negligible?

ViT-Tiny w/ ChebyKAN

Below is the result of module.tabulate() of ViT-tiny, with ChebyKAN hidden dimension of 192 and sequence length of 197.

Component Type Input Output Parameters
model/layer_10/ff KANLayer float32[1,197,192] float32[1,197,192]
True
model/layer_10/ff/ChebyKAN_0 ChebyKAN float32[197,192] float32[197,192] coefficients: float32[192,192,9]
331,776 (1.3 MB)

ViT-Tiny w/ Standard MLP

Below is the result of module.tabulate() of ViT-tiny, with normal MLP with same dim/seq_len configuration.

Component Type Input Output Parameters
model/layer_10/ff FeedForward float32[1,197,192] float32[1,197,192]
True
model/layer_10/ff/w1 Dense float32[1,197,192] float32[1,197,768] bias: float32[768]
kernel: float32[192,768]
148,224 (592.9 KB)
model/layer_10/ff/drop Dropout float32[1,197,768] float32[1,197,768]
True
model/layer_10/ff/w2 Dense float32[1,197,768] float32[1,197,192] bias: float32[192]
kernel: float32[768,192]
147,648 (590.6 KB)

ViT with ChebyKAN layer does converge, but underperforming by significant margin(more than -10%p). I will make some adjustment in polynomial degree + learning rate. Feel free to offer any additional suggestions or ask questions! I will keep updating.

CG80499 commented 1 month ago

If you try a chebykan layer with a much higher degree, does that work?

snoop2head commented 1 month ago

I don't think that is currently feasible in my device (TPUv3-8) and below is the log from running polynomial degree of 64. Degree of 16 doesn't work either and I will look into it.

jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Compilation failure: Aborting compilation early because it's unlikely to have enough memory. Requires 595.95G, has 14.71G available. If more detailed logging is desired, set --xla_tpu_impure_oom_fast_exit_threshold=-1
snoop2head commented 1 month ago

I updated the code as follows: https://github.com/snoop2head/KAN-ViT Please comment or give feedbacks any time.

demon2036 commented 1 month ago

@snoop2head @CG80499 Hi, thank for your contributions. I think we can optimize the original code so that we can set a higher degree. In fact, we don't need the entire large matrix; we can use a for loop for accumulation. This can save d times the memory and, when tested on KAN-ViT, the speed is approximately 3 times faster than the original implementation.

def kan_ops2(x, coefficients):
    # x: (batch_size, in_features)
    # normalize x between -1 and 1
    x = nn.tanh(x)

    in_features = x.shape[1]
    degree = coefficients.shape[-1] - 1

    cheby_values = jnp.ones((x.shape[0], in_features, 2), dtype=x.dtype)
    cheby_values = cheby_values.at[:, :, 1].set(x)
    prev_values = cheby_values[:, :, 0]
    values = cheby_values[:, :, 1]

    out = jnp.einsum('bi,ij->bj', prev_values,
                     coefficients[:, :, 0], ) + jnp.einsum('bi,ij->bj', values,
                                                           coefficients[:, :, 1],
                                                           )

    for i in range(2, degree + 1):
        temp = 2 * x * values - prev_values
        prev_values = values
        values = temp
        out += jnp.einsum('bi,ij->bj', temp, coefficients[:, :, i], )

    # def loop_body(i, carry):
    #     prev_values, values, out = carry
    #     temp = 2 * x * values - prev_values
    #     prev_values = values
    #     values = temp
    #     out += jnp.einsum('bi,ij->bj', temp, coefficients[:, :, i])
    #     return (prev_values, values, out)
    #
    # prev_values, values, out = jax.lax.fori_loop(2, degree + 1, body_fun=loop_body, init_val=(prev_values, values, out))

    return out

# Inspired by https://github.com/SynodicMonth/ChebyKAN/blob/main/ChebyKANLayer.py
# Imported from https://github.com/CG80499/KAN-GPT-2/blob/master/transformer.py

class ChebyKAN(nn.Module):
    in_features: int
    out_features: int
    degree: int # degree of the basis polynomials

    def setup(self):
        assert self.degree > 0, "Degree of the Chebyshev polynomials must be greater than 0"
        mean, std = 0.0, 1/ (self.in_features * (self.degree + 1))
        self.coefficients = self.param("coefficients", lambda key, shape: mean + std * jax.random.normal(key, shape), (self.in_features, self.out_features, self.degree+1))

    def __call__(self, x):
        return kan_ops2(x, self.coefficients)
        """
        # x: (batch_size, in_features)
        # normalize x between -1 and 1
        x = jnp.tanh(x)
        cheby_values = jnp.ones((x.shape[0], self.in_features, self.degree+1))
        cheby_values = cheby_values.at[:, :, 1].set(x)
        for i in range(2, self.degree+1):
            next_value = 2 * x * cheby_values[:, :, i-1] - cheby_values[:, :, i-2]
            cheby_values = cheby_values.at[:, :, i].set(next_value)
        # cheby_values: (batch_size, in_features, degree+1)
        # multiply by coefficients (in_features, out_features, degree+1)
        return jnp.einsum('bid,ijd->bj', cheby_values, self.coefficients)
        """
snoop2head commented 1 month ago

@demon2036 @CG80499 Let's check the performance out! I will run on Imagenet1K right away.

snoop2head commented 1 month ago

Hm... It seems like the performance has deteriorated. I will investigate further

demon2036 commented 1 month ago

Hm... It seems like the performance has deteriorated. I will investigate further

It's quite strange. My code should increase speed and reduce memory usage. I've tested it on TPU v4-8 and TPU v4-32. Are you running kan-deit-b16-224-in1k-300ep.sh? I don't think you should run the tiny or small versions, as you can see that the runtimes for tiny and small on jax-deit are nearly the same, which suggests a serious I/O issue, possibly related to webdataset. We can also discuss this problem further.

I think you should run the base version. If the original base version causes an out-of-memory (OOM) error, you can try reducing the batch size, for example, to 512. On TPU v4-8, the original version causes an OOM error, whereas my code does not. On TPU v4-32, the original code runs at around 4 it/s , whereas mine runs at 11-12 it/s