Git Product home page Git Product logo

efficient-kan's Introduction

An Efficient Implementation of Kolmogorov-Arnold Network

This repository contains an efficient implementation of Kolmogorov-Arnold Network (KAN). The original implementation of KAN is available here.

The performance issue of the original implementation is mostly because it needs to expand all intermediate variables to perform the different activation functions. For a layer with in_features input and out_features output, the original implementation needs to expand the input to a tensor with shape (batch_size, out_features, in_features) to perform the activation functions. However, all activation functions are linear combination of a fixed set of basis functions which are B-splines; given that, we can reformulate the computation as activate the input with different basis functions and then combine them linearly. This reformulation can significantly reduce the memory cost and make the computation a straightforward matrix multiplication, and works with both forward and backward pass naturally.

The problem is in the sparsification which is claimed to be critical to KAN's interpretability. The authors proposed a L1 regularization defined on the input samples, which requires non-linear operations on the (batch_size, out_features, in_features) tensor, and is thus not compatible with the reformulation. I instead replace the L1 regularization with a L1 regularization on the weights, which is more common in neural networks and is compatible with the reformulation. The author's implementation indeed include this kind of regularization alongside the one described in the paper as well, so I think it might help. More experiments are needed to verify this; but at least the original approach is infeasible if efficiency is wanted.

Another difference is that, beside the learnable activation functions (B-splines), the original implementation also includes a learnable scale on each activation function. I provided an option enable_standalone_scale_spline that defaults to True to include this feature; disable it will make the model more efficient, but potentially hurts results. It needs more experiments.

2024-05-04 Update: @xiaol hinted that the constant initialization of base_weight parameters can be a problem on MNIST. For now I've changed both the base_weight and spline_scaler matrices to be initialized with kaiming_uniform_, following nn.Linear's initialization. It seems to work much much better on MNIST (~20% to ~97%), but I'm not sure if it's a good idea in general.

efficient-kan's People

Contributors

akaashdash avatar blealtan avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

efficient-kan's Issues

Can MLP replace transformer?

Hello author, I would like to know if the efficient implementation of MLP can replace the MLP module in transformer. What are the disadvantages and advantages?

CUDA out of memory.

from src.efficient_kan.kan import KAN
import torch
net = KAN([1152,1152*4,1152]).to("cuda")
x = torch.rand(size=(4096*4,1152)).to("cuda")
net(x)

I found that if the hidden layer is too large, the problem of CUDA out of memory will occur.

flops and params

How to obtain the floats and params of this model? The result I obtained using thop is 0.

Necessity of the scale parameter?

Hi, I just want to share my experience when I developed KAN that the scale parameter seems quite important (but that was in the very beginning of the KAN project, so I could be hallucinating). Would love to hear your experimental results! Great initiative, would love to see a more efficient KAN implementation (with good features maintained).

Questions

Hello! Thank you for providing this awesome implementation!

I am curious whether this repo also contains 'grid-extention' implementation, which is included in the original KAN paper :)

please report a bug to PyTorch. torch.linalg.lstsq: (Batch element 0): Argument 6 has illegal value. Most certainly there is a bug in the implementation calling the backend library.

File "tests/test_simple_math.py", line 166, in curve2coeff
solution = torch.linalg.lstsq( # 使用最小二乘法求解线性方程组
RuntimeError: false INTERNAL ASSERT FAILED at "../aten/src/ATen/native/BatchLinearAlgebra.cpp":1462, please report a bug to PyTorch. torch.linalg.lstsq: (Batch element 0): Argument 6 has illegal value. Most certainly there is a bug in the implementation calling the backend library.

What happened?
image

Contributing?

Hey I have a fork, with some not useful stuff on it yet. (Mostly just profiling showing forward passes suck due to b-splines and some comparisons to MLPs.)

Do you want folks to contribute to this?
Are you interested in making b-splines more efficient with something like: https://github.com/GistNoesis/FourierKAN/blob/main/fftKAN.py

Let me know what you think.

Symbolic expression support

Hello, thanks for your work. Are there plans in the near future to support fitting symbolic expressions/manual input and network visualizations as the original implementation do?

Thanks in advance

How to pip install?

$pip install efficient-kan
ERROR: Could not find a version that satisfies the requirement efficient-kan (from versions: none)
ERROR: No matching distribution found for efficient-kan

KAN intuition for setting the hidden layers

Hello,

Is there a rule of thumb or intuition for setting the layers_hidden parameter? I'm using it for time series, and I use [input_size, 10, horizon]. The 10 is arbitrary, and taken from the MNIST example, but do you have a suggestion on setting these for best performance?

Constant loss, network is not learning

I am trying to predict certain function coefficients (output: a, b) based on its curve (input: frequency_response) with the help of Kolmogorov-Arnold Network and your nice library.

enter image description here

Unfortunately my loss is constant and is not improving at all. Any idea what am I doing wrong here? This problem was previously approached using MLP, hence I am hoping KANs can provide a better solution. My code is the following:

import time
import torch
import numpy as np
from tqdm import tqdm
from torchaudio.functional import lfilter
from torch.optim import Adam, lr_scheduler

from efficient_kan.kan import KAN

# Set the device
hardware = "cpu"
device = torch.device(hardware)

    
class FilterNet(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_batches=1, num_biquds=1, num_layers=1, fs=44100):
        super(FilterNet, self).__init__()
        self.eps = 1e-8
        self.fs = fs
        self.dirac = self.get_dirac(fs, 0, grad=True)  # generate a dirac
        self.kan = KAN([input_size, hidden_size, output_size], grid_size=5, spline_order=3)

    def get_dirac(self, size, index=1, grad=False):
        tensor = torch.zeros(size, requires_grad=grad)
        tensor.data[index] = 1
        return tensor

    def compute_filter_magnitude_and_phase_frequency_response(self, dirac, fs, a, b):
        # filter it 
        filtered_dirac = lfilter(dirac, a, b) 
        freqs_response = torch.fft.fft(filtered_dirac)
        
        # compute the frequency axis (positive frequencies only)
        freqs_rad = torch.fft.rfftfreq(filtered_dirac.shape[-1])
        
        # keep only the positive freqs
        freqs_hz = freqs_rad[:filtered_dirac.shape[-1] // 2] * fs / np.pi
        freqs_response = freqs_response[:len(freqs_hz)]
        
        # magnitude response 
        mag_response_db = 20 * torch.log10(torch.abs(freqs_response))
        
        # Phase Response
        phase_response_rad = torch.angle(freqs_response)
        phase_response_deg = phase_response_rad * 180 / np.pi
        return freqs_hz, mag_response_db, phase_response_deg
    
    def zpk2ba(self, zpk):
        gain    = zpk[0]
        p0_real = zpk[1]
        p0_imag = zpk[2]
        q0_real = zpk[3]
        q0_imag = zpk[4]
        
        zero = torch.complex(q0_real, q0_imag)
        zero_abs = zero.abs()
        zero = ((1 - self.eps) * zero * torch.tanh(zero_abs)) / (zero_abs + self.eps)
                
        pole = torch.complex(p0_real, p0_imag)
        pole_abs = pole.abs()
        pole = ((1 - self.eps) * pole * torch.tanh(pole_abs)) / (pole_abs + self.eps)

        b0 = gain 
        b1 = gain * -2 * zero.real
        b2 = gain * ((zero.real ** 2) + (zero.imag ** 2))
        a0 = 1
        a1 = -2 * pole.real
        a2 = (pole.real ** 2) + (pole.imag ** 2)
        b = torch.tensor([b0, b1, b2], requires_grad=True)
        a = torch.tensor([a0, a1, a2], requires_grad=True)
        return b, a
    
    def forward(self, x):
        zpk = self.kan(x)
        #print("> Zpk: ", zpk)
        
        # extract filter coeffs
        self.a, self.b = self.zpk2ba(zpk)
        
        # get filter reponse 
        freqs_hz, mag_response_db, phase_response_deg = self.compute_filter_magnitude_and_phase_frequency_response(self.dirac, self.fs, self.a, self.b)
        frequency_response = torch.hstack((mag_response_db, phase_response_deg))
        return frequency_response

# Define the target filter variables
fs = 512               # nbr of input points
num_biquads = 1        # Number of biquad filters in the cascade
num_biquad_coeffs = 6  # Number of coefficients per biquad

# Init the optimizer 
n_epochs  = 500
model     = FilterNet(fs, fs*8, 5, 1, 1, 1, fs)
optimizer = Adam(model.parameters(), lr=1e-1, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min')

# define filter coeffs (TARGET)
b = torch.tensor([0.803, -0.132, 0.731])
a = torch.tensor([1.000, -0.426, 0.850])

# compute filter response
freqs_hz, mag_response_db, phase_response_deg = model.compute_filter_magnitude_and_phase_frequency_response(model.get_dirac(fs, 0, grad=False), fs, a, b)
target_frequency_response = torch.hstack((mag_response_db, phase_response_deg))

# Inits
start_time = time.time()    # Start timing the loop
pbar = tqdm(total=n_epochs) # Create a tqdm progress bar
loss_history = []

# Run training
for epoch in range(n_epochs):    
    model.train()
    device = next(model.parameters()).device
    
    target = target_frequency_response.to(device)
    optimizer.zero_grad()
    
    # Compute prediction and loss
    predicted_frequency_response = model(target)
    loss = torch.nn.functional.mse_loss(predicted_frequency_response, target_frequency_response)
    
    # Backpropagation
    loss.backward()
    optimizer.step()
    loss_history.append(loss.item())

    # Update the progress bar
    pbar.set_description(f"Epoch: {epoch}, Loss: {loss:.9f}")
    pbar.update(1)
    scheduler.step(loss)
        
# End timing the loop & print duration
elapsed_time = time.time() - start_time
print(f"\nOptimization loop took {elapsed_time:.2f} seconds.")

# Plot predicted filter
freqs_hz, predicted_mag_response_db, predicted_phase_response_deg = model.compute_filter_magnitude_and_phase_frequency_response(model.get_dirac(fs, 0, grad=False), fs, model.a.detach().cpu(), model.b.detach().cpu())

time complexity

Hey, I want to use your implementation, do you know how much slower the learning can be compared to nn.linear?

Can configure efficient-kan model for continual learning?

Similar like what authors shown in official git repo, can use this efficient-kan model for continual learning settings. . For using efficient-kan for CL settings, I haven't found some attributes that need to be set given in official pykan;

######### cl code from pykan

setting bias_trainable=False, sp_trainable=False, sb_trainable=False is important.

otherwise KAN will have random scaling and shift for samples in previous stages

model = KAN(width=[1,1], grid=200, k=3, noise_scale=0.1, bias_trainable=False, sp_trainable=False, sb_trainable=False)

how can I set bias_trainable=False, sp_trainable=False, sb_trainable=False here, is there a way?

Layer fails on inputs that have been transposed previously

Using your implementation on the data that has been transposed previously causes a
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead. error.
Just replacing

x = x.view(-1, self.in_features)
with a reshape fixed it. Maybe worth to fix this if you want the layer to be a full drop in for nn.Linear

RuntimeError: false INTERNAL ASSERT FAILED

Getting this error when running test_simple_math.py. Any idea how to resolve it?

  File "/home/chansingh/test/kan.py", line 131, in curve2coeff
    solution = torch.linalg.lstsq(
               ^^^^^^^^^^^^^^^^^^^
RuntimeError: false INTERNAL ASSERT FAILED at "../aten/src/ATen/native/BatchLinearAlgebra.cpp":1539, please report a bug to PyTorch. torch.linalg.lstsq: (Batch element 0): Argument 6 has illegal value. Most certainly there is a bug in the implementation calling the backend library.

(pytorch is up-to-date, version '2.3.0+cu121', python 3.11)

Difference in memory usage of mlp and kan

Hello:
if input tensor size is [64,28x28],hidden layers is [256,256,256,256],The memory usage of mlp and kan is similar,382M and 500M respectively.The results are consistent with the experimental results:
However,if the input tensor size is [36864,28x28],The memory usage of the two is huge different,844M and 14468M respectively.What is the reason for this?The initialization of the kan is consistent with that given in the example. And use a gpu.

KAN has more learnable parameters?

I am confused about the principle of KAN. From this implementation, KAN has more learnable parameters?
It seems that the improvement of KAN lies in the learnable activation functions, thus achieving better accuracy. Does KAN have any advantage on computation and memory?

Intel oneMKL ERROR: Parameter 6 was incorrect on entry to SGELSY.

D:\Users\12719\anaconda3\python.exe D:\Users\12719\PycharmProjects\efficient-kan\tests\test_simple_math.py
20%|██ | 20/100 [00:01<00:06, 12.66it/s, mse_loss=nan, reg_loss=nan]
Intel oneMKL ERROR: Parameter 6 was incorrect on entry to SGELSY.

Intel oneMKL ERROR: Parameter 6 was incorrect on entry to SGELSY.
20%|██ | 20/100 [00:02<00:08, 9.82it/s, mse_loss=nan, reg_loss=nan]
Traceback (most recent call last):
File "D:\Users\12719\PycharmProjects\efficient-kan\tests\test_simple_math.py", line 35, in
test_mul()
File "D:\Users\12719\PycharmProjects\efficient-kan\tests\test_simple_math.py", line 29, in test_mul
optimizer.step(closure)
File "D:\Users\12719\anaconda3\Lib\site-packages\torch\optim\optimizer.py", line 459, in wrapper
out = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\12719\anaconda3\Lib\site-packages\torch\utils_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\12719\anaconda3\Lib\site-packages\torch\optim\lbfgs.py", line 320, in step
orig_loss = closure()
^^^^^^^^^
File "D:\Users\12719\anaconda3\Lib\site-packages\torch\utils_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\12719\PycharmProjects\efficient-kan\tests\test_simple_math.py", line 18, in closure
y = kan(x, update_grid=(i % 20 == 0))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\12719\anaconda3\Lib\site-packages\torch\nn\modules\module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\12719\anaconda3\Lib\site-packages\torch\nn\modules\module.py", line 1541, in call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\12719\PycharmProjects\efficient-kan\src\efficient_kan\kan.py", line 272, in forward
layer.update_grid(x)
File "D:\Users\12719\anaconda3\Lib\site-packages\torch\utils_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\12719\PycharmProjects\efficient-kan\src\efficient_kan\kan.py", line 210, in update_grid
self.spline_weight.data.copy
(self.curve2coeff(x, unreduced_spline_output))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\12719\PycharmProjects\efficient-kan\src\efficient_kan\kan.py", line 131, in curve2coeff
solution = torch.linalg.lstsq(
^^^^^^^^^^^^^^^^^^^
RuntimeError: false INTERNAL ASSERT FAILED at "C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\BatchLinearAlgebra.cpp":1538, please report a bug to PyTorch. torch.linalg.lstsq: (Batch element 0): Argument 6 has illegal value. Most certainly there is a bug in the implementation calling the backend library.

How to cite

If I would like to cite usage of this implementation, how would I do so?

Train Multi-KANs simultaneously! Coding realization

Hey, guys. Now I have this problem, the structure is like following:

y_pred = KAN_1(X1)*f1(X1) + KAN_2(X2)*f2(X2) + KAN_3(X3)*f3(X3),f1,f2,f3 is all fixed and known functions, how can I train KAN_i,i=1,2,3 simultaneously?

I am wondering how can I realize this training process.

Missing license

This repos should have a license to protect its owner and potential users

ONNX

Can it be converted to ONNX, bro?

I would like to add a train test function to the KAN class

can't seem to open a Branch for raising a pull request so adding code here:

def train_model(self, model, trainloader, valloader, optimizer, scheduler, criterion, device, epochs):
model.to(device)
for epoch in range(epochs):
# Train
model.train()
with tqdm(trainloader) as pbar:
for i, (images, labels) in enumerate(pbar):
images = images.view(-1, 28 * 28).to(device)
optimizer.zero_grad()
output = model(images)
loss = criterion(output, labels.to(device))
loss.backward()
optimizer.step()
accuracy = (output.argmax(dim=1) == labels.to(device)).float().mean()
pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item(), lr=optimizer.param_groups[0]['lr'])

        # Validation
        model.eval()
        val_loss = 0
        val_accuracy = 0
        with torch.no_grad():
            for images, labels in valloader:
                images = images.view(-1, 28 * 28).to(device)
                output = model(images)
                val_loss += criterion(output, labels.to(device)).item()
                val_accuracy += (
                    (output.argmax(dim=1) == labels.to(device)).float().mean().item()
                )
        val_loss /= len(valloader)
        val_accuracy /= len(valloader)

        # Update learning rate
        scheduler.step()

        print(
            f"Epoch {epoch + 1}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}"
        )

def test_model(self, model, testloader, device, num_samples=10):
    model.to(device)
    model.eval()
    predictions = []
    ground_truths = []
    images_to_show = []

    criterion = nn.CrossEntropyLoss()

    with torch.no_grad():
        for i, (images, labels) in enumerate(testloader):
            images = images.view(-1, 28 * 28).to(device)
            output = model(images)
            predictions.extend(output.argmax(dim=1).cpu().numpy())
            ground_truths.extend(labels.cpu().numpy())
            images_to_show.extend(images.view(-1, 28, 28).cpu().numpy())

            if len(predictions) >= num_samples:
                break

    # Print the predictions for the specified number of samples
    for i in range(num_samples):
        print(f"Ground Truth: {ground_truths[i]}, Prediction: {predictions[i]}")

    return predictions[:num_samples], ground_truths[:num_samples], images_to_show[:num_samples]

Something different from the official results for KAN.

I try to reproduce the experiments (example 4 in official KAN). With official KAN, I get the results as below (Ground-truth is at the top, and the predication is at the bottom):
Ground_task5
Pred_task4

But with the efficient-kan, I get the results as below:
Ground_task5
Pred_task4

It shows that previous peak will be higher when learning new peak.
The official model is create by: "model = KAN(width=[1, 1], grid=200, k=3, noise_scale=0.1, bias_trainable=False, sp_trainable=False, sb_trainable=False)"
The efficient-kan model is created by: "model = KAN([1, 1], grid_size=200)"
It seems to be the same except for "bias_trainable=False, sp_trainable=False, sb_trainable=False".

Supports inputs more than 2 dimensions For efficient-KAN

I don't quite understand KAN's code, is it possiable for KanLinear to do as Torch.nn.Linear: only the last dimension is subjected to derivation operations, allowing inputs greater than 2 dimensions?
For example, in multi head attention, our input is similar to [batch, nhead, dim]

However, this is not allowed in the current KAN ("assert x.dim() == 2 and x.size(1) == self.in_features")
Excuse me! I am very interested in exploring the application of KAN in attention

AssertionError

I am pretty sure this is really just a dimensionally issue, but trying to use KANLinear to substitute for nn.Linear to try this approach out. I can use the tutorial to get it to work with MNIST just fine, but it doesn't work well outside the box, almost certainly because I am missing something.

I keep getting the error for forward:

assert x.dim() == 2 and x.size(1) == self.in_features
AssertationError

All I am doing is dropping KANLinear in for nn.Linear, and keeping in_features and out_features the same hidden size. Is there a way forward can be edited to allow non-image inputs?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.